o
    ZhQ                     @   sX  d Z ddlZddlmZ ddlmZmZmZm	Z	m
Z
 ddlZddlZddlmZ ddlmZmZmZ ddlmZ dd	lmZmZmZmZ dd
lmZmZ ddlmZmZ ddlm Z m!Z!m"Z"m#Z# ddl$m%Z% e"&e'Z(G dd dej)Z*G dd dej)Z+	d>dej)dej,dej,dej,deej, de-de-fddZ.G dd dej)Z/G dd  d ej)Z0G d!d" d"ej)Z1G d#d$ d$ej)Z2G d%d& d&ej)Z3G d'd( d(ej)Z4G d)d* d*ej)Z5e!G d+d, d,eZ6e!G d-d. d.e6Z7G d/d0 d0ej)Z8e!d1d2G d3d4 d4e6Z9e!d5d2G d6d7 d7e6Z:eG d8d9 d9e Z;e!d:d2G d;d< d<e6Z<g d=Z=dS )?zPyTorch DeiT model.    N)	dataclass)CallableOptionalSetTupleUnion)nn)BCEWithLogitsLossCrossEntropyLossMSELoss   )ACT2FN)BaseModelOutputBaseModelOutputWithPoolingImageClassifierOutputMaskedImageModelingOutput)ALL_ATTENTION_FUNCTIONSPreTrainedModel) find_pruneable_heads_and_indicesprune_linear_layer)ModelOutputauto_docstringlogging	torch_int   )
DeiTConfigc                	       sx   e Zd ZdZddededdf fddZd	ejd
e	de	dejfddZ
		ddejdeej dedejfddZ  ZS )DeiTEmbeddingszv
    Construct the CLS token, distillation token, position and patch embeddings. Optionally, also the mask token.
    Fconfiguse_mask_tokenreturnNc                    s   t    ttdd|j| _ttdd|j| _|r*ttdd|jnd | _	t
|| _| jj}ttd|d |j| _t|j| _|j| _d S )Nr      )super__init__r   	ParametertorchZzeroshidden_size	cls_tokendistillation_token
mask_tokenDeiTPatchEmbeddingspatch_embeddingsnum_patchesposition_embeddingsDropouthidden_dropout_probdropout
patch_size)selfr   r   r+   	__class__ U/var/www/auris/lib/python3.10/site-packages/transformers/models/deit/modeling_deit.pyr"   /   s   
 
zDeiTEmbeddings.__init__
embeddingsheightwidthc                 C   s   |j d d }| jj d d }tj s||kr||kr| jS | jddddf }| jddddf }|j d }|| j }	|| j }
t|d }|d|||}|dddd}t	j
j||	|
fdd	d
}|dddddd|}tj||fddS )a  
        This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
        images. This method is also adapted to support torch.jit tracing and 2 class embeddings.

        Adapted from:
        - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
        - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
        r   r    N      ?r   r   ZbicubicF)sizemodeZalign_cornersdim)shaper,   r$   Zjit
