o
    Zh?                  	   @   s  d Z ddlZddlZddlZddlmZ ddlmZm	Z	m
Z
mZ ddlZddlZddlmZ ddlmZmZmZ ddlmZ dd	lmZmZmZmZ dd
lmZ ddlmZmZmZ ddl m!Z!m"Z"m#Z# ddl$m%Z% e"&e'Z(eG dd deZ)dIdej*de+de,dej*fddZ-G dd dej.Z/G dd dej.Z0G dd dej.Z1G dd  d ej.Z2G d!d" d"e2Z3G d#d$ d$ej.Z4e2e3d%Z5G d&d' d'ej.Z6G d(d) d)ej.Z7G d*d+ d+ej.Z8G d,d- d-ej.Z9G d.d/ d/ej.Z:G d0d1 d1ej.Z;e!G d2d3 d3eZ<e!G d4d5 d5e<Z=G d6d7 d7ej.Z>e!d8d9G d:d; d;e<Z?G d<d= d=ej.Z@G d>d? d?ej.ZAG d@dA dAej.ZBG dBdC dCej.ZCG dDdE dEej.ZDe!G dFdG dGe<ZEg dHZFdS )JzPyTorch Data2VecVision model.    N)	dataclass)ListOptionalTupleUnion)nn)BCEWithLogitsLossCrossEntropyLossMSELoss   )ACT2FN)BaseModelOutputBaseModelOutputWithPoolingImageClassifierOutputSemanticSegmenterOutput)PreTrainedModel)#compile_compatible_method_lru_cache find_pruneable_heads_and_indicesprune_linear_layer)auto_docstringlogging	torch_int   )Data2VecVisionConfigc                   @   s   e Zd ZdZdS )$Data2VecVisionModelOutputWithPoolinga  
    Class for outputs of [`Data2VecVisionModel`].

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
            Average of the last layer hidden states of the patch tokens (excluding the *[CLS]* token) if
            *config.use_mean_pooling* is set to True. If set to False, then the final hidden state of the *[CLS]* token
            will be returned.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    N)__name__
__module____qualname____doc__ r   r   d/var/www/auris/lib/python3.10/site-packages/transformers/models/data2vec/modeling_data2vec_vision.pyr   ,   s    r           Finput	drop_probtrainingreturnc                 C   sd   |dks|s| S d| }| j d fd| jd   }|tj|| j| jd }|  | || }|S )aF  
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

    Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
    however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
    layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
    argument.
    r!   r   r   r   )dtypedevice)shapendimtorchZrandr'   r(   Zfloor_div)r"   r#   r$   Z	keep_probr)   Zrandom_tensoroutputr   r   r    	drop_pathH   s   
r.   c                       sT   e Zd ZdZddee ddf fddZdejdejfdd	Z	de
fd
dZ  ZS )Data2VecVisionDropPathzXDrop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).Nr#   r%   c                    s   t    || _d S N)super__init__r#   )selfr#   	__class__r   r    r2   `   s   

zData2VecVisionDropPath.__init__hidden_statesc                 C   s   t || j| jS r0   )r.   r#   r$   r3   r6   r   r   r    forwardd   s   zData2VecVisionDropPath.forwardc                 C   s   d | jS )Nzp={})formatr#   r3   r   r   r    
extra_reprg   s   z!Data2VecVisionDropPath.extra_reprr0   )r   r   r   r   r   floatr2   r+   Tensorr8   strr;   __classcell__r   r   r4   r    r/   ]   s
    r/   c                	       sv   e Zd ZdZdeddf fddZdejded	edejfd
dZ			ddejde
ej de
e dejfddZ  ZS )Data2VecVisionEmbeddingszc
    Construct the CLS token, position and patch embeddings. Optionally, also the mask token.

    configr%   Nc                    s   t    ttdd|j| _|jr!ttdd|j| _	nd | _	t
|| _|j| _t|jtjjr8|jn|j|jf| _| jj}|jrUttd|d |j| _nd | _t|j| _d S )Nr   )r1   r2   r   	Parameterr+   zeroshidden_size	cls_tokenZuse_mask_token
mask_tokenData2VecVisionPatchEmbeddingspatch_embeddings
patch_size
isinstance
image_sizecollectionsabcIterablenum_patchesZ use_absolute_position_embeddingsposition_embeddingsDropouthidden_dropout_probdropout)r3   rA   rO   r4   r   r    r2   r   s    


z!Data2VecVisionEmbeddings.__init__
embeddingsheightwidthc                 C   s   |j d d }| jj d d }tj s||kr||kr| jS | jddddf }| jddddf }|j d }|| j }	|| j }
t|d }|d|||}|dddd}t	j
j||	|
fdd	d
}|dddddd|}tj||fddS )a   
        This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
        images. This method is also adapted to support torch.jit tracing.

        Adapted from:
        - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
        - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
        r   Ng      ?r   r      ZbicubicFsizemodealign_cornersdim)r)   rP   r+   Zjit
