o
    ZhR                  	   @   s  d Z ddlZddlmZ ddlmZmZmZ ddlZddl	Zddlm
Z
 ddlmZmZmZ ddlmZ dd	lmZmZmZ dd
lmZ ddlmZmZmZmZ ddlmZ eeZ dZ!dZ"g dZ#dZ$dZ%d9dej&de'de(dej&fddZ)G dd de
j*Z+G dd de
j*Z,G dd de
j*Z-G d d! d!e
j*Z.G d"d# d#e
j*Z/G d$d% d%e
j*Z0G d&d' d'e
j*Z1G d(d) d)e
j*Z2G d*d+ d+e
j*Z3G d,d- d-e
j*Z4G d.d/ d/eZ5d0Z6d1Z7ed2e6G d3d4 d4e5Z8ed5e6G d6d7 d7e5Z9g d8Z:dS ):z-PyTorch Visual Attention Network (VAN) model.    N)OrderedDict)OptionalTupleUnion)nn)BCEWithLogitsLossCrossEntropyLossMSELoss   )ACT2FN)BaseModelOutputWithNoAttention(BaseModelOutputWithPoolingAndNoAttention$ImageClassifierOutputWithNoAttention)PreTrainedModel)add_code_sample_docstringsadd_start_docstrings%add_start_docstrings_to_model_forwardlogging   )	VanConfigr   z!Visual-Attention-Network/van-base)r   i      r   ztabby, tabby cat        Finput	drop_probtrainingreturnc                 C   sd   |dks|s| S d| }| j d fd| jd   }|tj|| j| jd }|  | || }|S )aF  
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

    Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
    however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
    layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
    argument.
    r   r   r   )r   )dtypedevice)shapendimtorchZrandr   r   Zfloor_div)r   r   r   Z	keep_probr   Zrandom_tensoroutput r#   ^/var/www/auris/lib/python3.10/site-packages/transformers/models/deprecated/van/modeling_van.py	drop_path3   s   
r%   c                       sT   e Zd ZdZddee ddf fddZdejdejfdd	Z	de
fd
dZ  ZS )VanDropPathzXDrop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).Nr   r   c                    s   t    || _d S N)super__init__r   )selfr   	__class__r#   r$   r)   J   s   

zVanDropPath.__init__hidden_statesc                 C   s   t || j| jS r'   )r%   r   r   )r*   r-   r#   r#   r$   forwardN   s   zVanDropPath.forwardc                 C   s   d | jS )Nzp={})formatr   )r*   r#   r#   r$   
extra_reprQ   s   zVanDropPath.extra_reprr'   )__name__
__module____qualname____doc__r   floatr)   r    Tensorr.   strr0   __classcell__r#   r#   r+   r$   r&   G   s
    r&   c                	       sJ   e Zd ZdZddedededef fdd	Zd
ejdejfddZ  Z	S )VanOverlappingPatchEmbedderz
    Downsamples the input using a patchify operation with a `stride` of 4 by default making adjacent windows overlap by
    half of the area. From [PVTv2: Improved Baselines with Pyramid Vision
    Transformer](https://arxiv.org/abs/2106.13797).
    r   r
   in_channelshidden_size
patch_sizestridec                    s4   t    tj|||||d d| _t|| _d S )N   )kernel_sizer=   padding)r(   r)   r   Conv2dconvolutionBatchNorm2dnormalization)r*   r:   r;   r<   r=   r+   r#   r$   r)   \   s
   
z$VanOverlappingPatchEmbedder.__init__r   r   c                 C   s   |  |}| |}|S r'   )rB   rD   )r*   r   hidden_stater#   r#   r$   r.   c   s   

z#VanOverlappingPatchEmbedder.forward)r   r
   
r1   r2   r3   r4   intr)   r    r6   r.   r8   r#   r#   r+   r$   r9   U   s     r9   c                       sR   e Zd ZdZ		ddededededef
 fd	d
Zdej	dej	fddZ
  ZS )VanMlpLayerz
    MLP with depth-wise convolution, from [PVTv2: Improved Baselines with Pyramid Vision
    Transformer](https://arxiv.org/abs/2106.13797).
    gelu      ?r:   r;   out_channels
hidden_actdropout_ratec                    sj   t    tj||dd| _tj||dd|d| _t| | _t|| _	tj||dd| _
t|| _d S )Nr   r?      r?   r@   groups)r(   r)   r   rA   in_dense
depth_wiser   
activationZDropoutdropout1	out_densedropout2)r*   r:   r;   rK   rL   rM   r+   r#   r$   r)   o   s   

