o
    Zh                  	   @   s  d Z ddlZddlmZ ddlmZmZmZ ddlZddl	Zddlm
Z
 ddlmZmZmZ ddlmZ dd	lmZ dd
lmZ ddlmZmZ ddlmZmZmZmZmZmZmZm Z m!Z! ddl"m#Z# ddl$m%Z% e rvddl&m'Z'm(Z( ndd Z(dd Z'e)e*Z+dZ,dZ-g dZ.dZ/dZ0eG dd deZ1eG dd deZ2eG dd deZ3G dd  d e
j4Z5G d!d" d"e
j4Z6G d#d$ d$e
j4Z7dMd'ej8d(e9d)e:d*ej8fd+d,Z;G d-d. d.e
j4Z<G d/d0 d0e
j4Z=G d1d2 d2e
j4Z>G d3d4 d4e
j4Z?G d5d6 d6e
j4Z@G d7d8 d8e
j4ZAG d9d: d:e
j4ZBG d;d< d<e
j4ZCG d=d> d>e
j4ZDG d?d@ d@eZEdAZFdBZGedCeFG dDdE dEeEZHedFeFG dGdH dHeEZIedIeFG dJdK dKeEe#ZJg dLZKdS )Nz1PyTorch Neighborhood Attention Transformer model.    N)	dataclass)OptionalTupleUnion)nn)BCEWithLogitsLossCrossEntropyLossMSELoss   )ACT2FN)BackboneOutput)PreTrainedModel) find_pruneable_heads_and_indicesprune_linear_layer)	ModelOutputOptionalDependencyNotAvailableadd_code_sample_docstringsadd_start_docstrings%add_start_docstrings_to_model_forwardis_natten_availableloggingreplace_return_docstringsrequires_backends)BackboneMixin   )	NatConfig)
natten2davnatten2dqkrpbc                  O      t  Nr   argskwargs r$   ^/var/www/auris/lib/python3.10/site-packages/transformers/models/deprecated/nat/modeling_nat.pyr   1      r   c                  O   r   r   r    r!   r$   r$   r%   r   4   r&   r   r   zshi-labs/nat-mini-in1k-224)r      r'   i   z	tiger catc                   @   sr   e Zd ZU dZdZeej ed< dZ	ee
ejdf  ed< dZee
ejdf  ed< dZee
ejdf  ed< dS )NatEncoderOutputa  
    Nat encoder's outputs, with potential hidden states and attentions.

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
            shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
        reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
            shape `(batch_size, hidden_size, height, width)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
            include the spatial dimensions.
    Nlast_hidden_state.hidden_states
attentionsreshaped_hidden_states)__name__
__module____qualname____doc__r)   r   torchFloatTensor__annotations__r*   r   r+   r,   r$   r$   r$   r%   r(   I   s   
 r(   c                   @      e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eeejdf  ed< dZeeejdf  ed< dZeeejdf  ed< dS )	NatModelOutputaS  
    Nat model's outputs that also contains a pooling of the last hidden states.

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*, returned when `add_pooling_layer=True` is passed):
            Average pooling of the last layer hidden-state.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
            shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
        reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
            shape `(batch_size, hidden_size, height, width)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
            include the spatial dimensions.
    Nr)   pooler_output.r*   r+   r,   )r-   r.   r/   r0   r)   r   r1   r2   r3   r6   r*   r   r+   r,   r$   r$   r$   r%   r5   j      
 r5   c                   @   r4   )	NatImageClassifierOutputa   
    Nat outputs for image classification.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Classification (or regression if config.num_labels==1) loss.
        logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
            Classification (or regression if config.num_labels==1) scores (before SoftMax).
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
            shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
        reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
            shape `(batch_size, hidden_size, height, width)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
            include the spatial dimensions.
    Nlosslogits.r*   r+   r,   )r-   r.   r/   r0   r9   r   r1   r2   r3   r:   r*   r   r+   r,   r$   r$   r$   r%   r8      r7   r8   c                       s>   e Zd ZdZ fddZdeej deej	 fddZ
  ZS )NatEmbeddingsz6
    Construct the patch and position embeddings.
    c                    s4   t    t|| _t|j| _t|j	| _
d S r   )super__init__NatPatchEmbeddingspatch_embeddingsr   	LayerNorm	embed_dimnormDropouthidden_dropout_probdropoutselfconfig	__class__r$   r%   r=      s   

zNatEmbeddings.__init__pixel_valuesreturnc                 C   s"   |  |}| |}| |}|S r   )r?   rB   rE   )rG   rK   
embeddingsr$   r$   r%   forward   s   


zNatEmbeddings.forward)r-   r.   r/   r0   r=   r   r1   r2   r   TensorrN   __classcell__r$   r$   rI   r%   r;      s    &r;   c                       s:   e Zd ZdZ fddZdeej dejfddZ	  Z
S )r>   z
    This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
    `hidden_states` (patch embeddings) of shape `(batch_size, height, width, hidden_size)` to be consumed by a
    Transformer.
    c              
      sr   t    |j}|j|j}}|| _|dkrntdttj| j|d ddddtj|d |dddd| _	d S )Nr
   z2Dinat only supports patch size of 4 at the moment.      rS   rQ   rQ   r   r   )kernel_sizestridepadding)
r<   r=   
patch_sizenum_channelsrA   
ValueErrorr   Z
SequentialConv2d
projection)rG   rH   rY   rZ   Zhidden_sizerI   r$   r%   r=      s   

zNatPatchEmbeddings.__init__rK   rL   c                 C   s>   |j \}}}}|| jkrtd| |}|dddd}|S )NzeMake sure that the channel dimension of the pixel values match with the one set in the configuration.r   rQ   rS   r   )shaperZ   r[   r]   permute)rG   rK   _rZ   heightwidthrM   r$   r$   r%   rN      s   

zNatPatchEmbeddings.forward)r-   r.   r/   r0   r=   r   r1   r2   rO   rN   rP   r$   r$   rI   r%   r>      s    "r>   c                       sL   e Zd ZdZejfdedejddf fddZde	j
de	j
fd	d
Z  ZS )NatDownsamplerz
    Convolutional Downsampling Layer.

    Args:
        dim (`int`):
            Number of input channels.
        norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`):
            Normalization layer class.
    dim