is_tracingrI   r   reshapepermuter   
functionalinterpolateviewcat)r3   rT   rU   rV   rO   Znum_positionsZclass_pos_embedZpatch_pos_embedr^   
new_height	new_widthZsqrt_num_positionsr   r   r    interpolate_pos_encoding   s(   



z1Data2VecVisionEmbeddings.interpolate_pos_encodingpixel_valuesbool_masked_posrh   c                 C   s   | j d ur|d urtd |j\}}}}| |\}\}}	| \}
}}|d urC| j|
|d}|d	|}|d|  ||  }| j
|
dd}tj||fdd}| j d urb|| ||| }| |}|||	ffS )Nz`interpolate_pos_encoding` argument has no effect for BEiTEmbeddings, embeddings are always interpolated to the input image size. The argument will be removed in transformers v4.51.0.rW   r   r]   )rP   warningswarnr)   rH   rZ   rF   expand	unsqueezeZtype_asrE   r+   re   rh   rS   )r3   ri   rj   rh   _rU   rV   rT   patch_heightpatch_width
batch_sizeZseq_lenZmask_tokenswZ
cls_tokensr   r   r    r8      s"   

z Data2VecVisionEmbeddings.forwardNN)r   r   r   r   r   r2   r+   r=   intrh   r   
BoolTensorboolr8   r?   r   r   r4   r    r@   l   s    +r@   c                       s6   e Zd ZdZ fddZdejdejfddZ  ZS )rG   z
    This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
    `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
    Transformer.
    c                    s   t    |j|j}}|j|j}}t|tjj	r|n||f}t|tjj	r)|n||f}|d |d  |d |d   }|d |d  |d |d  f}|| _|| _|| _|| _
|| _tj||||d| _d S )Nr   r   kernel_sizeZstride)r1   r2   rK   rI   num_channelsrD   rJ   rL   rM   rN   rO   patch_shaper   Conv2d
projection)r3   rA   rK   rI   rz   rD   rO   r{   r4   r   r    r2      s   
  z&Data2VecVisionPatchEmbeddings.__init__ri   r%   c           	      C   s^   |j \}}}}|| jkrtd| |}|j d |j d }}|ddd}|||ffS )NzeMake sure that the channel dimension of the pixel values match with the one set in the configuration.rX   r   r   )r)   rz   
ValueErrorr}   flatten	transpose)	r3   ri   rr   rz   rU   rV   rT   rp   rq   r   r   r    r8      s   

z%Data2VecVisionPatchEmbeddings.forward)	r   r   r   r   r2   r+   r=   r8   r?   r   r   r4   r    rG      s    rG   c                       s   e Zd Zddedee ddf fddZdd Z							dd
ej	deej	 de
deej	 de
deee  deeej	 eej	ej	f f fddZ  ZS )Data2VecVisionSelfAttentionNrA   window_sizer%   c                    s   t    || _|j|j dkr"t|ds"td|j d|j d|j| _t|j|j | _| j| j | _	t
|j| j	| _t
j|j| j	dd| _t
|j| j	| _t
|j| _t|| _| jrkt||d| _d S d S )	Nr   Zembedding_sizezThe hidden size z4 is not a multiple of the number of attention heads .F)biasr   )r1   r2   rA   rD   num_attention_headshasattrr~   ru   attention_head_sizeall_head_sizer   LinearquerykeyvaluerQ   attention_probs_dropout_probrS   rw   has_relative_position_bias"Data2VecVisionRelativePositionBiasrelative_position_biasr3   rA   r   r4   r   r    r2      s&   


z$Data2VecVisionSelfAttention.__init__c                 C   s6   |  d d | j| jf }|j| }|ddddS )NrW   r   rX   r   r   )rZ   r   r   rd   ra   )r3   xZnew_x_shaper   r   r    transpose_for_scores  s   
z0Data2VecVisionSelfAttention.transpose_for_scoresFr6   	head_maskoutput_attentionsr   rh   
resolutionc                 C   s.  |  |}| | |}| | |}	| |}
t|
|dd}|t| j	 }| j
rL|\}}|| jj || jj f}|| j|||jd d }|d urT|| }tjj|dd}| |}|d uri|| }t||	}|dddd }| d d | jf }|j| }|r||f}|S |f}|S )	NrW   r   dim_sizer]   r   rX   r   )r   r   r   r   r+   matmulr   mathsqrtr   r   rA   rI   r   r)   r   rb   ZsoftmaxrS   ra   
contiguousrZ   r   rd   )r3   r6   r   r   r   rh   r   mixed_query_layer	key_layervalue_layerquery_layerZattention_scoresrU   rV   r   Zattention_probscontext_layernew_context_layer_shapeoutputsr   r   r    r8     s4   
	


z#Data2VecVisionSelfAttention.forwardr0   NFNFN)r   r   r   r   r   tupler2   r   r+   r=   rw   r   ru   r   r8   r?   r   r   r4   r    r      s.     
r   c                       sv   e Zd Z					ddejdeej dedeej dedeee  d	e	eej eejejf f f fd
dZ
  ZS )Data2VecVisionSdpaSelfAttentionNFr6   r   r   r   rh   r   r%   c              	      s8  |s|d urt d t j||||||dS | |}| | |}| | |}	| |}
d }| jrQ|\}}|| j	j
 || j	j
 f}| j|||jd d}|d ur`|d u r\|}n||7 }dt| j }tjjj|