zVanMlpLayer.__init__rE   r   c                 C   s@   |  |}| |}| |}| |}| |}| |}|S r'   )rR   rS   rT   rU   rV   rW   r*   rE   r#   r#   r$   r.      s   





zVanMlpLayer.forward)rI   rJ   )r1   r2   r3   r4   rG   r7   r5   r)   r    r6   r.   r8   r#   r#   r+   r$   rH   i   s     
rH   c                       <   e Zd ZdZdef fddZdejdejfddZ  Z	S )	VanLargeKernelAttentionz-
    Basic Large Kernel Attention (LKA).
    r;   c                    sN   t    tj||dd|d| _tj||ddd|d| _tj||dd	| _d S )
N   r>   rP   r   rO   	   )r?   Zdilationr@   rQ   r   rN   )r(   r)   r   rA   rS   depth_wise_dilated
point_wiser*   r;   r+   r#   r$   r)      s   
z VanLargeKernelAttention.__init__rE   r   c                 C   s"   |  |}| |}| |}|S r'   )rS   r]   r^   rX   r#   r#   r$   r.      s   


zVanLargeKernelAttention.forwardrF   r#   r#   r+   r$   rZ      s    rZ   c                       rY   )	VanLargeKernelAttentionLayerzV
    Computes attention using Large Kernel Attention (LKA) and attends the input.
    r;   c                    s   t    t|| _d S r'   )r(   r)   rZ   	attentionr_   r+   r#   r$   r)      s   
z%VanLargeKernelAttentionLayer.__init__rE   r   c                 C   s   |  |}|| }|S r'   )ra   )r*   rE   ra   Zattendedr#   r#   r$   r.      s   
z$VanLargeKernelAttentionLayer.forwardrF   r#   r#   r+   r$   r`      s    r`   c                       B   e Zd ZdZddedef fddZdejdejfd	d
Z	  Z
S )VanSpatialAttentionLayerz
    Van spatial attention layer composed by projection (via conv) -> act -> Large Kernel Attention (LKA) attention ->
    projection (via conv) + residual connection.
    rI   r;   rL   c              	      sV   t    ttdtj||ddfdt| fg| _t|| _	tj||dd| _