norm_layerrL   Nc                    s>   t    || _tj|d| ddddd| _|d| | _d S )NrQ   rR   rT   rU   F)rV   rW   rX   bias)r<   r=   rd   r   r\   	reductionrB   )rG   rd   re   rI   r$   r%   r=      s   
zNatDownsampler.__init__input_featurec                 C   s0   |  |dddddddd}| |}|S )Nr   rS   r   rQ   )rg   r_   rB   )rG   rh   r$   r$   r%   rN      s   "
zNatDownsampler.forward)r-   r.   r/   r0   r   r@   intModuler=   r1   rO   rN   rP   r$   r$   rI   r%   rc      s    "
rc           Finput	drop_probtrainingrL   c                 C   sd   |dks|s| S d| }| j d fd| jd   }|tj|| j| jd }|  | || }|S )aF  
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

    Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
    however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
    layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
    argument.
    rk   r   r   )r   )dtypedevice)r^   ndimr1   Zrandro   rp   Zfloor_div)rl   rm   rn   Z	keep_probr^   Zrandom_tensoroutputr$   r$   r%   	drop_path  s   
rt   c                       sT   e Zd ZdZddee ddf fddZdejdejfdd	Z	de
fd
dZ  ZS )NatDropPathzXDrop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).Nrm   rL   c                    s   t    || _d S r   )r<   r=   rm   )rG   rm   rI   r$   r%   r=     s   

zNatDropPath.__init__r*   c                 C   s   t || j| jS r   )rt   rm   rn   rG   r*   r$   r$   r%   rN     s   zNatDropPath.forwardc                 C   s   d | jS )Nzp={})formatrm   rG   r$   r$   r%   
extra_repr!  s   zNatDropPath.extra_reprr   )r-   r.   r/   r0   r   floatr=   r1   rO   rN   strry   rP   r$   r$   rI   r%   ru     s
    ru   c                       J   e Zd Z fddZdd Z	ddejdee de	ej fd	d
Z
  ZS )NeighborhoodAttentionc                    s   t    || dkrtd| d| d|| _t|| | _| j| j | _|| _t	t