||	|| jrw| j	jndd|d}|dd	dd
 }| d d | jf }|j| }|d fS )Na  `Data2VecVisionSdpaSelfAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.)r6   r   r   r   rh   r   r   r   r!   F)Z	attn_maskZ	dropout_pZ	is_causalscaler   rX   r   r   )loggerZwarning_oncer1   r8   r   r   r   r   r   rA   rI   r   r)   r   r   r   r+   r   rb   Zscaled_dot_product_attentionr$   r   ra   r   rZ   r   rd   )r3   r6   r   r   r   rh   r   r   r   r   r   Z	attn_biasrU   rV   r   Zscalingr   r   r4   r   r    r8   O  sR   	
	
	
z'Data2VecVisionSdpaSelfAttention.forwardr   )r   r   r   r+   r=   r   rw   r   ru   r   r8   r?   r   r   r4   r    r   N  s*    
r   c                       sH   e Zd ZdZdeddf fddZddejdejdejfd	d
Z  Z	S )Data2VecVisionSelfOutputz
    The residual connection is defined in Data2VecVisionLayer instead of here (as is the case with other models), due to the
    layernorm applied before each block.
    rA   r%   Nc                    s.   t    t|j|j| _t|j| _d S r0   )	r1   r2   r   r   rD   denserQ   rR   rS   r3   rA   r4   r   r    r2        
z!Data2VecVisionSelfOutput.__init__r6   input_tensorc                 C      |  |}| |}|S r0   r   rS   )r3   r6   r   gammar   r   r    r8        

z Data2VecVisionSelfOutput.forwardr0   )
r   r   r   r   r   r2   r+   r=   r8   r?   r   r   r4   r    r     s    &r   )eagerZsdpac                       s   e Zd Zddedee ddf fddZdd Z							dd
ej	deej	 de
ded de
deee  deeej	 eej	ej	f f fddZ  ZS )Data2VecVisionAttentionNrA   r   r%   c                    s4   t    t|j ||d| _t|| _t | _d S )Nr   )	r1   r2   &DATA2VEC_VISION_SELF_ATTENTION_CLASSESZ_attn_implementation	attentionr   r-   setpruned_headsr   r4   r   r    r2     s   

z Data2VecVisionAttention.__init__c                 C   s   t |dkrd S t|| jj| jj| j\}}t| jj|| j_t| jj|| j_t| jj	|| j_	t| j
j|dd| j
_| jjt | | j_| jj| jj | j_| j|| _d S )Nr   r   r]   )lenr   r   r   r   r   r   r   r   r   r-   r   r   union)r3   headsindexr   r   r    prune_heads  s   z#Data2VecVisionAttention.prune_headsFr6   r   r   r   r   rh   r   c           
      C   s:   |  ||||||}| |d |}|f|dd   }	|	S )Nr   r   )r   r-   )
r3   r6   r   r   r   rh   r   Zself_outputsattention_outputr   r   r   r    r8     s   	zData2VecVisionAttention.forwardr0   r   )r   r   r   r   r   r   r2   r   r+   r=   rw   r   ru   r   r8   r?   r   r   r4   r    r     s.     
r   c                       <   e Zd Zdeddf fddZdejdejfddZ  ZS )	Data2VecVisionIntermediaterA   r%   Nc                    sD   t    t|j|j| _t|jt	rt
|j | _d S |j| _d S r0   )r1   r2   r   r   rD   intermediate_sizer   rJ   Z
hidden_actr>   r   intermediate_act_fnr   r4   r   r    r2     s
   
z#Data2VecVisionIntermediate.__init__r6   c                 C   r   r0   )r   r   r7   r   r   r    r8     r   z"Data2VecVisionIntermediate.forward	r   r   r   r   r2   r+   r=   r8   r?   r   r   r4   r    r     s    r   c                       r   )	Data2VecVisionOutputrA   r%   Nc                    s.   t    t|j|j| _t|j| _	d S r0   )
r1   r2   r   r   r   rD   r   rQ   rR   rS   r   r4   r   r    r2     r   zData2VecVisionOutput.__init__r6   c                 C   r   r0   r   r7   r   r   r    r8     r   zData2VecVisionOutput.forwardr   r   r   r4   r    r     s    r   c                       s   e Zd ZdZ	ddedee deddf fdd	Z		
		
	dde	j
dee	j
 dedee	j
 dedeee  deee	j
 ee	j
e	j
f f fddZ  ZS )Data2VecVisionLayerz?This corresponds to the Block class in the timm implementation.Nr!   rA   r   drop_path_rater%   c                    s   t    |j| _d| _t||d| _t|| _t|| _	t
j|j|jd| _|dkr/t|nt
 | _t
j|j|jd| _|j}|dkrct
j|t|j dd| _t
j|t|j dd| _d S d\| _| _d S )	Nr   r   Zepsr!   r   T)Zrequires_gradrt   )r1   r2   Zchunk_size_feed_forwardZseq_len_dimr   r   r   intermediater   r-   r   	LayerNormrD   layer_norm_epslayernorm_beforer/   Identityr.   layernorm_afterlayer_scale_init_valuerB   r+   Zoneslambda_1lambda_2)r3   rA   r   r   Zinit_valuesr4   r   r    r2     s   


 zData2VecVisionLayer.__init__Fr6   r   r   r   rh   r   c                 C   s   | j | ||||||d}|d }|dd  }	| jd ur"| j| }| || }| |}
| |
}
| |
}
| jd urB| j|
 }
| |
| }
|
f|	 }	|	S )N)r   r   rh   r   r   r   )r   r   r   r.   r   r   r-   r   )r3   r6   r   r   r   rh   r   Zself_attention_outputsr   r   Zlayer_outputr   r   r    r8     s*   	







zData2VecVisionLayer.forward)Nr!   r   )r   r   r   r   r   r   r   r<   r2   r+   r=   rw   r   ru   r   r8   r?   r   r   r4   r    r     s@    
r   c                       sf   e Zd Zdededdf fddZedddeeef de	j
fd	d
Zddede	j
fddZ  ZS )r   rA   r   r%   Nc                    sR   t    || _d|d  d d|d  d  d | _tt| j|j| _	d S )NrX   r   r   r   )
r1   r2   r   num_relative_distancer   rB   r+   rC   r   relative_position_bias_tabler   r4   r   r    r2   9  s   
&
z+Data2VecVisionRelativePositionBias.__init__
   )maxsizec           	      C   s  d|d  d d|d  d  d }|d |d  }t jt |d t |d dd}t |}t |d}|dddddf |dddddf  }|ddd }|dddddf  |d d 7  < |dddddf  |d d 7  < |dddddf  d|d  d 9  < t j|d fd |jd}|	d	|ddddf< |d |dddf< |d |dddf< |d |d
< |S )z
        This method creates the relative position index, modified to support arbitrary window sizes,
        as introduced in [MiDaS v3.1](https://arxiv.org/abs/2307.14460).
        rX   r   r   r   Zij)ZindexingN)rZ   r'   rW   )r   r   )
r+   ZmeshgridZarangestackr   ra   r   rC   r'   sum)	r3   r   r   Zwindow_areagridZcoordsZcoords_flattenZrelative_coordsrelative_position_indexr   r   r     generate_relative_position_indexB  s    $$
,&&*zCData2VecVisionRelativePositionBias.generate_relative_position_indexFrh   c                 C   sf  d| j d  d }d| j d  d }d|d  d }d|d  d }| j}| j}	|| d }
|d|	d  }|d||ddddd}tjj|t|t|fdd}|dddd|
d d}t	
|||	d d g}| |}||d }||d |d  d |d |d  d d}|ddd }|rtjj|d||fdd	d
d}|dS )zu
        Modification of timm.models.beit.py: Attention._get_rel_pos_bias to support arbitrary window sizes.
        rX   r   r   r   NrW   bilinear)rZ   r[   FrY   )r   r   r   r`   ra   r   rb   rc   r   r+   re   r   rd   r   rn   squeeze)r3   r   rh   r   Z
