o
    Zh:                  	   @   s  d Z ddlZddlZddlZddlmZ ddlmZm	Z	m
Z
 ddlZddlZddlmZmZ ddlmZ ddlmZ dd	lmZ dd
lmZmZmZ ddlmZmZmZmZ ddlmZ ddl m!Z! e"e#Z$eG dd deZ%eG dd deZ&eG dd deZ'eG dd deZ(dd Z)dd Z*dJdejde+de,d ejfd!d"Z-G d#d$ d$ej.Z/G d%d& d&ej.Z0G d'd( d(ej.Z1G d)d* d*ej.Z2G d+d, d,ej.Z3G d-d. d.ej.Z4G d/d0 d0ej.Z5G d1d2 d2ej.Z6G d3d4 d4ej.Z7G d5d6 d6ej.Z8G d7d8 d8ej.Z9G d9d: d:ej.Z:eG d;d< d<eZ;eG d=d> d>e;Z<ed?d@G dAdB dBe;Z=edCd@G dDdE dEe;Z>edFd@G dGdH dHe;eZ?g dIZ@dS )Kz!PyTorch Swinv2 Transformer model.    N)	dataclass)OptionalTupleUnion)Tensornn   )ACT2FN)BackboneOutput)PreTrainedModel) find_pruneable_heads_and_indicesmeshgridprune_linear_layer)ModelOutputauto_docstringlogging	torch_int)BackboneMixin   )Swinv2Configc                   @   sr   e Zd ZU dZdZeej ed< dZ	ee
ejdf  ed< dZee
ejdf  ed< dZee
ejdf  ed< dS )Swinv2EncoderOutputa  
    Swinv2 encoder's outputs, with potential hidden states and attentions.

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
            shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
        reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
            shape `(batch_size, hidden_size, height, width)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
            include the spatial dimensions.
    Nlast_hidden_state.hidden_states
attentionsreshaped_hidden_states)__name__
__module____qualname____doc__r   r   torchFloatTensor__annotations__r   r   r   r    r"   r"   Y/var/www/auris/lib/python3.10/site-packages/transformers/models/swinv2/modeling_swinv2.pyr   *   s   
 r   c                   @      e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eeejdf  ed< dZeeejdf  ed< dZeeejdf  ed< dS )	Swinv2ModelOutputaV  
    Swinv2 model's outputs that also contains a pooling of the last hidden states.

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*, returned when `add_pooling_layer=True` is passed):
            Average pooling of the last layer hidden-state.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
            shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
        reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
            shape `(batch_size, hidden_size, height, width)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
            include the spatial dimensions.
    Nr   pooler_output.r   r   r   )r   r   r   r   r   r   r   r    r!   r&   r   r   r   r   r"   r"   r"   r#   r%   L      
 r%   c                   @   s   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eeejdf  ed< dZeeejdf  ed< dZeeejdf  ed< ed	d
 ZdS )Swinv2MaskedImageModelingOutputa  
    Swinv2 masked image model outputs.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided):
            Masked image modeling (MLM) loss.
        reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Reconstructed pixel values.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
            shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
        reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
            shape `(batch_size, hidden_size, height, width)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
            include the spatial dimensions.
    Nlossreconstruction.r   r   r   c                 C   s   t dt | jS )Nzlogits attribute is deprecated and will be removed in version 5 of Transformers. Please use the reconstruction attribute to retrieve the final output instead.)warningswarnFutureWarningr*   selfr"   r"   r#   logits   s
   z&Swinv2MaskedImageModelingOutput.logits)r   r   r   r   r)   r   r   r    r!   r*   r   r   r   r   propertyr0   r"   r"   r"   r#   r(   q   s   
 r(   c                   @   r$   )	Swinv2ImageClassifierOutputa  
    Swinv2 outputs for image classification.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Classification (or regression if config.num_labels==1) loss.
        logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
            Classification (or regression if config.num_labels==1) scores (before SoftMax).
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
            shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
        reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
            shape `(batch_size, hidden_size, height, width)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
            include the spatial dimensions.
    Nr)   r0   .r   r   r   )r   r   r   r   r)   r   r   r    r!   r0   r   r   r   r   r"   r"   r"   r#   r2      r'   r2   c                 C   sR   | j \}}}}| ||| ||| ||} | dddddd d|||}|S )z2
    Partitions the given input into windows.
    r   r   r            shapeviewpermute
contiguous)input_featurewindow_size
batch_sizeheightwidthnum_channelswindowsr"   r"   r#   window_partition   s   $rC   c                 C   sN   | j d }| d|| || |||} | dddddd d|||} | S )z?
    Merges windows to produce higher resolution features.
    r6   r   r   r   r3   r4   r5   r7   )rB   r=   r?   r@   rA   r"   r"   r#   window_reverse   s   