d S )Nconvr   rN   Zact)r(   r)   r   
Sequentialr   rA   r   pre_projectionr`   attention_layerpost_projection)r*   r;   rL   r+   r#   r$   r)      s   


z!VanSpatialAttentionLayer.__init__rE   r   c                 C   s.   |}|  |}| |}| |}|| }|S r'   )rf   rg   rh   r*   rE   Zresidualr#   r#   r$   r.      s   


z VanSpatialAttentionLayer.forward)rI   )r1   r2   r3   r4   rG   r7   r)   r    r6   r.   r8   r#   r#   r+   r$   rc      s    rc   c                       rb   )VanLayerScalingzT
    Scales the inputs by a learnable parameter initialized by `initial_value`.
    {Gz?r;   initial_valuec                    s(   t    tj|t| dd| _d S )NT)Zrequires_grad)r(   r)   r   	Parameterr    Zonesweight)r*   r;   rl   r+   r#   r$   r)      s   
zVanLayerScaling.__init__rE   r   c                 C   s   | j dd| }|S )N)rn   Z	unsqueezerX   r#   r#   r$   r.      s   zVanLayerScaling.forward)rk   )r1   r2   r3   r4   rG   r5   r)   r    r6   r.   r8   r#   r#   r+   r$   rj      s    rj   c                	       sN   e Zd ZdZ		ddedededef fdd	Zd
ej	dej	fddZ
  ZS )VanLayerzv
    Van layer composed by normalization layers, large kernel attention (LKA) and a multi layer perceptron (MLP).
    r
   rJ   configr;   	mlp_ratiodrop_path_ratec                    s   t    |dkrt|nt | _t|| _t||j	| _
t||j| _t|| _t||| ||j	|j| _t||j| _d S )Nr   )r(   r)   r&   r   Identityr%   rC   pre_normomalizationrc   rL   ra   rj   Zlayer_scale_init_valueattention_scalingpost_normalizationrH   rM   mlpmlp_scaling)r*   rq   r;   rr   rs   r+   r#   r$   r)      s   
zVanLayer.__init__rE   r   c                 C   sl   |}|  |}| |}| |}| |}|| }|}| |}| |}| |}| |}|| }|S r'   )ru   ra   rv   r%   rw   rx   ry   ri   r#   r#   r$   r.      s   







zVanLayer.forward)r
   rJ   r1   r2   r3   r4   r   rG   r5   r)   r    r6   r.   r8   r#   r#   r+   r$   rp      s    rp   c                       s^   e Zd ZdZ		ddededededed	ed
edef fddZdej	dej	fddZ
  ZS )VanStagez2
    VanStage, consisting of multiple layers.
    r
   r   rq   r:   r;   r<   r=   depthrr   rs   c	           	         sT   t    t|||| _tj fddt|D  | _tj j	d| _
d S )Nc                    s   g | ]
}t  d qS ))rr   rs   )rp   ).0_rq   rs   r;   rr   r#   r$   
<listcomp>  s    z%VanStage.__init__.<locals>.<listcomp>Zeps)r(   r)   r9   
embeddingsr   re   rangelayers	LayerNormlayer_norm_epsrD   )	r*   rq   r:   r;   r<   r=   r|   rr   rs   r+   r   r$   r)     s   
zVanStage.__init__rE   r   c                 C   s^   |  |}| |}|j\}}}}|ddd}| |}|||||dddd}|S )Nr>   r   r   rO   )r   r   r   flattenZ	transposerD   viewZpermute)r*   rE   Z
batch_sizer;   heightwidthr#   r#   r$   r.   !  s   


zVanStage.forward)r
   r   rz   r#   r#   r+   r$   r{     s,    	r{   c                       sX   e Zd ZdZdef fddZ		ddejdee	 d	ee	 d
e
eef fddZ  ZS )
VanEncoderz4
    VanEncoder, consisting of multiple stages.
    rq   c                    s   t    tg | _|j}|j}|j}|j}|j	}dd t
jd|jt|jddD }tt||||||D ])\}\}	}
}}}}|dk}||d  }|rP|j}| jt||||	|
|||d q7d S )Nc                 S   s   g | ]}|  qS r#   )item)r}   xr#   r#   r$   r   :  s    z'VanEncoder.__init__.<locals>.<listcomp>r   cpu)r   r   )r<   r=   r|   rr   rs   )r(   r)   r   Z
ModuleListstagespatch_sizesstrideshidden_sizesdepths
mlp_ratiosr    Zlinspacers   sum	enumeratezipZnum_channelsappendr{   )r*   rq   r   r   r   r   r   Zdrop_path_ratesZ	num_stager<   r=   r;   r|   Zmlp_expantionrs   Zis_first_stager:   r+   r#   r$   r)   2  s<   
zVanEncoder.__init__FTrE   output_hidden_statesreturn_dictr   c                 C   s\   |rdnd }t | jD ]\}}||}|r||f }q|s(tdd ||fD S t||dS )Nr#   c                 s   s    | ]	}|d ur|V  qd S r'   r#   )r}   vr#   r#   r$   	<genexpr>a  s    z%VanEncoder.forward.<locals>.<genexpr>)last_hidden_stater-   )r   r   tupler   )r*   rE   r   r   Zall_hidden_statesr~   Zstage_moduler#   r#   r$   r.   R  s   
zVanEncoder.forward)FT)r1   r2   r3   r4   r   r)   r    r6   r   boolr   r   r   r.   r8   r#   r#   r+   r$   r   -  s    #
r   c                   @   s(   e Zd ZdZeZdZdZdZdd Z	dS )VanPreTrainedModelz
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    vanpixel_valuesTc                 C   s   t |tjr*tjj|j| jjd t |tjr&|jdur(tj	|jd dS dS dS t |tj
rBtj	|jd tj	|jd dS t |tjrt|jd |jd  |j }||j }|jjdtd|  |jdurv|jj  dS dS dS )zInitialize the weights)ZstdNr   g      ?r   g       @)
isinstancer   LinearinitZtrunc_normal_rn   rq   Zinitializer_rangeZbiasZ	constant_r   rA   r?   rK   rQ   dataZnormal_mathsqrtZzero_)r*   moduleZfan_outr#   r#   r$   _init_weightsq  s    

z VanPreTrainedModel._init_weightsN)
r1   r2   r3   r4   r   config_classZbase_model_prefixZmain_input_nameZsupports_gradient_checkpointingr   r#   r#   r#   r$   r   f  s    r   aE  
    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
    as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
    behavior.

    Parameters:
        config ([`VanConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