|d| j d d| j d | _tj| j| j|jd| _tj| j| j|jd| _tj| j| j|jd| _t|j| _d S )Nr   zThe hidden size (z6) is not a multiple of the number of attention heads ()rQ   r   )rf   )r<   r=   r[   num_attention_headsri   attention_head_sizeall_head_sizerV   r   	Parameterr1   ZzerosrpbLinearZqkv_biasquerykeyvaluerC   attention_probs_dropout_probrE   rG   rH   rd   	num_headsrV   rI   r$   r%   r=   &  s   
*zNeighborhoodAttention.__init__c                 C   s8   |  d d | j| jf }||}|dddddS )Nr   rS   r   rQ   r
   )sizer   r   viewr_   )rG   xZnew_x_shaper$   r$   r%   transpose_for_scores;  s   
z*NeighborhoodAttention.transpose_for_scoresFr*   output_attentionsrL   c                 C   s   |  | |}|  | |}|  | |}|t| j }t||| j| j	d}t
jj|dd}| |}t||| j	d}|ddddd }| d d | jf }	||	}|rc||f}
|
S |f}
|
S )	Nr   r   rd   r   rQ   rS   r
   )r   r   r   r   mathsqrtr   r   r   rV   r   
functionalZsoftmaxrE   r   r_   
contiguousr   r   r   )rG   r*   r   Zquery_layerZ	key_layerZvalue_layerZattention_scoresZattention_probsZcontext_layerZnew_context_layer_shapeoutputsr$   r$   r%   rN   @  s   

zNeighborhoodAttention.forwardF)r-   r.   r/   r=   r   r1   rO   r   boolr   rN   rP   r$   r$   rI   r%   r}   %  s    r}   c                       s8   e Zd Z fddZdejdejdejfddZ  ZS )NeighborhoodAttentionOutputc                    s*   t    t||| _t|j| _d S r   )r<   r=   r   r   denserC   r   rE   rG   rH   rd   rI   r$   r%   r=   c  s   
z$NeighborhoodAttentionOutput.__init__r*   input_tensorrL   c                 C      |  |}| |}|S r   r   rE   )rG   r*   r   r$   r$   r%   rN   h  s   

z#NeighborhoodAttentionOutput.forwardr-   r.   r/   r=   r1   rO   rN   rP   r$   r$   rI   r%   r   b  s    $r   c                       r|   )NeighborhoodAttentionModulec                    s2   t    t||||| _t||| _t | _d S r   )r<   r=   r}   rG   r   rs   setpruned_headsr   rI   r$   r%   r=   p  s   
z$NeighborhoodAttentionModule.__init__c                 C   s   t |dkrd S t|| jj| jj| j\}}t| jj|| j_t| jj|| j_t| jj	|| j_	t| j
j|dd| j
_| jjt | | j_| jj| jj | j_| j|| _d S )Nr   r   r   )lenr   rG   r   r   r   r   r   r   r   rs   r   r   union)rG   headsindexr$   r$   r%   prune_headsv  s   z'NeighborhoodAttentionModule.prune_headsFr*   r   rL   c                 C   s2   |  ||}| |d |}|f|dd   }|S Nr   r   )rG   rs   )rG   r*   r   Zself_outputsattention_outputr   r$   r$   r%   rN     s   z#NeighborhoodAttentionModule.forwardr   )r-   r.   r/   r=   r   r1   rO   r   r   r   rN   rP   r$   r$   rI   r%   r   o  s    r   c                       2   e Zd Z fddZdejdejfddZ  ZS )NatIntermediatec                    sJ   t    t|t|j| | _t|jt	rt
|j | _d S |j| _d S r   )r<   r=   r   r   ri   	mlp_ratior   
isinstanceZ
hidden_actr{   r   intermediate_act_fnr   rI   r$   r%   r=     s
   
zNatIntermediate.__init__r*   rL   c                 C   r   r   )r   r   rv   r$   r$   r%   rN        

zNatIntermediate.forwardr   r$   r$   rI   r%   r     s    r   c                       r   )	NatOutputc                    s4   t    tt|j| || _t|j| _	d S r   )
r<   r=   r   r   ri   r   r   rC   rD   rE   r   rI   r$   r%   r=     s   
zNatOutput.__init__r*   rL   c                 C   r   r   r   rv   r$   r$   r%   rN     r   zNatOutput.forwardr   r$   r$   rI   r%   r     s    r   c                	       sR   e Zd Zd fdd	Zdd Z	ddejdee d	e	ejejf fd
dZ
  ZS )NatLayerrk   c                    s   t    |j| _|j| _tj||jd| _t|||| jd| _	|dkr(t
|nt | _tj||jd| _t||| _t||| _|jdkrYtj|jtd|f dd| _d S d | _d S )NZeps)rV   rk   r   rQ   T)Zrequires_grad)r<   r=   Zchunk_size_feed_forwardrV   r   r@   layer_norm_epslayernorm_beforer   	attentionru   Identityrt   layernorm_afterr   intermediater   rs   Zlayer_scale_init_valuer   r1   Zoneslayer_scale_parameters)rG   rH   rd   r   drop_path_raterI   r$   r%   r=     s   