$rD           Finput	drop_probtrainingreturnc                 C   sd   |dks|s| S d| }| j d fd| jd   }|tj|| j| jd }|  | || }|S )aF  
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

    Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
    however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
    layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
    argument.
    rE   r   r   )r   )dtypedevice)r8   ndimr   ZrandrJ   rK   Zfloor_div)rF   rG   rH   Z	keep_probr8   Zrandom_tensoroutputr"   r"   r#   	drop_path   s   
rO   c                       sT   e Zd ZdZddee ddf fddZdejdejfdd	Z	de
fd
dZ  ZS )Swinv2DropPathzXDrop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).NrG   rI   c                    s   t    || _d S N)super__init__rG   )r/   rG   	__class__r"   r#   rS      s   

zSwinv2DropPath.__init__r   c                 C   s   t || j| jS rQ   )rO   rG   rH   r/   r   r"   r"   r#   forward   s   zSwinv2DropPath.forwardc                 C   s   d | jS )Nzp={})formatrG   r.   r"   r"   r#   
extra_repr   s   zSwinv2DropPath.extra_reprrQ   )r   r   r   r   r   floatrS   r   r   rW   strrY   __classcell__r"   r"   rT   r#   rP      s
    rP   c                
       sr   e Zd ZdZd fdd	Zdejdededejfd	d
Z		dde	ej
 de	ej dedeej fddZ  ZS )Swinv2EmbeddingszW
    Construct the patch and position embeddings. Optionally, also the mask token.
    Fc                    s   t    t|| _| jj}| jj| _|r tt	
dd|jnd | _|jr5tt	
d|d |j| _nd | _t|j| _t|j| _|j| _|| _d S )Nr   )rR   rS   Swinv2PatchEmbeddingspatch_embeddingsnum_patches	grid_size
patch_gridr   	Parameterr   zeros	embed_dim
mask_tokenZuse_absolute_embeddingsposition_embeddings	LayerNormnormDropouthidden_dropout_probdropout
patch_sizeconfig)r/   rn   use_mask_tokenr`   rT   r"   r#   rS     s   


 
zSwinv2Embeddings.__init__
embeddingsr?   r@   rI   c                 C   s   |j d d }| jj d d }tj s||kr||kr| jS | jddddf }| jddddf }|j d }|| j }	|| j }
t|d }|d|||}|dddd}t	j
j||	|
fdd	d
}|dddddd|}tj||fddS )a   
        This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
        images. This method is also adapted to support torch.jit tracing.

        Adapted from:
        - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
        - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
        r   Nr6         ?r   r   r3   ZbicubicF)sizemodeZalign_cornersdim)r8   rg   r   Zjit
is_tracingrm   r   reshaper:   r   
functionalZinterpolater9   cat)r/   rp   r?   r@   r`   Znum_positionsZclass_pos_embedZpatch_pos_embedru   Z
new_heightZ	new_widthZsqrt_num_positionsr"   r"   r#   interpolate_pos_encoding  s(   



z)Swinv2Embeddings.interpolate_pos_encodingNpixel_valuesbool_masked_posrz   c                 C   s   |j \}}}}| |\}}	| |}| \}
}}|d ur8| j|
|d}|d|}|d|  ||  }| jd urN|rI|| 	||| }n|| j }| 
|}||	fS )Nr6         ?)r8   r_   ri   rr   rf   expand	unsqueezeZtype_asrg   rz   rl   )r/   r{   r|   rz   _rA   r?   r@   rp   output_dimensionsr>   Zseq_lenZmask_tokensmaskr"   r"   r#   rW   A  s   



zSwinv2Embeddings.forward)FNF)r   r   r   r   rS   r   r   intrz   r   r    
BoolTensorboolr   rW   r\   r"   r"   rT   r#   r]     s    +r]   c                       sN   e Zd ZdZ fddZdd Zdeej de	ej
e	e f fdd	Z  ZS )
r^   z
    This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
    `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
    Transformer.
    c                    s   t    |j|j}}|j|j}}t|tjj	r|n||f}t|tjj	r)|n||f}|d |d  |d |d   }|| _|| _|| _|| _
|d |d  |d |d  f| _tj||||d| _d S )Nr   r   )kernel_sizeZstride)rR   rS   
image_sizerm   rA   re   
isinstancecollectionsabcIterabler`   ra   r   Conv2d
projection)r/   rn   r   rm   rA   hidden_sizer`   rT   r"   r#   rS   e  s   
 "zSwinv2PatchEmbeddings.__init__c                 C   s   || j d  dkrd| j d || j d   f}tj||}|| j d  dkr>ddd| j d || j d   f}tj||}|S )Nr   r   )rm   r   rx   pad)r/   r{   r?   r@   
pad_valuesr"   r"   r#   	maybe_padt  s    zSwinv2PatchEmbeddings.maybe_padr{   rI   c                 C   sV   |j \}}}}| |||}| |}|j \}}}}||f}|ddd}||fS )Nr3   r   )r8   r   r   flatten	transpose)r/   r{   r   rA   r?   r@   rp   r   r"   r"   r#   rW   }  s   
zSwinv2PatchEmbeddings.forward)r   r   r   r   rS   r   r   r   r    r   r   r   rW   r\   r"   r"   rT   r#   r^   ^  s
    .	r^   c                	       sh   e Zd ZdZejfdee dedejddf fddZ	d	d
 Z
dejdeeef dejfddZ  ZS )Swinv2PatchMerginga'  
    Patch Merging Layer.

    Args:
        input_resolution (`Tuple[int]`):
            Resolution of input feature.
        dim (`int`):
            Number of input channels.
        norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`):
            Normalization layer class.
    input_resolutionru   
