o
    Zh~g                  	   @   s$  d Z ddlZddlmZ ddlmZmZmZ ddl	Z	ddl
Z	ddl	mZ ddlmZmZmZ ddlmZmZ dd	lmZmZmZ dd
lmZmZ ddlmZ eeZeG dd deZd<de	j de!de"de	j fddZ#G dd dej$Z%G dd dej$Z&G dd dej$Z'G dd dej$Z(G dd  d ej$Z)G d!d" d"ej$Z*G d#d$ d$ej$Z+G d%d& d&ej$Z,G d'd( d(ej$Z-G d)d* d*ej$Z.G d+d, d,ej$Z/G d-d. d.ej$Z0G d/d0 d0ej$Z1G d1d2 d2ej$Z2eG d3d4 d4eZ3eG d5d6 d6e3Z4ed7d8G d9d: d:e3Z5g d;Z6dS )=zPyTorch CvT model.    N)	dataclass)OptionalTupleUnion)nn)BCEWithLogitsLossCrossEntropyLossMSELoss   )$ImageClassifierOutputWithNoAttentionModelOutput)PreTrainedModel find_pruneable_heads_and_indicesprune_linear_layer)auto_docstringlogging   )	CvtConfigc                   @   sP   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eeejdf  ed< dS )BaseModelOutputWithCLSTokena  
    Base class for model's outputs, with potential hidden states and attentions.

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        cls_token_value (`torch.FloatTensor` of shape `(batch_size, 1, hidden_size)`):
            Classification token at the output of the last layer of the model.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
            plus the initial embedding outputs.
    Nlast_hidden_statecls_token_value.hidden_states)__name__
__module____qualname____doc__r   r   torchZFloatTensor__annotations__r   r   r    r   r   S/var/www/auris/lib/python3.10/site-packages/transformers/models/cvt/modeling_cvt.pyr   #   s
   
 r           Finput	drop_probtrainingreturnc                 C   sd   |dks|s| S d| }| j d fd| jd   }|tj|| j| jd }|  | || }|S )aF  
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

    Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
    however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
    layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
    argument.
    r    r   r   )r   )dtypedevice)shapendimr   Zrandr%   r&   Zfloor_div)r!   r"   r#   Z	keep_probr'   Zrandom_tensoroutputr   r   r   	drop_path9   s   
r+   c                       sT   e Zd ZdZddee ddf fddZdejdejfdd	Z	de
fd
dZ  ZS )CvtDropPathzXDrop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).Nr"   r$   c                    s   t    || _d S N)super__init__r"   )selfr"   	__class__r   r   r/   Q   s   

zCvtDropPath.__init__r   c                 C   s   t || j| jS r-   )r+   r"   r#   )r0   r   r   r   r   forwardU   s   zCvtDropPath.forwardc                 C   s   d | jS )Nzp={})formatr"   )r0   r   r   r   
extra_reprX   s   zCvtDropPath.extra_reprr-   )r   r   r   r   r   floatr/   r   Tensorr3   strr5   __classcell__r   r   r1   r   r,   N   s
    r,   c                       (   e Zd ZdZ fddZdd Z  ZS )CvtEmbeddingsz'
    Construct the CvT embeddings.
    c                    s.   t    t|||||d| _t|| _d S )N)
patch_sizenum_channels	embed_dimstridepadding)r.   r/   CvtConvEmbeddingsconvolution_embeddingsr   Dropoutdropout)r0   r<   r=   r>   r?   r@   dropout_rater1   r   r   r/   a   s
   

zCvtEmbeddings.__init__c                 C   s   |  |}| |}|S r-   )rB   rD   )r0   pixel_valueshidden_stater   r   r   r3   h      

zCvtEmbeddings.forwardr   r   r   r   r/   r3   r9   r   r   r1   r   r;   \       r;   c                       r:   )rA   z"
    Image to Conv Embedding.
    c                    sP   t    t|tjjr|n||f}|| _tj|||||d| _	t
|| _d S )N)kernel_sizer?   r@   )r.   r/   
isinstancecollectionsabcIterabler<   r   Conv2d
projection	LayerNormnormalization)r0   r<   r=   r>   r?   r@   r1   r   r   r/   s   s
   
zCvtConvEmbeddings.__init__c                 C   sf   |  |}|j\}}}}|| }||||ddd}| jr$| |}|ddd||||}|S Nr      r   )rQ   r'   viewpermuterS   )r0   rF   
batch_sizer=   heightwidthhidden_sizer   r   r   r3   z   s   