zNatLayer.__init__c           
      C   sd   | j }d}||k s||k r.d }}td|| }td|| }	dd||||	f}tj||}||fS )N)r   r   r   r   r   r   r   )rV   maxr   r   pad)
rG   r*   ra   rb   Zwindow_size
pad_valuesZpad_lZpad_tZpad_rZpad_br$   r$   r%   	maybe_pad  s   zNatLayer.maybe_padFr*   r   rL   c                 C   s  |  \}}}}|}| |}| |||\}}|j\}	}
}}	| j||d}|d }|d dkp5|d dk}|rJ|d d d |d |d d f  }| jd urV| jd | }|| | }| |}| 	| 
|}| jd urv| jd | }|| | }|r||d f}|S |f}|S )N)r   r   rS      r   )r   r   r   r^   r   r   r   rt   r   rs   r   )rG   r*   r   
batch_sizera   rb   channelsZshortcutr   r`   Z
height_padZ	width_padZattention_outputsr   Z
was_paddedZlayer_outputlayer_outputsr$   r$   r%   rN     s,   
$


zNatLayer.forward)rk   r   )r-   r.   r/   r=   r   r1   rO   r   r   r   rN   rP   r$   r$   rI   r%   r     s    r   c                       sB   e Zd Z fddZ	d	dejdee deej fddZ	  Z
S )
NatStagec                    sd   t     | _| _t fddt|D | _|d ur*|tjd| _	nd | _	d| _
d S )Nc                    s    g | ]}t  | d qS ))rH   rd   r   r   )r   .0irH   rd   r   r   r$   r%   
<listcomp>  s    z%NatStage.__init__.<locals>.<listcomp>)rd   re   F)r<   r=   rH   rd   r   
ModuleListrangelayersr@   
downsampleZpointing)rG   rH   rd   depthr   r   r   rI   r   r%   r=     s   

zNatStage.__init__Fr*   r   rL   c                 C   sn   |  \}}}}t| jD ]\}}|||}|d }q|}	| jd ur'| |	}||	f}
|r5|
|dd  7 }
|
S r   )r   	enumerater   r   )rG   r*   r   r`   ra   rb   r   layer_moduler   !hidden_states_before_downsamplingZstage_outputsr$   r$   r%   rN   
  s   



zNatStage.forwardr   )r-   r.   r/   r=   r1   rO   r   r   r   rN   rP   r$   r$   rI   r%   r     s    r   c                       sb   e Zd Z fddZ				ddejdee dee dee d	ee d
ee	e
f fddZ  ZS )
NatEncoderc                    sh   t    t j_ _dd tjd jt	 jddD t
 fddtjD _d S )Nc                 S   s   g | ]}|  qS r$   )item)r   r   r$   r$   r%   r   $  s    z'NatEncoder.__init__.<locals>.<listcomp>r   cpu)rp   c                    st   g | ]6}t  t jd |   j|  j| t jd| t jd|d   |jd k r4tnddqS )rQ   Nr   )rH   rd   r   r   r   r   )r   ri   rA   depthsr   sum
num_levelsrc   )r   Zi_layerrH   ZdprrG   r$   r%   r   &  s    	*)r<   r=   r   r   r   rH   r1   Zlinspacer   r   r   r   r   levelsrF   rI   r   r%   r=      s   
$	
zNatEncoder.__init__FTr*   r   output_hidden_states(output_hidden_states_before_downsamplingreturn_dictrL   c                 C   s  |rdnd }|r
dnd }|rdnd }|r&| dddd}	||f7 }||	f7 }t| jD ]H\}
}|||}|d }|d }|rS|rS| dddd}	||f7 }||	f7 }n|ri|si| dddd}	||f7 }||	f7 }|rs||dd  7 }q+|stdd |||fD S t||||dS )	Nr$   r   rS   r   rQ   c                 s   s    | ]	}|d ur|V  qd S r   r$   )r   vr$   r$   r%   	<genexpr>Z  s    z%NatEncoder.forward.<locals>.<genexpr>)r)   r*   r+   r,   )r_   r   r   tupler(   )rG   r*   r   r   r   r   Zall_hidden_statesZall_reshaped_hidden_statesZall_self_attentionsZreshaped_hidden_stater   r   r   r   r$   r$   r%   rN   3  s<   