norm_layerrI   Nc                    sB   t    || _|| _tjd| d| dd| _|d| | _d S )Nr4   r3   Fbias)rR   rS   r   ru   r   Linear	reductionri   )r/   r   ru   r   rT   r"   r#   rS     s
   
zSwinv2PatchMerging.__init__c                 C   sF   |d dkp|d dk}|r!ddd|d d|d f}t j||}|S )Nr3   r   r   )r   rx   r   )r/   r<   r?   r@   Z
should_padr   r"   r"   r#   r     s
   zSwinv2PatchMerging.maybe_padr<   input_dimensionsc                 C   s   |\}}|j \}}}|||||}| |||}|d d dd ddd dd d f }|d d dd ddd dd d f }	|d d dd ddd dd d f }
|d d dd ddd dd d f }t||	|
|gd}||dd| }| |}| |}|S )Nr   r3   r   r6   r4   )r8   r9   r   r   ry   r   ri   )r/   r<   r   r?   r@   r>   ru   rA   Zinput_feature_0Zinput_feature_1Zinput_feature_2Zinput_feature_3r"   r"   r#   rW     s   $$$$

zSwinv2PatchMerging.forward)r   r   r   r   r   rh   r   r   ModulerS   r   r   r   rW   r\   r"   r"   rT   r#   r     s
    **r   c                       sj   e Zd Zddgf fdd	Zdd Z			ddejd	eej d
eej dee	 de
ej f
ddZ  ZS )Swinv2SelfAttentionr   c              
      s  t    || dkrtd| d| d|| _t|| | _| j| j | _t|tj	j
r0|n||f| _|| _ttdt|ddf | _ttjddd	d
tjd	dtjd|dd
| _tj| jd d  | jd tjd }tj| jd d  | jd tjd }tt||gddddd d}|d dkr|d d d d d d df  |d d   < |d d d d d d df  |d d   < n3|dkr|d d d d d d df  | jd d   < |d d d d d d df  | jd d   < |d9 }t|t t!|d  t" d }|#t$| j% j&}| j'd|dd t| jd }	t| jd }
tt|	|
gdd}t(|d}|d d d d d f |d d d d d f  }|ddd }|d d d d df  | jd d 7  < |d d d d df  | jd d 7  < |d d d d df  d| jd  d 9  < |)d}| j'd|dd tj| j| j|j*d
| _+tj| j| jdd
| _,tj| j| j|j*d
| _-t.|j/| _0d S )Nr   zThe hidden size (z6) is not a multiple of the number of attention heads ()
   r   r3   i   Tr   )ZinplaceFrJ   Zij)Zindexing   r}   relative_coords_table)