zCvtConvEmbeddings.forwardrI   r   r   r1   r   rA   n   rJ   rA   c                       $   e Zd Z fddZdd Z  ZS )CvtSelfAttentionConvProjectionc              	      s4   t    tj|||||d|d| _t|| _d S )NF)rK   r@   r?   biasgroups)r.   r/   r   rP   convolutionZBatchNorm2drS   )r0   r>   rK   r@   r?   r1   r   r   r/      s   
	z'CvtSelfAttentionConvProjection.__init__c                 C      |  |}| |}|S r-   )r`   rS   r0   rG   r   r   r   r3      rH   z&CvtSelfAttentionConvProjection.forwardr   r   r   r/   r3   r9   r   r   r1   r   r]      s    r]   c                   @   s   e Zd Zdd ZdS ) CvtSelfAttentionLinearProjectionc                 C   s2   |j \}}}}|| }||||ddd}|S rT   )r'   rV   rW   )r0   rG   rX   r=   rY   rZ   r[   r   r   r   r3      s   z(CvtSelfAttentionLinearProjection.forwardN)r   r   r   r3   r   r   r   r   rd      s    rd   c                       s&   e Zd Zd fdd	Zdd Z  ZS )CvtSelfAttentionProjectiondw_bnc                    s.   t    |dkrt||||| _t | _d S )Nrf   )r.   r/   r]   convolution_projectionrd   linear_projection)r0   r>   rK   r@   r?   projection_methodr1   r   r   r/      s   
z#CvtSelfAttentionProjection.__init__c                 C   ra   r-   )rg   rh   rb   r   r   r   r3      rH   z"CvtSelfAttentionProjection.forward)rf   rc   r   r   r1   r   re      s    re   c                       0   e Zd Z	d fdd	Zdd Zdd Z  ZS )	CvtSelfAttentionTc                    s   t    |d | _|| _|| _|| _t|||||dkrdn|d| _t|||||d| _t|||||d| _	t
j|||	d| _t
j|||	d| _t
j|||	d| _t
|
| _d S )Ng      ZavgZlinear)ri   )r^   )r.   r/   scalewith_cls_tokenr>   	num_headsre   convolution_projection_queryconvolution_projection_keyconvolution_projection_valuer   Linearprojection_queryprojection_keyprojection_valuerC   rD   )r0   rn   r>   rK   	padding_q
padding_kvstride_q	stride_kvqkv_projection_methodqkv_biasattention_drop_raterm   kwargsr1   r   r   r/      s,   



zCvtSelfAttention.__init__c                 C   s6   |j \}}}| j| j }|||| j|ddddS )Nr   rU   r   r
   )r'   r>   rn   rV   rW   )r0   rG   rX   r[   _head_dimr   r   r   "rearrange_for_multi_head_attention   s   z3CvtSelfAttention.rearrange_for_multi_head_attentionc                 C   sT  | j rt|d|| gd\}}|j\}}}|ddd||||}| |}| |}	| |}
| j rPtj	||	fdd}	tj	||fdd}tj	||
fdd}
| j
| j }| | |	}	| | |}| | |
}
td|	|g| j }tjjj|dd}| |}td||
g}|j\}}}}|dddd ||| j| }|S )	Nr   r   rU   dimzbhlk,bhtk->bhltzbhlt,bhtv->bhlvr
   )rm   r   splitr'   rW   rV   rp   ro   rq   catr>   rn   r   rs   rt   ru   Zeinsumrl   r   Z
functionalZsoftmaxrD   
contiguous)r0   rG   rY   rZ   	cls_tokenrX   r[   r=   keyqueryvaluer   Zattention_scoreZattention_probscontextr~   r   r   r   r3      s,   



$zCvtSelfAttention.forwardT)r   r   r   r/   r   r3   r9   r   r   r1   r   rk      s
    )rk   c                       r:   )CvtSelfOutputz
    The residual connection is defined in CvtLayer instead of here (as is the case with other models), due to the
    layernorm applied before each block.
    c                    s(   t    t||| _t|| _d S r-   )r.   r/   r   rr   denserC   rD   )r0   r>   	drop_rater1   r   r   r/     s   
zCvtSelfOutput.__init__c                 C   ra   r-   r   rD   r0   rG   Zinput_tensorr   r   r   r3     rH   zCvtSelfOutput.forwardrI   r   r   r1   r   r     s    r   c                       rj   )	CvtAttentionTc                    s@   t    t|||||||||	|
|| _t||| _t | _d S r-   )r.   r/   rk   	attentionr   r*   setpruned_heads)r0   rn   r>   rK   rv   rw   rx   ry   rz   r{   r|   r   rm   r1   r   r   r/     s    
zCvtAttention.__init__c                 C   s   t |dkrd S t|| jj| jj| j\}}t| jj|| j_t| jj|| j_t| jj	|| j_	t| j
j|dd| j
_| jjt | | j_| jj| jj | j_| j|| _d S )Nr   r   r   )lenr   r   Znum_attention_headsZattention_head_sizer   r   r   r   r   r*   r   Zall_head_sizeunion)r0   headsindexr   r   r   prune_heads4  s   zCvtAttention.prune_headsc                 C   s   |  |||}| ||}|S r-   )r   r*   )r0   rG   rY   rZ   Zself_outputattention_outputr   r   r   r3   F  s   zCvtAttention.forwardr   )r   r   r   r/   r   r3   r9   r   r   r1   r   r     s
     r   c                       r\   )CvtIntermediatec                    s.   t    t|t|| | _t | _d S r-   )r.   r/   r   rr   intr   ZGELU