old_heightZ	old_widthrf   rg   Z old_relative_position_bias_tableZold_num_relative_distanceZnew_num_relative_distanceZold_sub_tableZnew_sub_tableZ new_relative_position_bias_tabler   r   r   r   r    r8   [  s@   
&
z*Data2VecVisionRelativePositionBias.forward)FN)r   r   r   r   r   r2   r   r   ru   r+   r=   r   rw   r8   r?   r   r   r4   r    r   8  s
    	r   c                       s   e Zd Zddedee ddf fddZ						dd	ejd
eej de	de	de	dee
eef  de	deeef fddZ  ZS )Data2VecVisionEncoderNrA   r   r%   c                    sz   t     | _ j| _| jrt d| _dd tjd j	 j
ddD t fddt j
D | _d| _d S )	Nr   c                 S   s   g | ]}|  qS r   )item.0r   r   r   r    
<listcomp>  s    z2Data2VecVisionEncoder.__init__.<locals>.<listcomp>r   cpu)r(   c                    s(   g | ]}t   jrnd | dqS )N)r   r   )r   Zuse_relative_position_biasr   irA   Zdprr   r   r    r     s    F)r1   r2   rA   Z!use_shared_relative_position_biasr   r   r   r+   Zlinspacer   num_hidden_layersr   
ModuleListrangelayergradient_checkpointingr   r4   r   r    r2     s   
 