persistentr6   relative_position_index)1rR   rS   
ValueErrornum_attention_headsr   attention_head_sizeall_head_sizer   r   r   r   r=   pretrained_window_sizer   rc   r   logZoneslogit_scale
Sequentialr   ZReLUcontinuous_position_bias_mlpZarangeZint64rZ   stackr   r:   r;   r   signlog2absmathtonext
parametersrJ   Zregister_bufferr   sumZqkv_biasquerykeyvaluerj   attention_probs_dropout_probrl   )r/   rn   ru   	num_headsr=   r   Zrelative_coords_hZrelative_coords_wr   Zcoords_hZcoords_wZcoordsZcoords_flattenZrelative_coordsr   rT   r"   r#   rS     s`   
"&((
,.
..&,((,
zSwinv2SelfAttention.__init__c                 C   s6   |  d d | j| jf }||}|ddddS )Nr6   r   r3   r   r   )rr   r   r   r9   r:   )r/   xZnew_x_shaper"   r"   r#   transpose_for_scores  s   
z(Swinv2SelfAttention.transpose_for_scoresNFr   attention_mask	head_maskoutput_attentionsrI   c                 C   s  |j \}}}| |}| | |}	| | |}
| |}tjj|ddtjj|	dddd }t	j
| jtdd }|| }| | jd| j}|| jd | jd | jd  | jd | jd  d}|ddd }d	t	| }||d }|d ur|j d }||| || j|||dd }||dd }|d| j||}tjj|dd}| |}|d ur|| }t	||
}|dddd
 }| d d | jf }||}|r||f}|S |f}|S )Nr6   rt   g      Y@)maxr   r   r3      r   )r8   r   r   r   r   r   rx   	normalizer   r   clampr   r   r   expr   r   r9   r   r   r=   r:   r;   Zsigmoidr   Zsoftmaxrl   matmulrr   r   )r/   r   r   r   r   r>   ru   rA   Zmixed_query_layerZ	key_layerZvalue_layerZquery_layerZattention_scoresr   Zrelative_position_bias_tableZrelative_position_biasZ
mask_shapeZattention_probsZcontext_layerZnew_context_layer_shapeoutputsr"   r"   r#   rW     sT   

&


zSwinv2SelfAttention.forwardNNF)r   r   r   rS   r   r   r   r   r    r   r   rW   r\   r"   r"   rT   r#   r     s"    =r   c                       s8   e Zd Z fddZdejdejdejfddZ  ZS )Swinv2SelfOutputc                    s*   t    t||| _t|j| _d S rQ   )rR   rS   r   r   denserj   r   rl   r/   rn   ru   rT   r"   r#   rS   B  s   
zSwinv2SelfOutput.__init__r   input_tensorrI   c                 C      |  |}| |}|S rQ   r   rl   )r/   r   r   r"   r"   r#   rW   G  s   

zSwinv2SelfOutput.forwardr   r   r   rS   r   r   rW   r\   r"   r"   rT   r#   r   A  s    $r   c                       sd   e Zd Zd fdd	Zdd Z			ddejd	eej d
eej dee	 de
ej f
ddZ  ZS )Swinv2Attentionr   c                    sL   t    t||||t|tjjr|n||fd| _t||| _	t
 | _d S )Nrn   ru   r   r=   r   )rR   rS   r   r   r   r   r   r/   r   rN   setpruned_heads)r/   rn   ru   r   r=   r   rT   r"   r#   rS   O  s   
	zSwinv2Attention.__init__c                 C   s   t |dkrd S t|| jj| jj| j\}}t| jj|| j_t| jj|| j_t| jj	|| j_	t| j
j|dd| j
_| jjt | | j_| jj| jj | j_| j|| _d S )Nr   r   rt   )lenr   r/   r   r   r   r   r   r   r   rN   r   r   union)r/   headsindexr"   r"   r#   prune_heads]  s   zSwinv2Attention.prune_headsNFr   r   r   r   rI   c                 C   s6   |  ||||}| |d |}|f|dd   }|S )Nr   r   )r/   rN   )r/   r   r   r   r   Zself_outputsattention_outputr   r"   r"   r#   rW   o  s   zSwinv2Attention.forwardr   r   )r   r   r   rS   r   r   r   r   r    r   r   rW   r\   r"   r"   rT   r#   r   N  s"    r   c                       2   e Zd Z fddZdejdejfddZ  ZS )Swinv2Intermediatec                    sJ   t    t|t|j| | _t|jt	rt
|j | _d S |j| _d S rQ   )rR   rS   r   r   r   	mlp_ratior   r   Z
hidden_actr[   r	   intermediate_act_fnr   rT   r"   r#   rS   ~  s
   
zSwinv2Intermediate.__init__r   rI   c                 C   r   rQ   )r   r   rV   r"   r"   r#   rW        