activation)r0   r>   	mlp_ratior1   r   r   r/   M  s   
zCvtIntermediate.__init__c                 C   ra   r-   )r   r   rb   r   r   r   r3   R  rH   zCvtIntermediate.forwardrc   r   r   r1   r   r   L      r   c                       r\   )	CvtOutputc                    s0   t    tt|| || _t|| _d S r-   )r.   r/   r   rr   r   r   rC   rD   )r0   r>   r   r   r1   r   r   r/   Y  s   
zCvtOutput.__init__c                 C   s    |  |}| |}|| }|S r-   r   r   r   r   r   r3   ^  s   

zCvtOutput.forwardrc   r   r   r1   r   r   X  r   r   c                       s,   e Zd ZdZ	d fdd	Zdd Z  ZS )CvtLayerzb
    CvtLayer composed by attention layers, normalization and multi-layer perceptrons (mlps).
    Tc                    s|   t    t|||||||||	|
||| _t||| _t|||| _|dkr+t|dnt	
 | _t	|| _t	|| _d S )Nr    )r"   )r.   r/   r   r   r   intermediater   r*   r,   r   Identityr+   rR   layernorm_beforelayernorm_after)r0   rn   r>   rK   rv   rw   rx   ry   rz   r{   r|   r   r   drop_path_raterm   r1   r   r   r/   j  s(   
zCvtLayer.__init__c                 C   sX   |  | |||}|}| |}|| }| |}| |}| ||}| |}|S r-   )r   r   r+   r   r   r*   )r0   rG   rY   rZ   Zself_attention_outputr   Zlayer_outputr   r   r   r3     s   



zCvtLayer.forwardr   rI   r   r   r1   r   r   e  s
    'r   c                       r\   )CvtStagec                    s   t     _|_jjj r!ttddjj	d _t
 jj  jj jdkr4 jn j	jd   j	j  jj  jj d_dd tjd jj  j| ddD tj fd	dt jj D  _d S )
Nr   r   r   )r<   r?   r=   r>   r@   rE   c                 S   s   g | ]}|  qS r   )item).0xr   r   r   
<listcomp>  s    z%CvtStage.__init__.<locals>.<listcomp>cpu)r&   c                    s   g | ]K}t  jj  jj  jj  jj  jj  jj  jj  j	j  j
j  jj  jj j  jj  jj d qS ))rn   r>   rK   rv   rw   ry   rx   rz   r{   r|   r   r   r   rm   )r   rn   stager>   Z
kernel_qkvrv   rw   ry   rx   rz   r{   r|   r   r   r   )r   r~   configZdrop_path_ratesr0   r   r   r     s&    












)r.   r/   r   r   r   r   	Parameterr   Zrandnr>   r;   Zpatch_sizesZpatch_strider=   Zpatch_paddingr   	embeddingZlinspacer   depthZ
Sequentialrangelayers)r0   r   r   r1   r   r   r/     s*   





	
zCvtStage.__init__c           	      C   s   d }|  |}|j\}}}}||||| ddd}| jj| j r4| j|dd}tj	||fdd}| j
D ]
}||||}|}q7| jj| j rVt|d|| gd\}}|ddd||||}||fS )Nr   rU   r   r   r   )r   r'   rV   rW   r   r   r   expandr   r   r   r   )	r0   rG   r   rX   r=   rY   rZ   layerZlayer_outputsr   r   r   r3     s   