zData2VecVisionEncoder.__init__FTr6   r   r   output_hidden_statesrh   r   return_dictc              
   C   s  |rdnd }|r
dnd }	t | jD ]_\}
}|r||f }| jr;|\}}|| jj || jj f}| j|||jd d}nd }|d urE||
 nd }| jrZ| jrZ| 	|j
||||||}n	|||||||}|d }|rp|	|d f }	q|rx||f }|stdd |||	fD S t|||	dS )Nr   r   )rh   r   r   c                 s   s    | ]	}|d ur|V  qd S r0   r   )r   vr   r   r    	<genexpr>  s    z0Data2VecVisionEncoder.forward.<locals>.<genexpr>)last_hidden_stater6   
attentions)	enumerater   r   rA   rI   r   r)   r   r$   Z_gradient_checkpointing_func__call__r   r   )r3   r6   r   r   r   rh   r   r   Zall_hidden_statesZall_self_attentionsr   Zlayer_modulerU   rV   r   r   Zlayer_head_maskZlayer_outputsr   r   r    r8     sX   


	
zData2VecVisionEncoder.forwardr0   )NFFFNT)r   r   r   r   r   r   r2   r+   r=   rw   r   ru   r   r   r8   r?   r   r   r4   r    r     s2     
	r   c                   @   s4   e Zd ZeZdZdZdZdgZdgZ	dZ
dd ZdS )	Data2VecVisionPreTrainedModeldata2vec_visionri   Tr   z.*relative_position_index.*c                 C   sb  t |tjtjtjfr%|jjjd| jj	d |j
dur#|j
j  dS dS t |tjrH|jjjd| jj	d |jdurF|jj|j   dS dS t |tjr]|j
j  |jjd dS t |tr|jj  |jdurs|jj  |jdur|jj  dS dS t |tr|jj  dS t |tr|jdur|jj| jj |jj| jj dS dS dS )zInitialize the weightsr!   )meanZstdNg      ?)rJ   r   r   r|   ConvTranspose2dweightdataZnormal_rA   Zinitializer_ranger   Zzero_Z	EmbeddingZpadding_idxr   Zfill_r@   rE   rF   rP   r   r   r   r   r   r   )r3   moduler   r   r    _init_weights  s8   







z+Data2VecVisionPreTrainedModel._init_weightsN)r   r   r   r   Zconfig_classZbase_model_prefixZmain_input_nameZsupports_gradient_checkpointingZ_no_split_modulesZ"_keys_to_ignore_on_load_unexpectedZ_supports_sdpar   r   r   r   r    r     s    r   c                       s   e Zd Zddededdf fddZdd	 Zd
d Ze						dde	j
dee	j dee	j
 dee dee dedee deeef fddZ  ZS )Data2VecVisionModelFrA   add_pooling_layerr%   Nc                    sp   t  | || _t|| _t|| jjjd| _|j	rt
 nt
j|j|jd| _|r/t|nd| _|   dS )zw
        add_pooling_layer (bool, *optional*, defaults to `False`):
            Whether to add a pooling layer
        r   r   N)r1   r2   rA   r@   rT   r   rH   r{   encoderuse_mean_poolingr   r   r   rD   r   	layernormData2VecVisionPoolerpooler	post_init)r3   rA   r   r4   r   r    r2     s   
zData2VecVisionModel.__init__c                 C   s   | j jS r0   )rT   rH   r:   r   r   r    get_input_embeddings!  s   z(Data2VecVisionModel.get_input_embeddingsc                 C   s*   |  D ]\}}| jj| j| qdS )z
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        N)itemsr   r   r   r   )r3   Zheads_to_pruner   r   r   r   r    _prune_heads$  s   z Data2VecVisionModel._prune_headsri   rj   r   r   r   rh   r   c              	   C   s   |dur|n| j j}|dur|n| j j}|dur|n| j j}| || j j}| j||d\}}	|jdd }
| j|||||
||d}|d }| 	|}| j
durU| 
|nd}|sl|dura||fn|f}||dd  S t|||j|jdS )z
        bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
            Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
        N)rj   rX   )r   r   r   r   r   rh   r   r   )r   pooler_outputr6   r   )rA   r   r   use_return_dictZget_head_maskr   rT   r)   r   r   r  r   r6   r   )r3   ri   rj   r   r   r   rh   r   Zembedding_outputro   r   Zencoder_outputsZsequence_outputpooled_outputZhead_outputsr   r   r    r8   ,  s8   	