zSwinv2Intermediate.forwardr   r"   r"   rT   r#   r   }  s    r   c                       r   )Swinv2Outputc                    s4   t    tt|j| || _t|j| _	d S rQ   )
rR   rS   r   r   r   r   r   rj   rk   rl   r   rT   r"   r#   rS     s   
zSwinv2Output.__init__r   rI   c                 C   r   rQ   r   rV   r"   r"   r#   rW     r   zSwinv2Output.forwardr   r"   r"   rT   r#   r     s    r   c                       s   e Zd Z	d fdd	Zdeeeef eeef f fddZdd	 Zd
d Z		dde	j
deeef dee	j dee dee	j
e	j
f f
ddZ  ZS )Swinv2LayerrE   r   c           	         s   t    || _| |j|jf||f\}}|d | _|d | _t|||| jt|tj	j
r/|n||fd| _tj||jd| _|dkrGt|nt | _t||| _t||| _tj||jd| _d S )Nr   r   ZepsrE   )rR   rS   r   _compute_window_shiftr=   
shift_sizer   r   r   r   r   	attentionr   rh   layer_norm_epslayernorm_beforerP   IdentityrO   r   intermediater   rN   layernorm_after)	r/   rn   ru   r   r   drop_path_rater   r   r=   rT   r"   r#   rS     s*   


	zSwinv2Layer.__init__rI   c                 C   s6   dd t | j|D }dd t | j||D }||fS )Nc                 S   s    g | ]\}}||kr|n|qS r"   r"   ).0rwr"   r"   r#   
<listcomp>  s     z5Swinv2Layer._compute_window_shift.<locals>.<listcomp>c                 S   s"   g | ]\}}}||krd n|qS r   r"   )r   r   r   sr"   r"   r#   r     s   " )zipr   )r/   Ztarget_window_sizeZtarget_shift_sizer=   r   r"   r"   r#   r     s   z!Swinv2Layer._compute_window_shiftc              	   C   s  | j dkrtjd||df|d}td| j t| j | j  t| j  d f}td| j t| j | j  t| j  d f}d}|D ]}|D ]}	||d d ||	d d f< |d7 }qDq@t|| j}
|
d| j| j }
|
d|
d }||dkt	d|dkt	d}|S d }|S )Nr   r   r   r6   r3   g      YrE   )
r   r   rd   slicer=   rC   r9   r   Zmasked_fillrZ   )r/   r?   r@   rJ   Zimg_maskZheight_slicesZwidth_slicescountZheight_sliceZwidth_sliceZmask_windows	attn_maskr"   r"   r#   get_attn_mask  s.   

$zSwinv2Layer.get_attn_maskc                 C   sR   | j || j   | j  }| j || j   | j  }ddd|d|f}tj||}||fS Nr   )r=   r   rx   r   )r/   r   r?   r@   	pad_rightZ
pad_bottomr   r"   r"   r#   r     s
   zSwinv2Layer.maybe_padNFr   r   r   r   c                 C   s  |\}}|  \}}}	|}
|||||	}| |||\}}|j\}}}}| jdkr9tj|| j | j fdd}n|}t|| j}|d| j| j |	}| j	|||j
d}|d ur_||j}| j||||d}|d }|d| j| j|	}t|| j||}| jdkrtj|| j| jfdd}n|}|d dkp|d dk}|r|d d d |d |d d f  }|||| |	}| |}|
| | }| |}| |}|| | | }|r||d	 f}|S |f}|S )
Nr   )r   r3   )Zshiftsdimsr6   r   )r   r   r5   r   )rr   r9   r   r8   r   r   ZrollrC   r=   r   rJ   r   rK   r   rD   r;   r   rO   r   rN   r   )r/   r   r   r   r   r?   r@   r>   r   ZchannelsZshortcutr   Z
height_padZ	width_padZshifted_hidden_statesZhidden_states_windowsr   Zattention_outputsr   Zattention_windowsZshifted_windowsZ
was_paddedZlayer_outputlayer_outputsr"   r"   r#   rW     sH   

$


zSwinv2Layer.forward)rE   r   r   r   )r   r   r   rS   r   r   r   r   r   r   r   r   r    r   rW   r\   r"   r"   rT   r#   r     s&    &
r   c                       s^   e Zd Z	d fdd	Z		ddejdeeef deej	 d	ee
 d
eej f
ddZ  ZS )Swinv2Stager   c	              
      s   t    || _|| _g }	t|D ]}
t||||||
 |
d dkr#dn|jd |d}|	| qt	|	| _
|d urE|||tjd| _nd | _d| _d S )Nr3   r   )rn   ru   r   r   r   r   r   )ru   r   F)rR   rS   rn   ru   ranger   r=   appendr   
ModuleListblocksrh   
downsampleZpointing)r/   rn   ru   r   depthr   rO   r  r   r   iblockrT   r"   r#   rS     s(   
	
zSwinv2Stage.__init__NFr   r   r   r   rI   c                 C   s   |\}}t | jD ]\}}|d ur|| nd }	||||	|}
|
d }q	|}| jd urD|d d |d d }}||||f}| ||}n||||f}|||f}|rY||
dd  7 }|S )Nr   r   r3   )	enumerater   r  )r/   r   r   r   r   r?   r@   r  layer_modulelayer_head_maskr   !hidden_states_before_downsamplingZheight_downsampledZwidth_downsampledr   Zstage_outputsr"   r"   r#   rW   3  s(   