zNatEncoder.forward)FFFT)r-   r.   r/   r=   r1   rO   r   r   r   r   r(   rN   rP   r$   r$   rI   r%   r     s&    
r   c                   @   s$   e Zd ZdZeZdZdZdd ZdS )NatPreTrainedModelz
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    natrK   c                 C   st   t |tjtjfr#|jjjd| jjd |j	dur!|j	j
  dS dS t |tjr8|j	j
  |jjd dS dS )zInitialize the weightsrk   )meanZstdNg      ?)r   r   r   r\   weightdataZnormal_rH   Zinitializer_rangerf   Zzero_r@   Zfill_)rG   moduler$   r$   r%   _init_weightsn  s   
z NatPreTrainedModel._init_weightsN)	r-   r.   r/   r0   r   config_classZbase_model_prefixZmain_input_namer   r$   r$   r$   r%   r   d  s    r   aF  
    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
    it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
    behavior.

    Parameters:
        config ([`NatConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
a  
    Args:
        pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`]
            for details.

        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
z]The bare Nat Model transformer outputting raw hidden-states without any specific head on top.c                       s   e Zd Zd fdd	Zdd Zdd Zeeee	e
eded		
	
	
	
ddeej dee dee dee deee
f f
ddZ  ZS )NatModelTc                    s   t  | t| dg || _t|j| _t|jd| jd   | _	t
|| _t|| _tj| j	|jd| _|r=tdnd | _|   d S )NnattenrQ   r   r   )r<   r=   r   rH   r   r   r   ri   rA   num_featuresr;   rM   r   encoderr   r@   r   	layernormZAdaptiveAvgPool1dpooler	post_init)rG   rH   Zadd_pooling_layerrI   r$   r%   r=     s   

zNatModel.__init__c                 C      | j jS r   rM   r?   rx   r$   r$   r%   get_input_embeddings     zNatModel.get_input_embeddingsc                 C   s*   |  D ]\}}| jj| j| qdS )z
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        N)itemsr   layerr   r   )rG   Zheads_to_pruner   r   r$   r$   r%   _prune_heads  s   zNatModel._prune_headsZvision)
checkpointoutput_typer   Zmodalityexpected_outputNrK   r   r   r   rL   c           
      C   s   |d ur|n| j j}|d ur|n| j j}|d ur|n| j j}|d u r&td| |}| j||||d}|d }| |}d }| jd urW| |	dd
dd}t	|d}|se||f|dd   }	|	S t|||j|j|jdS )Nz You have to specify pixel_valuesr   r   r   r   r   rQ   )r)   r6   r*   r+   r,   )rH   r   r   use_return_dictr[   rM   r   r   r   flattenZ	transposer1   r5   r*   r+   r,   )
rG   rK   r   r   r   embedding_outputZencoder_outputsZsequence_outputpooled_outputrs   r$   r$   r%   rN     s:   


zNatModel.forward)T)NNNN)r-   r.   r/   r=   r   r   r   NAT_INPUTS_DOCSTRINGr   _CHECKPOINT_FOR_DOCr5   _CONFIG_FOR_DOC_EXPECTED_OUTPUT_SHAPEr   r1   r2   r   r   r   rN   rP   r$   r$   rI   r%   r     s6    	
r   z
    Nat Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
    the [CLS] token) e.g. for ImageNet.
    c                       s   e Zd Z fddZeeeeee	e
d					ddeej deej dee dee d	ee d
eeef fddZ  ZS )NatForImageClassificationc                    s\   t  | t| dg |j| _t|| _|jdkr#t| jj|jnt	 | _
|   d S )Nr   r   )r<   r=   r   
num_labelsr   r   r   r   r   r   
classifierr   rF   rI   r$   r%   r=     s   
"z"NatForImageClassification.__init__)r   r   r   r   NrK   labelsr   r   r   rL   c                 C   sb  |dur|n| j j}| j||||d}|d }| |}d}	|dur| j jdu rL| jdkr2d| j _n| jdkrH|jtjksC|jtj	krHd| j _nd| j _| j jdkrjt
 }