is_tracingr0   r   reshapepermuter   
functionalZinterpolateviewcat)r1   r6   r7   r8   r+   Znum_positionsZclass_and_dist_pos_embedZpatch_pos_embedr>   Z
new_heightZ	new_widthZsqrt_num_positionsr4   r4   r5   interpolate_pos_encoding;   s(   



z'DeiTEmbeddings.interpolate_pos_encodingpixel_valuesbool_masked_posrF   c                 C   s   |j \}}}}| |}| \}}	}|d ur1| j||	d}
|d|
}|d|  |
|  }| j|dd}| j|dd}t	j
|||fdd}| j}|rW| |||}|| }| |}|S )Nr9         ?r   r=   )r?   r*   r;   r(   expand	unsqueezeZtype_asr&   r'   r$   rE   r,   rF   r/   )r1   rG   rH   rF   _r7   r8   r6   
batch_sizeZ
seq_lengthZmask_tokensmaskZ
cls_tokensZdistillation_tokensZposition_embeddingr4   r4   r5   forwardc   s    

zDeiTEmbeddings.forward)FNF)__name__
__module____qualname____doc__r   boolr"   r$   TensorintrF   r   
BoolTensorrO   __classcell__r4   r4   r2   r5   r   *   s    +r   c                       s6   e Zd ZdZ fddZdejdejfddZ  ZS )r)   z
    This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
    `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
    Transformer.
    c                    s   t    |j|j}}|j|j}}t|tjj	r|n||f}t|tjj	r)|n||f}|d |d  |d |d   }|| _|| _|| _|| _
tj||||d| _d S )Nr   r   )kernel_sizeZstride)r!   r"   
image_sizer0   num_channelsr%   
isinstancecollectionsabcIterabler+   r   Conv2d
projection)r1   r   r[   r0   r\   r%   r+   r2   r4   r5   r"      s   
 zDeiTPatchEmbeddings.__init__rG   r   c                 C   s<   |j \}}}}|| jkrtd| |ddd}|S )NzeMake sure that the channel dimension of the pixel values match with the one set in the configuration.r    r   )r?   r\   
ValueErrorrb   flatten	transpose)r1   rG   rM   r\   r7   r8   xr4   r4   r5   rO      s   
zDeiTPatchEmbeddings.forward)	rQ   rR   rS   rT   r"   r$   rV   rO   rY   r4   r4   r2   r5   r)      s    r)           modulequerykeyvalueattention_maskscalingr/   c           
      K   s|   t ||dd| }tjj|dt jd|j}tjj	||| j
d}|d ur,|| }t ||}	|	dd }	|	|fS )Nr9   )r>   dtype)ptrainingr   r    )r$   matmulre   r   rC   Zsoftmaxfloat32toro   r/   rq   
contiguous)
rh   ri   rj   rk   rl   rm   r/   kwargsZattn_weightsZattn_outputr4   r4   r5   eager_attention_forward   s   rw   c                
       sv   e Zd Zdeddf fddZdejdejfddZ		dd
eej de	de
eejejf eej f fddZ  ZS )DeiTSelfAttentionr   r   Nc                    s   t    |j|j dkrt|dstd|j d|j d|| _|j| _t|j|j | _| j| j | _	|j
| _| jd | _d| _tj|j| j	|jd| _tj|j| j	|jd| _tj|j| j	|jd| _d S )	Nr   Zembedding_sizezThe hidden size z4 is not a multiple of the number of attention heads .g      F)bias)r!   r"   r%   num_attention_headshasattrrc   r   rW   attention_head_sizeall_head_sizeZattention_probs_dropout_probdropout_probrm   	is_causalr   LinearZqkv_biasri   rj   rk   r1   r   r2   r4   r5   r"      s"   

zDeiTSelfAttention.__init__rf   c                 C   s6   |  d d | j| jf }||}|ddddS )Nr9   r   r    r   r   )r;   r{   r}   rD   rB   )r1   rf   Znew_x_shaper4   r4   r5   transpose_for_scores   s   
z&DeiTSelfAttention.transpose_for_scoresF	head_maskoutput_attentionsc              
   C   s   |  | |}|  | |}|  | |}t}| jjdkr4| jjdkr.|r.td nt	| jj }|| ||||| j
| j| jsCdn| jd\}}	| d d | jf }
||
}|rc||	f}|S |f}|S )NeagerZsdpaz`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.rg   )r   rm   r/   rn   )r   rj   rk   ri   rw   r   Z_attn_implementationloggerZwarning_oncer   r   rm   rq   r   r;   r~   rA   )r1   hidden_statesr   r   Z	key_layerZvalue_layerZquery_layerZattention_interfaceZcontext_layerZattention_probsZnew_context_layer_shapeoutputsr4   r4   r5   rO      s4   

zDeiTSelfAttention.forwardrP   )rQ   rR   rS   r   r"   r$   rV   r   r   rU   r   r   rO   rY   r4   r4   r2   r5   rx      s    rx   c                       sF   e Zd ZdZdeddf fddZdejdejdejfd	d
Z  Z	S )DeiTSelfOutputz
    The residual connection is defined in DeiTLayer instead of here (as is the case with other models), due to the
    layernorm applied before each block.
    r   r   Nc                    s.   t    t|j|j| _t|j| _d S N)	r!   r"   r   r   r%   denser-   r.   r/   r   r2   r4   r5   r"        
zDeiTSelfOutput.__init__r   input_tensorc                 C      |  |}| |}|S r   r   r/   r1   r   r   r4   r4   r5   rO        

zDeiTSelfOutput.forward)
rQ   rR   rS   rT   r   r"   r$   rV   rO   rY   r4   r4   r2   r5   r     s    $r   c                       s~   e Zd Zdeddf fddZdee ddfddZ			dd
ej	de
ej	 dedeeej	ej	f eej	 f fddZ  ZS )DeiTAttentionr   r   Nc                    s*   t    t|| _t|| _t | _d S r   )r!   r"   rx   	attentionr   outputsetpruned_headsr   r2   r4   r5   r"     s   