zSwinv2Stage.forwardr   r   )r   r   r   rS   r   r   r   r   r   r    r   rW   r\   r"   r"   rT   r#   r     s      
r   c                       s|   e Zd Zd fdd	Z					ddejdeeef d	eej	 d
ee
 dee
 dee
 dee
 deeef fddZ  ZS )Swinv2Encoderr   r   r   r   c                    s  t    t|j| _|| _| jjd ur|j}dd tjd|j	t
|jddD }g }t| jD ]M}t|t|jd|  |d d|  |d d|  f|j| |j| |t
|jd | t
|jd |d   || jd k rqtnd || d}|| q0t|| _d	| _d S )
Nc                 S   s   g | ]}|  qS r"   )item)r   r   r"   r"   r#   r   ]  s    z*Swinv2Encoder.__init__.<locals>.<listcomp>r   cpu)rK   r3   r   )rn   ru   r   r  r   rO   r  r   F)rR   rS   r   depths
num_layersrn   pretrained_window_sizesr   Zlinspacer   r   r   r   r   re   r   r   r   r   r   layersgradient_checkpointing)r/   rn   ra   r  Zdprr  Zi_layerstagerT   r"   r#   rS   W  s*   
$*

zSwinv2Encoder.__init__NFTr   r   r   r   output_hidden_states(output_hidden_states_before_downsamplingreturn_dictrI   c                 C   s  |rdnd }|r
dnd }	|rdnd }
|r7|j \}}}|j|g||R  }|dddd}||f7 }|	|f7 }	t| jD ]\}}|d urH|| nd }| jrZ| jrZ| |j|||}n|||||}|d }|d }|d }|d |d f}|r|r|j \}}}|j|g|d |d f|R  }|dddd}||f7 }|	|f7 }	n'|r|s|j \}}}|j|g||R  }|dddd}||f7 }|	|f7 }	|r|
|dd  7 }
q<|st	dd	 |||
|	fD S t
|||
|	d
S )Nr"   r   r   r   r3   r   r6   c                 s   s    | ]	}|d ur|V  qd S rQ   r"   )r   vr"   r"   r#   	<genexpr>  s    z(Swinv2Encoder.forward.<locals>.<genexpr>)r   r   r   r   )r8   r9   r:   r  r  r  rH   Z_gradient_checkpointing_func__call__tupler   )r/   r   r   r   r   r  r  r  Zall_hidden_statesZall_reshaped_hidden_statesZall_self_attentionsr>   r   r   Zreshaped_hidden_stater  r  r  r   r  r   r"   r"   r#   rW   p  sp   







zSwinv2Encoder.forward)r
  )NFFFT)r   r   r   rS   r   r   r   r   r   r    r   r   r   rW   r\   r"   r"   rT   r#   r	  V  s0    

	r	  c                   @   s*   e Zd ZeZdZdZdZdgZdd Z	dS )Swinv2PreTrainedModelswinv2r{   Tr   c                 C   s   t |tjtjfr#|jjjd| jjd |j	dur!|j	j
  dS dS t |tjr8|j	j
  |jjd dS t |trW|jdurH|jj
  |jdurU|jj
  dS dS t |trh|jjtd dS dS )zInitialize the weightsrE   )meanZstdNr}   r   )r   r   r   r   weightdataZnormal_rn   Zinitializer_ranger   Zzero_rh   Zfill_r]   rf   rg   r   r   r   r   )r/   moduler"   r"   r#   _init_weights  s"   




z#Swinv2PreTrainedModel._init_weightsN)
r   r   r   r   Zconfig_classZbase_model_prefixZmain_input_nameZsupports_gradient_checkpointingZ_no_split_modulesr   r"   r"   r"   r#   r    s    r  c                       s   e Zd Zd fdd	Zdd Zdd Ze													dd
eej	 deej
 deej	 dee dee dedee deeef fddZ  ZS )Swinv2ModelTFc                    s   t  | || _t|j| _t|jd| jd   | _t	||d| _