zData2VecVisionModel.forward)F)NNNNFN)r   r   r   r   rw   r2   r  r  r   r+   r=   r   rv   r   r   r   r8   r?   r   r   r4   r    r     s8    
	r   c                       r   )	r   rA   r%   Nc                    s2   t    |jrtj|j|jd| _d S d | _d S )Nr   )r1   r2   r   r   r   rD   r   r   r   r4   r   r    r2   f  s
   
zData2VecVisionPooler.__init__r6   c                 C   sL   | j d ur|d d dd d d f }|  |d}|S |d d df }|S )Nr   r   )r   r   )r3   r6   Zpatch_tokensr  r   r   r    r8   l  s   
zData2VecVisionPooler.forwardr   r   r   r4   r    r   e  s    r   z
    Data2VecVision Model transformer with an image classification head on top (a linear layer on top of the average of
    the final hidden states of the patch tokens) e.g. for ImageNet.
    )Zcustom_introc                       s   e Zd Zdeddf fddZe							ddeej deej d	eej d
ee	 dee	 de	dee	 de
eef fddZ  ZS )$Data2VecVisionForImageClassificationrA   r%   Nc                    sR   t  | |j| _t|dd| _|jdkrt|j|jnt | _	| 
  d S )NTr   r   )r1   r2   
num_labelsr   r   r   r   rD   r   
classifierr  r   r4   r   r    r2     s
   $z-Data2VecVisionForImageClassification.__init__Fri   r   labelsr   r   rh   r   c                 C   sl  |dur|n| j j}| j||||||d}|r|jn|d }	| |	}
d}|dur| j jdu rS| jdkr9d| j _n| jdkrO|jtj	ksJ|jtj
krOd| j _nd| j _| j jdkrqt }| jdkrk||
 | }n+||
|}n%| j jdkrt }||
d| j|d}n| j jdkrt }||
|}|s|
f|dd  }|dur|f| S |S t||
|j|jd	S )
a  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        Nr   r   r   rh   r   r   Z
regressionZsingle_label_classificationZmulti_label_classificationrW   rX   losslogitsr6   r   )rA   r  r   r  r  Zproblem_typer  r'   r+   longru   r
   r   r	   rd   r   r   r6   r   )r3   ri   r   r  r   r   rh   r   r   r  r  r  loss_fctr-   r   r   r    r8     sN   	


"


z,Data2VecVisionForImageClassification.forwardNNNNNFN)r   r   r   r   r2   r   r   r+   r=   rw   r   r   r   r8   r?   r   r   r4   r    r	  x  s6    
	r	  c                       s   e Zd ZdZ			ddededeeeeef f deeeeef ef d	ed
eeeeef f ddf fddZ	de
jde
jfddZ  ZS )Data2VecVisionConvModuleaD  
    A convolutional block that bundles conv/norm/activation layers. This block simplifies the usage of convolution
    layers, which are commonly used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU).

    Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
    r   Fr   in_channelsout_channelsry   paddingr   dilationr%   Nc                    s<   t    tj||||||d| _t|| _t | _d S )N)r  r  ry   r  r   r  )	r1   r2   r   r|   convBatchNorm2dbnZReLU
activation)r3   r  r  ry   r  r   r  r4   r   r    r2     s   
	z!Data2VecVisionConvModule.__init__r"   c                 C   s"   |  |}| |}| |}|S r0   )r  r  r  )r3   r"   r-   r   r   r    r8     s   


z Data2VecVisionConvModule.forward)r   Fr   )r   r   r   r   ru   r   r   r>   rw   r2   r+   r=   r8   r?   r   r   r4   r    r    s*    r  c                       sD   e Zd Zdedededdf fddZdejdejfd	d
Z  ZS )!Data2VecVisionPyramidPoolingBlock
pool_scaler  channelsr%   Nc                    sL   t    t|t||ddg| _t| jD ]\}}| t|| qd S )Nr   ry   )	r1   r2   r   ZAdaptiveAvgPool2dr  layersr   
add_moduler>   )r3   r  r  r   r   r   r4   r   r    r2     s   
z*Data2VecVisionPyramidPoolingBlock.__init__r"   c                 C   s   |}| j D ]}||}q|S r0   )r"  )r3   r"   Zhidden_stater   r   r   r    r8     s   

z)Data2VecVisionPyramidPoolingBlock.forward)	r   r   r   ru   r2   r+   r=   r8   r?   r   r   r4   r    r    s    	r  c                
       sX   e Zd ZdZdeedf dedededdf
 fd	d