zDeiTAttention.__init__headsc                 C   s   t |dkrd S t|| jj| jj| j\}}t| jj|| j_t| jj|| j_t| jj	|| j_	t| j
j|dd| j
_| jjt | | j_| jj| jj | j_| j|| _d S )Nr   r   r=   )lenr   r   r{   r}   r   r   ri   rj   rk   r   r   r~   union)r1   r   indexr4   r4   r5   prune_heads  s   zDeiTAttention.prune_headsFr   r   r   c                 C   s4   |  |||}| |d |}|f|dd   }|S )Nr   r   )r   r   )r1   r   r   r   Zself_outputsattention_outputr   r4   r4   r5   rO   .  s   zDeiTAttention.forwardrP   )rQ   rR   rS   r   r"   r   rW   r   r$   rV   r   rU   r   r   rO   rY   r4   r4   r2   r5   r     s    r   c                       s<   e Zd Zdeddf fddZdejdejfddZ  ZS )	DeiTIntermediater   r   Nc                    sD   t    t|j|j| _t|jt	rt
|j | _d S |j| _d S r   )r!   r"   r   r   r%   intermediate_sizer   r]   Z
hidden_actstrr   intermediate_act_fnr   r2   r4   r5   r"   >  s
   
zDeiTIntermediate.__init__r   c                 C   r   r   )r   r   )r1   r   r4   r4   r5   rO   F  r   zDeiTIntermediate.forward	rQ   rR   rS   r   r"   r$   rV   rO   rY   r4   r4   r2   r5   r   =  s    r   c                       sB   e Zd Zdeddf fddZdejdejdejfdd	Z  ZS )

DeiTOutputr   r   Nc                    s.   t    t|j|j| _t|j| _	d S r   )
r!   r"   r   r   r   r%   r   r-   r.   r/   r   r2   r4   r5   r"   O  r   zDeiTOutput.__init__r   r   c                 C   s    |  |}| |}|| }|S r   r   r   r4   r4   r5   rO   T  s   

zDeiTOutput.forwardr   r4   r4   r2   r5   r   N  s    $r   c                       sl   e Zd ZdZdeddf fddZ		ddejd	eej d
e	de
eejejf eej f fddZ  ZS )	DeiTLayerz?This corresponds to the Block class in the timm implementation.r   r   Nc                    sb   t    |j| _d| _t|| _t|| _t|| _	t
j|j|jd| _t
j|j|jd| _d S )Nr   Zeps)r!   r"   Zchunk_size_feed_forwardZseq_len_dimr   r   r   intermediater   r   r   	LayerNormr%   layer_norm_epslayernorm_beforelayernorm_afterr   r2   r4   r5   r"   a  s   