t|| j
j| _tj| j|jd| _|r<tdnd| _|   dS )a  
        add_pooling_layer (`bool`, *optional*, defaults to `True`):
            Whether or not to apply pooling layer.
        use_mask_token (`bool`, *optional*, defaults to `False`):
            Whether or not to create and apply mask tokens in the embedding layer.
        r3   r   )ro   r   N)rR   rS   rn   r   r  r  r   re   num_featuresr]   rp   r	  rb   encoderr   rh   r   	layernormZAdaptiveAvgPool1dpooler	post_init)r/   rn   add_pooling_layerro   rT   r"   r#   rS     s   zSwinv2Model.__init__c                 C      | j jS rQ   rp   r_   r.   r"   r"   r#   get_input_embeddings     z Swinv2Model.get_input_embeddingsc                 C   s*   |  D ]\}}| jj| j| qdS )z
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        N)itemsr#  layerr   r   )r/   Zheads_to_pruner-  r   r"   r"   r#   _prune_heads  s   zSwinv2Model._prune_headsNr{   r|   r   r   r  rz   r  rI   c                 C   s   |dur|n| j j}|dur|n| j j}|dur|n| j j}|du r&td| |t| j j}| j|||d\}}	| j	||	||||d}
|
d }| 
|}d}| jdurd| |dd}t|d}|sr||f|
dd  }|S t|||
j|
j|
jdS )	z
        bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
            Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
        Nz You have to specify pixel_values)r|   rz   )r   r   r  r  r   r   r3   )r   r&   r   r   r   )rn   r   r  use_return_dictr   Zget_head_maskr   r  rp   r#  r$  r%  r   r   r   r%   r   r   r   )r/   r{   r|   r   r   r  rz   r  embedding_outputr   Zencoder_outputssequence_outputpooled_outputrN   r"   r"   r#   rW     sD   
	

zSwinv2Model.forward)TFNNNNNFN)r   r   r   rS   r*  r.  r   r   r   r    r   r   r   r   r%   rW   r\   r"   r"   rT   r#   r!    s:    
	r!  av  
        Swinv2 Model with a decoder on top for masked image modeling, as proposed in
    [SimMIM](https://arxiv.org/abs/2111.09886).

        <Tip>

        Note that we provide a script to pre-train this model on custom data in our [examples
        directory](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-pretraining).

        </Tip>
    )Zcustom_introc                       s   e Zd Z fddZe							ddeej deej deej dee	 d	ee	 d
e	dee	 de
eef fddZ  ZS )Swinv2ForMaskedImageModelingc                    sn   t  | t|ddd| _t|jd|jd   }ttj	||j
d |j ddt|j
| _|   d S )NFT)r'  ro   r3   r   )Zin_channelsZout_channelsr   )rR   rS   r!  r  r   re   r  r   r   r   Zencoder_striderA   ZPixelShuffledecoderr&  )r/   rn   r"  rT   r"   r#   rS   O  s   
z%Swinv2ForMaskedImageModeling.__init__NFr{   r|   r   r   r  rz   r  rI   c              	   C   s>  |dur|n| j j}| j|||||||d}|d }	|	dd}	|	j\}
}}t|d  }}|	|
|||}	| |	}d}|dur}| j j	| j j
 }|d||}|| j j
d| j j
dd }tjj||dd	}||  | d
  | j j }|s|f|dd  }|dur|f| S |S t|||j|j|jdS )a?  
        bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
            Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).

        Examples:
        ```python
        >>> from transformers import AutoImageProcessor, Swinv2ForMaskedImageModeling
        >>> import torch
        >>> from PIL import Image
        >>> import requests

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/swinv2-tiny-patch4-window8-256")
        >>> model = Swinv2ForMaskedImageModeling.from_pretrained("microsoft/swinv2-tiny-patch4-window8-256")

        >>> num_patches = (model.config.image_size // model.config.patch_size) ** 2
        >>> pixel_values = image_processor(images=image, return_tensors="pt").pixel_values
        >>> # create random boolean mask of shape (batch_size, num_patches)
        >>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()

        >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
        >>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction
        >>> list(reconstructed_pixel_values.shape)
        [1, 3, 256, 256]
        ```N)r|   r   r   r  rz   r  r   r   r3   rq   r6   none)r   gh㈵>)r)   r*   r   r   r   )rn   r/  r  r   r8   r   floorrw   r5  r   rm   Zrepeat_interleaver   r;   r   rx   Zl1_lossr   rA   r(   r   r   r   )r/   r{   r|   r   r   r  rz   r  r   r1  r>   rA   Zsequence_lengthr?   r@   Zreconstructed_pixel_valuesZmasked_im_lossrr   r   Zreconstruction_lossrN   r"   r"   r#   rW   _  sJ   &

 z$Swinv2ForMaskedImageModeling.forwardr3  )r   r   r   rS   r   r   r   r    r   r   r   r   r(   rW   r\   r"   r"   rT   r#   r4  @  s6    
	r4  a  
    Swinv2 Model transformer with an image classification head on top (a linear layer on top of the final hidden state
    of the [CLS] token) e.g. for ImageNet.

    <Tip>

        Note that it's possible to fine-tune SwinV2 on higher resolution images than the ones it has been trained on, by
        setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
        position embeddings to the higher resolution.

    </Tip>
    c                       s   e Zd Z fddZe							ddeej deej deej dee	 d	ee	 d
e	dee	 de
eef fddZ  ZS )Swinv2ForImageClassificationc                    sP   t  | |j| _t|| _|jdkrt| jj|jnt | _	| 
  d S r   )rR   rS   Z