Zdej	de
ej	 fddZ  ZS )"Data2VecVisionPyramidPoolingModulea  
    Pyramid Pooling Module (PPM) used in PSPNet.

    Args:
        pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
            Module.
        in_channels (int): Input channels.
        channels (int): Channels after modules, before conv_seg.
        align_corners (bool): align_corners argument of F.interpolate.

    Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
    pool_scales.r  r   r\   r%   Nc                    sh   t    || _|| _|| _|| _g | _t|D ]\}}t|||d}| j	| | 
t|| qd S )N)r  r  r   )r1   r2   r%  r\   r  r   blocksr   r  appendr#  r>   )r3   r%  r  r   r\   r   r  blockr4   r   r    r2     s   
z+Data2VecVisionPyramidPoolingModule.__init__r   c                 C   sH   g }| j D ]}||}tjj|| dd  d| jd}|| q|S )NrX   r   rY   )r&  r   rb   rc   rZ   r\   r'  )r3   r   Zppm_outsppmZppm_outZupsampled_ppm_outr   r   r    r8   "  s   
z*Data2VecVisionPyramidPoolingModule.forward)r   r   r   r   r   ru   rw   r2   r+   r=   r   r8   r?   r   r   r4   r    r$    s    *"r$  c                       sH   e Zd ZdZdeddf fddZdd Zd	ejdejfd
dZ	  Z
S )Data2VecVisionUperHeadz
    Unified Perceptual Parsing for Scene Understanding. This head is the implementation of
    [UPerNet](https://arxiv.org/abs/1807.10221).

    Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
    rA   r%   Nc                    s  t    |j| _|jgd | _|j| _d| _tj| j|j	dd| _
t| j| jd | j| jd| _t| jd t| j| j  | jddd| _t | _t | _| jd d D ] }t|| jdd}t| j| jddd}| j| | j| qWtt| j| j | jddd| _d S )	N   Fr   r!  rW   )r\   r   ry   r  )r1   r2   r%  rD   r  r   r\   r   r|   r  r  r$  psp_modulesr  r   
bottleneckr   lateral_convs	fpn_convsr'  fpn_bottleneck)r3   rA   r  Zl_convZfpn_convr4   r   r    r2   6  s>   


zData2VecVisionUperHead.__init__c                 C   s:   |d }|g}| | | tj|dd}| |}|S )NrW   r   r]   )extendr-  r+   re   r.  )r3   Zinputsr   Zpsp_outsr-   r   r   r    psp_forward\  s   
z"Data2VecVisionUperHead.psp_forwardencoder_hidden_statesc                    s   fddt jD   t}t|d ddD ]$}|d  jdd  }|d  tjj	| |dj
d |d < q fd	dt|d D }|d  t|d ddD ]}tjj	|| |d jdd  dj
d||< qbtj|dd
}|}|}|S )Nc                    s   g | ]
\}}| | qS r   r   )r   r   Zlateral_conv)r4  r   r    r   g  s    z2Data2VecVisionUperHead.forward.<locals>.<listcomp>r   r   rW   rX   r   rY   c                    s   g | ]}j |  | qS r   )r0  r   )lateralsr3   r   r    r   t  s    r]   )r   r/  r'  r3  r   r   r)   r   rb   rc   r\   r+   re   r1  r  )r3   r4  Zused_backbone_levelsr   Z
prev_shapeZfpn_outsr-   r   )r4  r5  r3   r    r8   e  s$   

zData2VecVisionUperHead.forward)r   r   r   r   r   r2   r3  r+   r=   r8   r?   r   r   r4   r    r*  .  s
    &	r*  c                       sd   e Zd ZdZ			ddedededeeeeef f d	d
f
 fddZde	j
d	e	j
fddZ  ZS )Data2VecVisionFCNHeada  
    Fully Convolution Networks for Semantic Segmentation. This head is implemented of
    [FCNNet](https://arxiv.org/abs/1411.4038>).

    Args:
        config (Data2VecVisionConfig): Configuration.
        in_channels
        kernel_size (int): The kernel size for convs in the head. Default: 3.
        dilation (int): The dilation rate for convs in the head. Default: 1.


    Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
    rX   r   r   rA   in_indexry   r  r%   Nc              
      s   t    |j| _|j| _|j| _|j| _	|| _
|d | }g }|t| j| j|||d t| jd D ]}|t| j| j|||d q5| jdkrQt | _ntj| | _| j	rjt| j| j | j||d d| _tj| j|jdd| _d S )NrX   )ry   r  r  r   r   r,  r!  )r1   r2   rD   r  Zauxiliary_channelsr   Zauxiliary_num_convsZ	num_convsZauxiliary_concat_inputconcat_inputr7  r'  r  r   r   r   convs
Sequentialconv_catr|   r  r  )r3   rA   r7  ry   r  Zconv_paddingr9  r   r4   r   r    r2     s6   