zDeiTLayer.__init__Fr   r   r   c                 C   s`   | j | |||d}|d }|dd  }|| }| |}| |}| ||}|f| }|S )N)r   r   r   )r   r   r   r   r   )r1   r   r   r   Zself_attention_outputsr   r   Zlayer_outputr4   r4   r5   rO   k  s   


zDeiTLayer.forwardrP   )rQ   rR   rS   rT   r   r"   r$   rV   r   rU   r   r   rO   rY   r4   r4   r2   r5   r   ^  s    r   c                       sb   e Zd Zdeddf fddZ				ddejd	eej d
ededede	e
ef fddZ  ZS )DeiTEncoderr   r   Nc                    s:   t     | _t fddt jD | _d| _d S )Nc                    s   g | ]}t  qS r4   )r   ).0rL   r   r4   r5   
<listcomp>  s    z(DeiTEncoder.__init__.<locals>.<listcomp>F)	r!   r"   r   r   Z
ModuleListrangenum_hidden_layerslayergradient_checkpointingr   r2   r   r5   r"     s   
 
zDeiTEncoder.__init__FTr   r   r   output_hidden_statesreturn_dictc                 C   s   |rdnd }|r
dnd }t | jD ]8\}}	|r||f }|d ur$|| nd }
| jr6| jr6| |	j||
|}n|	||
|}|d }|rI||d f }q|rQ||f }|s_tdd |||fD S t|||dS )Nr4   r   r   c                 s   s    | ]	}|d ur|V  qd S r   r4   )r   vr4   r4   r5   	<genexpr>  s    z&DeiTEncoder.forward.<locals>.<genexpr>)last_hidden_stater   
attentions)	enumerater   r   rq   Z_gradient_checkpointing_func__call__tupler   )r1   r   r   r   r   r   Zall_hidden_statesZall_self_attentionsiZlayer_moduleZlayer_head_maskZlayer_outputsr4   r4   r5   rO     s6   

zDeiTEncoder.forward)NFFT)rQ   rR   rS   r   r"   r$   rV   r   rU   r   r   r   rO   rY   r4   r4   r2   r5   r     s&    	
r   c                   @   sL   e Zd ZeZdZdZdZdgZdZ	dZ
deejejejf ddfdd	ZdS )
DeiTPreTrainedModeldeitrG   Tr   rh   r   Nc                 C   s   t |tjtjfr0tjj|jjt	j
d| jjd|jj|j_|jdur.|jj  dS dS t |tjrE|jj  |jjd dS t |tri|jj  |jj  |jj  |jdurk|jj  dS dS dS )zInitialize the weightsrg   )meanZstdNrI   )r]   r   r   ra   initZtrunc_normal_weightdatart   r$   rs   r   Zinitializer_rangero   rz   Zzero_r   Zfill_r   r&   r,   r'   r(   )r1   rh   r4   r4   r5   _init_weights  s(   



z!DeiTPreTrainedModel._init_weights)rQ   rR   rS   r   Zconfig_classZbase_model_prefixZmain_input_nameZsupports_gradient_checkpointingZ_no_split_modulesZ_supports_sdpaZ_supports_flash_attn_2r   r   r   ra   r   r   r4   r4   r4   r5   r     s    &r   c                       s   e Zd Zddedededdf fdd	Zdefd
dZdd Ze								dde
ej de
ej de
ej de
e de
e de
e dedeeef fddZ  ZS )	DeiTModelTFr   add_pooling_layerr   r   Nc                    s\   t  | || _t||d| _t|| _tj|j	|j
d| _|r%t|nd| _|   dS )z
        add_pooling_layer (bool, *optional*, defaults to `True`):
            Whether to add a pooling layer
        use_mask_token (`bool`, *optional*, defaults to `False`):
            Whether to use a mask token for masked image modeling.
        )r   r   N)r!   r"   r   r   r6   r   encoderr   r   r%   r   	layernorm
DeiTPoolerpooler	post_init)r1   r   r   r   r2   r4   r5   r"     s   
zDeiTModel.__init__c                 C   s   | j jS r   )r6   r*   )r1   r4   r4   r5   get_input_embeddings  s   zDeiTModel.get_input_embeddingsc                 C   s*   |  D ]\}}| jj| j| qdS )z
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        N)itemsr   r   r   r   )r1   Zheads_to_pruner   r   r4   r4   r5   _prune_heads  s   zDeiTModel._prune_headsrG   rH   r   r   r   r   rF   c                 C   s
  |dur|n| j j}|dur|n| j j}|dur|n| j j}|du r&td| || j j}| jjj	j
j}|j|kr?||}| j|||d}	| j|	||||d}
|
d }| |}| jdurd| |nd}|s{|durp||fn|f}||
dd  S t|||
j|
jdS )z
        bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
            Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
        Nz You have to specify pixel_values)rH   rF   )r   r   r   r   r   r   )r   Zpooler_outputr   r   )r   r   r   use_return_dictrc   Zget_head_maskr   r6   r*   rb   r   ro   rt   r   r   r   r   r   r   )r1   rG   rH   r   r   r   r   rF   Zexpected_dtypeZembedding_outputZencoder_outputssequence_outputpooled_outputZhead_outputsr4   r4   r5   rO     s@   


zDeiTModel.forward)TFNNNNNNF)rQ   rR   rS   r   rU   r"   r)   r   r   r   r   r$   rV   rX   r   r   r   rO   rY   r4   r4   r2   r5   r     s:     
	r   c                       s*   e Zd Zdef fddZdd Z  ZS )r   r   c                    s,   t    t|j|j| _t|j | _	d S r   )
r!   r"   r   r   r%   Zpooler_output_sizer   r   Z
pooler_act
activationr   r2   r4   r5   r"   <  s   
zDeiTPooler.__init__c                 C   s(   |d d df }|  |}| |}|S )Nr   )r   r   )r1   r   Zfirst_token_tensorr   r4   r4   r5   rO   A  s   

zDeiTPooler.forward)rQ   rR   rS   r   r"   rO   rY   r4   r4   r2   r5   r   ;  s    r   a\  
    DeiT Model with a decoder on top for masked image modeling, as proposed in [SimMIM](https://arxiv.org/abs/2111.09886).

    <Tip>

    Note that we provide a script to pre-train this model on custom data in our [examples
    directory](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-pretraining).

    </Tip>
    )Zcustom_introc                       s   e Zd Zdeddf fddZe							ddeej deej	 d	eej d
ee
 dee
 dee
 de
deeef fddZ  ZS )DeiTForMaskedImageModelingr   r   Nc                    sX   t  | t|ddd| _ttj|j|jd |j	 ddt
|j| _|   d S )NFT)r   r   r    r   )Zin_channelsZout_channelsrZ   )r!   r"   r   r   r   Z
Sequentialra   r%   Zencoder_strider\   ZPixelShuffledecoderr   r   r2   r4   r5   r"   W  s   

z#DeiTForMaskedImageModeling.__init__FrG   rH   r   r   r   r   rF   c              	   C   sJ  |dur|n| j j}| j|||||||d}|d }	|	ddddf }	|	j\}
}}t|d  }}|	ddd|
|||}	| |	}d}|dur| j j| j j	 }|d||}|
| j j	d
| j j	dd }tjj||dd	}||  | d
  | j j }|s|f|dd  }|dur|f| S |S t|||j|jdS )a;  
        bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
            Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).

        Examples:
        ```python
        >>> from transformers import AutoImageProcessor, DeiTForMaskedImageModeling
        >>> import torch
        >>> from PIL import Image
        >>> import requests

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> image_processor = AutoImageProcessor.from_pretrained("facebook/deit-base-distilled-patch16-224")
        >>> model = DeiTForMaskedImageModeling.from_pretrained("facebook/deit-base-distilled-patch16-224")

        >>> num_patches = (model.config.image_size // model.config.patch_size) ** 2
        >>> pixel_values = image_processor(images=image, return_tensors="pt").pixel_values
        >>> # create random boolean mask of shape (batch_size, num_patches)
        >>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()

        >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
        >>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction
        >>> list(reconstructed_pixel_values.shape)
        [1, 3, 224, 224]
        ```N)rH   r   r   r   r   rF   r   r   r9   r:   r    none)Z	reductiongh㈵>)lossZreconstructionr   r   )r   r   r   r?   rW   rB   rA   r   r[   r0   Zrepeat_interleaverK   ru   r   rC   Zl1_losssumr\   r   r   r   )r1   rG   rH   r   r   r   r   rF   r   r   rM   Zsequence_lengthr\   r7   r8   Zreconstructed_pixel_valuesZmasked_im_lossr;   rN   Zreconstruction_lossr   r4   r4   r5   rO   h  sH   &

 z"DeiTForMaskedImageModeling.forwardr   )rQ   rR   rS   r   r"   r   r   r$   rV   rX   rU   r   r   r   rO   rY   r4   r4   r2   r5   r   J  s6    
	r   z
    DeiT Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
    the [CLS] token) e.g. for ImageNet.
    c                       s   e Zd Zdeddf fddZe							ddeej deej d	eej d
ee	 dee	 dee	 de	de
eef fddZ  ZS )DeiTForImageClassificationr   r   Nc                    sR   t  | |j| _t|dd| _|jdkrt|j|jnt | _	| 
  d S NF)r   r   )r!   r"   
num_labelsr   r   r   r   r%   Identity
classifierr   r   r2   r4   r5   r"     s
   $z#DeiTForImageClassification.__init__FrG   r   labelsr   r   r   rF   c                 C   s  |dur|n| j j}| j||||||d}|d }	| |	dddddf }
d}|dur||
j}| j jdu r]| jdkrCd| j _n| jdkrY|jt	j
ksT|jt	jkrYd| j _nd| j _| j jdkr{t }| jdkru||
 | }n+||
|}n%| j jdkrt }||
d| j|d}n| j jdkrt }||
|}|s|
f|dd  }|dur|f| S |S t||
|j|jd	S )
aZ  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).

        Examples:

        ```python
        >>> from transformers import AutoImageProcessor, DeiTForImageClassification
        >>> import torch
        >>> from PIL import Image
        >>> import requests

        >>> torch.manual_seed(3)  # doctest: +IGNORE_RESULT
        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> # note: we are loading a DeiTForImageClassificationWithTeacher from the hub here,
        >>> # so the head will be randomly initialized, hence the predictions will be random
        >>> image_processor = AutoImageProcessor.from_pretrained("facebook/deit-base-distilled-patch16-224")
        >>> model = DeiTForImageClassification.from_pretrained("facebook/deit-base-distilled-patch16-224")

        >>> inputs = image_processor(images=image, return_tensors="pt")
        >>> outputs = model(**inputs)
        >>> logits = outputs.logits
        >>> # model predicts one of the 1000 ImageNet classes
        >>> predicted_class_idx = logits.argmax(-1).item()
        >>> print("Predicted class:", model.config.id2label[predicted_class_idx])
        Predicted class: Polaroid camera, Polaroid Land camera
        ```Nr   r   r   r   rF   r   r   Z
regressionZsingle_label_classificationZmulti_label_classificationr9   )r   logitsr   r   )r   r   r   r   rt   ZdeviceZproblem_typer   ro   r$   longrW   r   Zsqueezer
   rD   r	   r   r   r   )r1   rG   r   r   r   r   r   rF   r   r   r   r   Zloss_fctr   r4   r4   r5   rO     sP   *	

"


z"DeiTForImageClassification.forwardr   )rQ   rR   rS   r   r"   r   r   r$   rV   rU   r   r   r   rO   rY   r4   r4   r2   r5   r     s6    
	r   c                   @   st   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eej ed< dZeeej  ed< dZeeej  ed< dS )+DeiTForImageClassificationWithTeacherOutputa5  
    Output type of [`DeiTForImageClassificationWithTeacher`].

    Args:
        logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
            Prediction scores as the average of the cls_logits and distillation logits.
        cls_logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
            Prediction scores of the classification head (i.e. the linear layer on top of the final hidden state of the
            class token).
        distillation_logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
            Prediction scores of the distillation head (i.e. the linear layer on top of the final hidden state of the
            distillation token).
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
            plus the initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
            the self-attention heads.
    Nr   
cls_logitsdistillation_logitsr   r   )rQ   rR   rS   rT   r   r   r$   ZFloatTensor__annotations__r   r   r   r   r   r4   r4   r4   r5   r   .  s   
 r   a  
    DeiT Model transformer with image classification heads on top (a linear layer on top of the final hidden state of
    the [CLS] token and a linear layer on top of the final hidden state of the distillation token) e.g. for ImageNet.

    .. warning::

           This model supports inference-only. Fine-tuning with distillation (i.e. with a teacher) is not yet
           supported.
    c                       s~   e Zd Zdeddf fddZe						ddeej deej d	ee	 d
ee	 dee	 de	de
eef fddZ  ZS )%DeiTForImageClassificationWithTeacherr   r   Nc                    sv   t  | |j| _t|dd| _|jdkrt|j|jnt | _	|jdkr0t|j|jnt | _
|   d S r   )r!   r"   r   r   r   r   r   r%   r   cls_classifierdistillation_classifierr   r   r2   r4   r5   r"   Y  s     z.DeiTForImageClassificationWithTeacher.__init__FrG   r   r   r   r   rF   c                 C   s   |d ur|n| j j}| j||||||d}|d }| |d d dd d f }	| |d d dd d f }
|	|
 d }|sJ||	|
f|dd   }|S t||	|
|j|jdS )Nr   r   r   r    )r   r   r   r   r   )r   r   r   r   r   r   r   r   )r1   rG   r   r   r   r   rF   r   r   r   r   r   r   r4   r4   r5   rO   j  s.   
	z-DeiTForImageClassificationWithTeacher.forward)NNNNNF)rQ   rR   rS   r   r"   r   r   r$   rV   rU   r   r   r   rO   rY   r4   r4   r2   r5   r   M  s0    
r   )r   r   r   r   r   )rg   )>rT   collections.abcr^   dataclassesr   typingr   r   r   r   r   r$   Ztorch.utils.checkpointr   Ztorch.nnr	   r
   r   Zactivationsr   Zmodeling_outputsr   r   r   r   Zmodeling_utilsr   r   Zpytorch_utilsr   r   utilsr   r   r   r   Zconfiguration_deitr   Z
get_loggerrQ   r   Moduler   r)   rV   floatrw   rx   r   r   r   r   r   r   r   r   r   r   r   r   r   __all__r4   r4   r4   r5   <module>   sx   
Y(
?(+3_hj<