aF  
    Args:
        pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
            [`ConvNextImageProcessor.__call__`] for details.

        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all stages. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
zxThe bare VAN model outputting raw features without any specific head on top. Note, VAN does not have an embedding layer.c                       sl   e Zd Z fddZeeeeee	de
d		ddeej dee dee d	eeef fd
dZ  ZS )VanModelc                    s@   t  | || _t|| _tj|jd |jd| _	| 
  d S )Nro   r   )r(   r)   rq   r   encoderr   r   r   r   Z	layernorm	post_initr*   rq   r+   r#   r$   r)     s
   
zVanModel.__init__Zvision)
checkpointoutput_typer   Zmodalityexpected_outputNr   r   r   r   c                 C   sx   |d ur|n| j j}|d ur|n| j j}| j|||d}|d }|jddgd}|s4||f|dd   S t|||jdS )Nr   r   r   ro   )dimr   )r   pooler_outputr-   )rq   r   use_return_dictr   meanr   r-   )r*   r   r   r   Zencoder_outputsr   pooled_outputr#   r#   r$   r.     s"   zVanModel.forward)NN)r1   r2   r3   r)   r   VAN_INPUTS_DOCSTRINGr   _CHECKPOINT_FOR_DOCr   _CONFIG_FOR_DOC_EXPECTED_OUTPUT_SHAPEr   r    FloatTensorr   r   r   r.   r8   r#   r#   r+   r$   r     s*    	

r   z
    VAN Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
    ImageNet.
    c                       sx   e Zd Z fddZeeeeee	e
d				ddeej deej dee dee d	eeef f
d
dZ  ZS )VanForImageClassificationc                    sJ   t  | t|| _|jdkrt|jd |jnt | _	| 
  d S )Nr   ro   )r(   r)   r   r   
num_labelsr   r   r   rt   
classifierr   r   r+   r#   r$   r)     s
   
$z"VanForImageClassification.__init__)r   r   r   r   Nr   labelsr   r   r   c                 C   sj  |dur|n| j j}| j|||d}|r|jn|d }| |}d}|dur| j jdu rR| j jdkr7d| j _n| j jdkrN|jtj	ksI|jtj
krNd| j _nd| j _| j jdkrqt }	| j jdkrk|	| | }n,|	||}n&| j jdkrt }	|	|d| j j|d}n| j jdkrt }	|	||}|s|f|dd  }
|dur|f|
 S |
S t|||jd	S )
a  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        Nr   r   Z
regressionZsingle_label_classificationZmulti_label_classificationro   r>   )losslogitsr-   )rq   r   r   r   r   Zproblem_typer   r   r    longrG   r	   Zsqueezer   r   r   r   r-   )r*   r   r   r   r   Zoutputsr   r   r   Zloss_fctr"   r#   r#   r$   r.     s6   

$

z!VanForImageClassification.forward)NNNN)r1   r2   r3   r)   r   r   r   _IMAGE_CLASS_CHECKPOINTr   r   _IMAGE_CLASS_EXPECTED_OUTPUTr   r    r   Z
LongTensorr   r   r   r.   r8   r#   r#   r+   r$   r     s0    
r   )r   r   r   )r   F);r4   r   collectionsr   typingr   r   r   r    Ztorch.utils.checkpointr   Ztorch.nnr   r   r	   Zactivationsr   Zmodeling_outputsr   r   r   Zmodeling_utilsr   utilsr   r   r   r   Zconfiguration_vanr   Z
get_loggerr1   loggerr   r   r   r   r   r6   r5   r   r%   Moduler&   r9   rH   rZ   r`   rc   rj   rp   r{   r   r   ZVAN_START_DOCSTRINGr   r   r   __all__r#   r#   r#   r$   <module>   sX   
  ++90F