zData2VecVisionFCNHead.__init__r4  c                 C   s@   || j  }| |}| jr| tj||gdd}| |}|S )Nr   r]   )r7  r9  r8  r;  r+   re   r  )r3   r4  r6   r-   r   r   r    r8     s   


zData2VecVisionFCNHead.forward)rX   r   r   )r   r   r   r   r   ru   r   r   r2   r+   r=   r8   r?   r   r   r4   r    r6    s"    &r6  c                       s   e Zd Zdeddf fddZdd Ze							dd	eej	 d
eej	 deej	 dee
 dee
 de
dee
 deeef fddZ  ZS )%Data2VecVisionForSemanticSegmentationrA   r%   Nc                    s   t  | |j| _t|dd| _t| jjdkrtdt	
t	j|j|jdddt	|jt	 t	j|j|jddd| _t	
t	j|j|jddd| _t	 | _t	jddd| _t|| _|jrft|nd | _|   d S )NFr
  r+  zData2VecVisionForSemanticSegmentation requires config.out_indices to be a list of 4 integers, specifying which features to use from the backbone. One can use [3, 5, 7, 11] in case of a base-sized architecture.rX   rx   )r1   r2   r  r   r   r   rA   out_indicesr~   r   r:  r   rD   r  ZGELUfpn1fpn2r   fpn3Z	MaxPool2dfpn4r*  decode_headZuse_auxiliary_headr6  auxiliary_headr  r   r4   r   r    r2     s*   


z.Data2VecVisionForSemanticSegmentation.__init__c           
      C   s   t jj||jdd  ddd}|d ur"t jj||jdd  ddd}t| jjd}|||}|}|d urA|||}	|| jj|	 7 }|S )Nr   r   FrY   )Zignore_index)r   rb   rc   r)   r	   rA   Zsemantic_loss_ignore_indexZauxiliary_loss_weight)
r3   r  auxiliary_logitsr  Zupsampled_logitsZupsampled_auxiliary_logitsr  Z	main_lossr  Zauxiliary_lossr   r   r    compute_loss  s   

z2Data2VecVisionForSemanticSegmentation.compute_lossFri   r   r  r   r   rh   r   c                    s  |dur|nj j}|dur|nj j}|dur"j jdkr"tdj|||d||d}|r2|jn|d }	fddt|	D }
|jd  j j	j j
  fd	d|
D }
jjjjg}tt|
D ]}|| |
| |
|< qh|
}d}jdur|
}d}|dur|||}|s|r|f|dd  }n	|f|d
d  }|dur|f| S |S t|||r|jnd|jdS )a@  
        labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
            Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy).

        Examples:

        ```python
        >>> from transformers import AutoImageProcessor, Data2VecVisionForSemanticSegmentation
        >>> from PIL import Image
        >>> import requests

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> image_processor = AutoImageProcessor.from_pretrained("facebook/data2vec-vision-base")
        >>> model = Data2VecVisionForSemanticSegmentation.from_pretrained("facebook/data2vec-vision-base")

        >>> inputs = image_processor(images=image, return_tensors="pt")
        >>> outputs = model(**inputs)
        >>> # logits are of shape (batch_size, num_labels, height, width)
        >>> logits = outputs.logits
        ```Nr   z/The number of labels should be greater than oneTr  c                    s$   g | ]\}}|d   j jv r|qS r&   )rA   r=  )r   idxfeaturer:   r   r    r   0  s   $ zAData2VecVisionForSemanticSegmentation.forward.<locals>.<listcomp>r   c                    s<   g | ]}|d d dd d d f  ddd dqS )Nr   r   rX   rW   )ra   r`   r   )rr   patch_resolutionr   r    r   3  s    0rX   r  )rA   r  r   r  r~   r   r6   r   r)   rK   rI   r>  r?  r@  rA  r   r   rB  rC  rE  r   r   )r3   ri   r   r  r   r   rh   r   r   r4  featuresopsr   r  rD  r  r-   r   )rr   rH  r3   r    r8     sR   "	



z-Data2VecVisionForSemanticSegmentation.forwardr  )r   r   r   r   r2   rE  r   r   r+   r=   rw   r   r   r   r8   r?   r   r   r4   r    r<    s8     
	r<  )r	  r<  r   r   )r!   F)Gr   collections.abcrL   r   rk   dataclassesr   typingr   r   r   r   r+   Ztorch.utils.checkpointr   Ztorch.nnr   r	   r
   Zactivationsr   Zmodeling_outputsr   r   r   r   Zmodeling_utilsr   Zpytorch_utilsr   r   r   utilsr   r   r   Zconfiguration_data2vec_visionr   Z
get_loggerr   r   r   r=   r<   rw   r.   Moduler/   r@   rG   r   r   r   r   r   r   r   r   r   r   r   r   r   r	  r  r  r$  r*  r6  r<  __all__r   r   r   r    <module>   sj   
 g'T?/DTW&XO&(V? 