| jdkrd|
| | }	n+|
||}	n%| j jdkrt }
|
|d| j|d}	n| j jdkrt }
|
||}	|s|f|dd  }|	dur|	f| S |S t|	||j|j|jd	S )
a  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        Nr   r   Z
regressionZsingle_label_classificationZmulti_label_classificationr   rQ   )r9   r:   r*   r+   r,   )rH   r   r   r  Zproblem_typer  ro   r1   longri   r	   Zsqueezer   r   r   r8   r*   r+   r,   )rG   rK   r  r   r   r   r   r   r:   r9   Zloss_fctrs   r$   r$   r%   rN   	  sL   


"


z!NatForImageClassification.forward)NNNNN)r-   r.   r/   r=   r   r   r   _IMAGE_CLASS_CHECKPOINTr8   r   _IMAGE_CLASS_EXPECTED_OUTPUTr   r1   r2   Z
LongTensorr   r   r   rN   rP   r$   r$   rI   r%   r    s6    
r  zBNAT backbone, to be used with frameworks like DETR and MaskFormer.c                       sl   e Zd Z fddZdd Zeeeee	d			dde
jdee d	ee d
ee def
ddZ  ZS )NatBackbonec                    s   t    t    t| dg t | _t | _ jg fddt	t
 jD  | _i }t| j| jD ]\}}t|||< q8t|| _|   d S )Nr   c                    s   g | ]}t  jd |  qS )rQ   )ri   rA   r   rH   r$   r%   r   \  s    z(NatBackbone.__init__.<locals>.<listcomp>)r<   r=   Z_init_backboner   r;   rM   r   r   rA   r   r   r   r   zipout_featuresr   r   r@   Z
ModuleDicthidden_states_normsr   )rG   rH   r  stagerZ   rI   r
  r%   r=   T  s   

&zNatBackbone.__init__c                 C   r   r   r   rx   r$   r$   r%   r   g  r   z NatBackbone.get_input_embeddings)r   r   NrK   r   r   r   rL   c                 C   s,  |dur|n| j j}|dur|n| j j}|dur|n| j j}| |}| j||dddd}|j}d}t| j|D ]A\}	}
|	| j	v ry|
j
\}}}}|
dddd }
|
||| |}
| j|	 |
}
|
||||}
|
dddd }
||
f7 }q8|s|f}|r||jf7 }|S t||r|jnd|jd	S )
aA  
        Returns:

        Examples:

        ```python
        >>> from transformers import AutoImageProcessor, AutoBackbone
        >>> import torch
        >>> from PIL import Image
        >>> import requests

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> processor = AutoImageProcessor.from_pretrained("shi-labs/nat-mini-in1k-224")
        >>> model = AutoBackbone.from_pretrained(
        ...     "shi-labs/nat-mini-in1k-224", out_features=["stage1", "stage2", "stage3", "stage4"]
        ... )

        >>> inputs = processor(image, return_tensors="pt")

        >>> outputs = model(**inputs)

        >>> feature_maps = outputs.feature_maps
        >>> list(feature_maps[-1].shape)
        [1, 512, 7, 7]
        ```NT)r   r   r   r   r$   r   rQ   rS   r   )feature_mapsr*   r+   )rH   r   r   r   rM   r   r,   r  Zstage_namesr  r^   r_   r   r   r  r*   r   r+   )rG   rK   r   r   r   r   r   r*   r  r  Zhidden_stater   rZ   ra   rb   rs   r$   r$   r%   rN   j  sD   $


zNatBackbone.forward)NNN)r-   r.   r/   r=   r   r   r   r   r   r   r1   rO   r   r   rN   rP   r$   r$   rI   r%   r	  O  s&    
r	  )r  r   r   r	  )rk   F)Lr0   r   dataclassesr   typingr   r   r   r1   Ztorch.utils.checkpointr   Ztorch.nnr   r   r	   Zactivationsr   Zmodeling_outputsr   Zmodeling_utilsr   Zpytorch_utilsr   r   utilsr   r   r   r   r   r   r   r   r   Zutils.backbone_utilsr   Zconfiguration_natr   Znatten.functionalr   r   Z
get_loggerr-   loggerr   r   r  r  r  r(   r5   r8   rj   r;   r>   rc   rO   rz   r   rt   ru   r}   r   r   r   r   r   r   r   r   ZNAT_START_DOCSTRINGr   r   r  r	  __all__r$   r$   r$   r%   <module>   s~   ,
 ##$ =$C.EUWf