zCvtStage.forwardrc   r   r   r1   r   r     s    *r   c                       s&   e Zd Z fddZdddZ  ZS )
CvtEncoderc                    sF   t    || _tg | _tt|jD ]}| j	t
|| qd S r-   )r.   r/   r   r   Z
ModuleListstagesr   r   r   appendr   )r0   r   Z	stage_idxr1   r   r   r/     s   
zCvtEncoder.__init__FTc           	      C   sl   |rdnd }|}d }t | jD ]\}}||\}}|r ||f }q|s/tdd |||fD S t|||dS )Nr   c                 s   s    | ]	}|d ur|V  qd S r-   r   )r   vr   r   r   	<genexpr>  s    z%CvtEncoder.forward.<locals>.<genexpr>r   r   r   )	enumerater   tupler   )	r0   rF   output_hidden_statesreturn_dictZall_hidden_statesrG   r   r~   Zstage_moduler   r   r   r3     s   
zCvtEncoder.forward)FTrc   r   r   r1   r   r     s    r   c                   @   s&   e Zd ZeZdZdZdgZdd ZdS )CvtPreTrainedModelcvtrF   r   c                 C   s   t |tjtjfr'tjj|jjd| jj	d|j_|j
dur%|j
j  dS dS t |tjr<|j
j  |jjd dS t |trY| jj|j r[tjj|jjd| jj	d|j_dS dS dS )zInitialize the weightsr    )meanZstdNg      ?)rL   r   rr   rP   initZtrunc_normal_weightdatar   Zinitializer_ranger^   Zzero_rR   Zfill_r   r   r   )r0   moduler   r   r   _init_weights	  s   

z CvtPreTrainedModel._init_weightsN)	r   r   r   r   Zconfig_classZbase_model_prefixZmain_input_nameZ_no_split_modulesr   r   r   r   r   r     s    r   c                       sb   e Zd Zd fdd	Zdd Ze			ddeej dee	 d	ee	 d
e
eef fddZ  ZS )CvtModelTc                    s(   t  | || _t|| _|   dS )zv
        add_pooling_layer (bool, *optional*, defaults to `True`):
            Whether to add a pooling layer
        N)r.   r/   r   r   encoder	post_init)r0   r   add_pooling_layerr1   r   r   r/     s   
zCvtModel.__init__c                 C   s*   |  D ]\}}| jj| j| qdS )z
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        N)itemsr   r   r   r   )r0   Zheads_to_pruner   r   r   r   r   _prune_heads%  s   zCvtModel._prune_headsNrF   r   r   r$   c                 C   sx   |d ur|n| j j}|d ur|n| j j}|d u rtd| j|||d}|d }|s3|f|dd   S t||j|jdS )Nz You have to specify pixel_valuesr   r   r   r   r   )r   r   use_return_dict
ValueErrorr   r   r   r   )r0   rF   r   r   Zencoder_outputssequence_outputr   r   r   r3   -  s$   zCvtModel.forwardr   )NNN)r   r   r   r/   r   r   r   r   r7   boolr   r   r   r3   r9   r   r   r1   r   r     s     

r   z
    Cvt Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
    the [CLS] token) e.g. for ImageNet.
    )Zcustom_introc                       sd   e Zd Z fddZe				ddeej deej dee dee de	e
ef f
d	d
Z  ZS )CvtForImageClassificationc                    sh   t  | |j| _t|dd| _t|jd | _|jdkr)t	|jd |jnt
 | _|   d S )NF)r   r   r   )r.   r/   
num_labelsr   r   r   rR   r>   	layernormrr   r   
classifierr   )r0   r   r1   r   r   r/   T  s   $z"CvtForImageClassification.__init__NrF   labelsr   r   r$   c                 C   s  |dur|n| j j}| j|||d}|d }|d }| j jd r&| |}n|j\}}	}
}|||	|
| ddd}| |}|jdd}| 	|}d}|dur| j j
du r}| j jdkrbd| j _
n| j jdkry|jtjkst|jtjkryd	| j _
nd
| j _
| j j
dkrt }| j jdkr|| | }n,|||}n&| j j
d	krt }||d| j j|d}n| j j
d
krt }|||}|s|f|dd  }|dur|f| S |S t|||jdS )a  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        Nr   r   r   r   rU   r   Z
regressionZsingle_label_classificationZmulti_label_classification)losslogitsr   )r   r   r   r   r   r'   rV   rW   r   r   Zproblem_typer   r%   r   longr   r	   Zsqueezer   r   r   r   )r0   rF   r   r   r   Zoutputsr   r   rX   r=   rY   rZ   Zsequence_output_meanr   r   Zloss_fctr*   r   r   r   r3   b  sL   


$

z!CvtForImageClassification.forward)NNNN)r   r   r   r/   r   r   r   r7   r   r   r   r   r3   r9   r   r   r1   r   r   M  s$    
r   )r   r   r   )r    F)7r   collections.abcrM   dataclassesr   typingr   r   r   r   Ztorch.utils.checkpointr   Ztorch.nnr   r   r	   Zmodeling_outputsr   r   Zmodeling_utilsr   r   r   utilsr   r   Zconfiguration_cvtr   Z
get_loggerr   loggerr   r7   r6   r   r+   Moduler,   r;   rA   r]   rd   re   rk   r   r   r   r   r   r   r   r   r   r   __all__r   r   r   r   <module>   sN   
 	Q9B?3O