num_labelsr!  r  r   r   r"  r   
classifierr&  r/   rn   rT   r"   r#   rS     s   
"z%Swinv2ForImageClassification.__init__NFr{   r   labelsr   r  rz   r  rI   c                 C   s   |dur|n| j j}| j||||||d}|d }	| |	}
d}|dur.| j|
||
| j d}|sD|
f|dd  }|durB|f| S |S t||
|j|j|jdS )a  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        N)r   r   r  rz   r  r   )r0   r;  Zpooled_logitsrn   r3   )r)   r0   r   r   r   )	rn   r/  r  r9  Zloss_functionr2   r   r   r   )r/   r{   r   r;  r   r  rz   r  r   r2  r0   r)   rN   r"   r"   r#   rW     s0   	
z$Swinv2ForImageClassification.forwardr3  )r   r   r   rS   r   r   r   r    Z
LongTensorr   r   r   r2   rW   r\   r"   r"   rT   r#   r8    s6    
	r8  zO
    Swinv2 backbone, to be used with frameworks like DETR and MaskFormer.
    c                       sZ   e Zd Z fddZdd Ze			ddedee dee d	ee d
e	f
ddZ
  ZS )Swinv2Backbonec                    sd   t    t     jg fddtt jD  | _t | _	t
 | j	j| _|   d S )Nc                    s   g | ]}t  jd |  qS )r3   )r   re   )r   r  rn   r"   r#   r     s    z+Swinv2Backbone.__init__.<locals>.<listcomp>)rR   rS   Z_init_backbonere   r   r   r  r"  r]   rp   r	  rb   r#  r&  r:  rT   r=  r#   rS   
  s   &
zSwinv2Backbone.__init__c                 C   r(  rQ   r)  r.   r"   r"   r#   r*    r+  z#Swinv2Backbone.get_input_embeddingsNr{   r   r  r  rI   c              	   C   s   |dur|n| j j}|dur|n| j j}|dur|n| j j}| |\}}| j||d|dd|d}|r6|jn|d }d}	t| j|D ]\}
}|
| j	v rP|	|f7 }	qB|sj|	f}|r_||d f7 }|rh||d f7 }|S t
|	|rq|jnd|jdS )	aK  
        Examples:

        ```python
        >>> from transformers import AutoImageProcessor, AutoBackbone
        >>> import torch
        >>> from PIL import Image
        >>> import requests

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> processor = AutoImageProcessor.from_pretrained("microsoft/swinv2-tiny-patch4-window8-256")
        >>> model = AutoBackbone.from_pretrained(
        ...     "microsoft/swinv2-tiny-patch4-window8-256", out_features=["stage1", "stage2", "stage3", "stage4"]
        ... )

        >>> inputs = processor(image, return_tensors="pt")

        >>> outputs = model(**inputs)
        >>> feature_maps = outputs.feature_maps
        >>> list(feature_maps[-1].shape)
        [1, 2048, 7, 7]
        ```NT)r   r   r  r  r  r6   r"   r   r3   )feature_mapsr   r   )rn   r/  r  r   rp   r#  r   r   Zstage_namesZout_featuresr
   r   r   )r/   r{   r   r  r  r0  r   r   r   r>  r  Zhidden_staterN   r"   r"   r#   rW     s@    


zSwinv2Backbone.forward)NNN)r   r   r   rS   r*  r   r   r   r   r
   rW   r\   r"   r"   rT   r#   r<    s$    r<  )r8  r4  r!  r  r<  )rE   F)Ar   collections.abcr   r   r+   dataclassesr   typingr   r   r   r   Ztorch.utils.checkpointr   r   Zactivationsr	   Zmodeling_outputsr
   Zmodeling_utilsr   Zpytorch_utilsr   r   r   utilsr   r   r   r   Zutils.backbone_utilsr   Zconfiguration_swinv2r   Z
get_loggerr   loggerr   r%   r(   r2   rC   rD   rZ   r   rO   r   rP   r]   r^   r   r   r   r   r   r   r   r   r	  r  r!  r4  r8  r<  __all__r"   r"   r"   r#   <module>   st   
 #,$ ]+6 /}@icg@W