o
    Zhe                 
   @   s  d Z ddlZddlmZ ddlmZmZmZmZm	Z	 ddl
Z
ddlZ
ddl
mZ ddlmZmZmZ ddlmZ dd	lmZmZmZmZmZ dd
lmZ ddlmZmZmZ ddlm Z  ddl!m"Z" e#e$Z%eG dd deZ&eG dd deZ'eG dd deZ(eG dd deZ)G dd dej*Z+G dd dej*Z,G dd dej*Z-dNde
j.d e/d!e0d"e
j.fd#d$Z1G d%d& d&ej*Z2G d'd( d(ej*Z3G d)d* d*ej*Z4G d+d, d,ej*Z5d-e
j.d.ee6 d/ee6 d"e
j.fd0d1Z7G d2d3 d3ej*Z8d-e
j.d4ee6e6f d5ee6e6f d6eee6  d"e
j.f
d7d8Z9eG d9d: d:eZ:G d;d< d<ej*Z;eG d=d> d>e:Z<G d?d@ d@ej*Z=G dAdB dBej*Z>edCdDG dEdF dFe:Z?edGdDG dHdI dIe:Z@edJdDG dKdL dLe:e ZAg dMZBdS )OzPyTorch Hiera model.    N)	dataclass)DictListOptionalTupleUnion)nn)BCEWithLogitsLossCrossEntropyLossMSELoss   )ACT2FN)BackboneOutputBaseModelOutputBaseModelOutputWithPoolingImageClassifierOutputModelOutput)PreTrainedModel)auto_docstringlogging	torch_int)BackboneMixin   )HieraConfigc                   @   sr   e Zd ZU dZdZeej ed< dZ	ee
ejdf  ed< dZee
ejdf  ed< dZee
ejdf  ed< dS )HieraEncoderOutputa  
    Hiera encoder's outputs, with potential hidden states and attentions.

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
            shape `(batch_size, sequence_length, hidden_size)`. Thesre are the unrolled hidden states of the model.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
        reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
            shape `(batch_size, height, width, hidden_size)`. These are the reshaped and re-rolled hidden states of the model.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
            include the spatial dimensions.
    Nlast_hidden_state.hidden_states
attentionsreshaped_hidden_states)__name__
__module____qualname____doc__r   r   torchFloatTensor__annotations__r   r   r   r    r&   r&   W/var/www/auris/lib/python3.10/site-packages/transformers/models/hiera/modeling_hiera.pyr   +   s   
 r   c                   @   s   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
ejed< dZeej ed< dZeeejdf  ed< dZeeejdf  ed	< dZeeejdf  ed
< dS )HieraModelOutputa	  
    Hiera model's outputs that also contains a pooling of the last hidden states.

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*, returned when `add_pooling_layer=True` is passed):
            Average pooling of the last layer hidden-state.
        bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, sequence_length)`):
            Tensor indicating which patches are masked (0) and which are not (1).
        ids_restore (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            Tensor containing the original index of the (shuffled) masked patches.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
            shape `(batch_size, sequence_length, hidden_size)`. These are the unrolled hidden states of the model.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
        reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
            shape `(batch_size, height, width, hidden_size)`. These are the reshaped and re-rolled hidden states of the model.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
            include the spatial dimensions.
    Nr   pooler_outputbool_masked_posids_restore.r   r   r   )r   r    r!   r"   r   r   r#   r$   r%   r)   r*   
BoolTensorr+   
LongTensorr   r   r   r   r&   r&   r&   r'   r(   L   s   
 r(   c                   @   s   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eeejdf  ed< dZeeejdf  ed< dZeeejdf  ed< dS )	!HieraForImageClassificationOutputa  
    Hiera image classification outputs.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, `optional`):
            Loss value for the training task.
        logits (`torch.FloatTensor` of shape `(batch_size, num_labels)`):
            Prediction scores of the classification head (logits of the output layer).
        hidden_states (`tuple(torch.FloatTensor)`, `optional`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
            shape `(batch_size, sequence_length, hidden_size)`. These are the unrolled hidden states of the model.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, `optional`):
            Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
        reshaped_hidden_states (`tuple(torch.FloatTensor)`, `optional`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
            shape `(batch_size, height, width, hidden_size)`. These are the reshaped and re-rolled hidden states of the model.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
            include the spatial dimensions.
    Nlosslogits.r   r   r   )r   r    r!   r"   r/   r   r#   r$   r%   r0   r   r   r   r   r&   r&   r&   r'   r.   v   s   
 r.   c                   @   s   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
ejed< dZeej ed< dZeeej  ed< dZeeej  ed< dZeeej  ed	< dS )
HieraForPreTrainingOutputa  
    Class for HieraForPreTraining's outputs, with potential hidden states and attentions.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`):
            Pixel reconstruction loss.
        logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, patch_size ** 2 * num_channels)`):
            Pixel reconstruction logits.
        bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, sequence_length)`):
            Tensor indicating which patches are masked (0) and which are not (1).
        ids_restore (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            Tensor containing the original index of the (shuffled) masked patches.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
            plus the initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
            the self-attention heads.
        reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, height, width, hidden_size)`. Hidden-states of the model at the output of each layer
            plus the initial embedding outputs reshaped to include the spatial dimensions.
    Nr/   r0   r*   r+   r   r   r   )r   r    r!   r"   r/   r   r#   r$   r%   r0   r*   r,   r+   r-   r   r   r   r   r&   r&   r&   r'   r1      s   
 r1   c                       s   e Zd ZdZddef fddZ	ddejdeej	 d	ej
fd
dZ	ddejdeej d	eej	ejf fddZ	ddejdeej d	eej
eej	 eej f fddZ  ZS )HieraPatchEmbeddingsz
    This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
    `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
    Transformer.
    Fis_maec                    s   t    t|j| _| jdkrtd| j d|j| _|jdd  | _dd t|j|j	D | _
dd t| j
|jD | _|j| _|| _tj| j|j|j|j	|jd| _d S )	N   zAThe number of dimensions of the input image should be 2, but got .c                 S      g | ]\}}|| qS r&   r&   .0isr&   r&   r'   
<listcomp>       z1HieraPatchEmbeddings.__init__.<locals>.<listcomp>c                 S   r7   r&   r&   r8   r&   r&   r'   r<      r=   )kernel_sizestridepadding)super__init__lenZ
patch_sizeZspatial_dims
ValueErrornum_channels
image_sizezippatch_stridetokens_spatial_shapemasked_unit_sizemask_spatial_shape
mask_ratior3   r   Conv2d	embed_dimZpatch_padding
projection)selfconfigr3   	__class__r&   r'   rB      s"   

zHieraPatchEmbeddings.__init__Npixel_valuesr*   returnc                 C   s^   |du r	|  |S |jdd }|j|jd dg| jR  }tjj| |d}|  || S )zZero-out the masked regions of the input before conv.
        Prevents leakage of masked regions when using overlapping kernels.
        Nr4   r   r   )size)rO   shapeviewrK   r   
functionalinterpolatefloat)rP   rT   r*   Ztarget_sizer&   r&   r'   masked_conv   s   
z HieraPatchEmbeddings.masked_convnoisec           	      C   s   |j d }t| j}t|d| j  }|du r!tj|||jd}tj	|dd}tj	|dd
|j}tj||g|jd}d|ddd|f< tj|d|d }||fS )a  
        Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random
        noise.

        Args:
            pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`)
            noise (`torch.FloatTensor` of shape `(batch_size, num_mask_units)`, *optional*) which is
                mainly used for testing purposes to control randomness and maintain the reproducibility
        r   r   Ndevicedim)ra   index)rW   mathprodrK   intrL   r#   randr_   Zargsorttozerosgatherbool)	rP   rT   r]   
batch_sizenum_windowsZlen_keepZids_shuffler+   r*   r&   r&   r'   random_masking   s   
z#HieraPatchEmbeddings.random_maskingc                 C   sD   | j r
| j||dnd\}}| ||}|ddd}|||fS )Nr]   )NNr4   r   )r3   rm   r\   flatten	transpose)rP   rT   r]   r*   r+   
embeddingsr&   r&   r'   forward  s
   
zHieraPatchEmbeddings.forwardFN)r   r    r!   r"   rj   rB   r#   r$   r   r,   Tensorr\   r   r-   rm   rr   __classcell__r&   r&   rR   r'   r2      s6    

%r2   c                       s   e Zd ZdZddededdf fddZd	ejd
ejde	de	dejf
ddZ
d	ejde	de	dedejf
ddZ		ddejdeej dedeejeej eej f fddZ  ZS )HieraEmbeddingsz2
    Construct position and patch embeddings.
    FrQ   r3   rU   Nc                    s   t    |j| _dd t|j|jD }dd t||jD | _t|| _	|| _
t||d| _ttd| j	|j| _d S )Nc                 S   r7   r&   r&   r8   r&   r&   r'   r<   %  r=   z,HieraEmbeddings.__init__.<locals>.<listcomp>c                 S   r7   r&   r&   r8   r&   r&   r'   r<   &  r=   r3   r   )rA   rB   rH   rG   rF   rJ   rK   rc   rd   Z
num_tokensr3   r2   patch_embeddingsr   	Parameterr#   rh   rN   position_embeddings)rP   rQ   r3   rI   rR   r&   r'   rB   "  s   
zHieraEmbeddings.__init__rq   
pos_embedsheightwidthc                 C   s   |j d }|j d }tj s||kr||kr|S |j d }|| jd  }|| jd  }	t|d }
|d|
|
|}|dddd}tj	j
|||	fddd	}|dddddd|}|S )
a2  
        This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
        images. This method is also adapted to support torch.jit tracing, no class embeddings, and different patch strides.

        Adapted from:
        - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
        - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
        r   r         ?r   r4   ZbicubicF)rV   modeZalign_corners)rW   r#   Zjit
is_tracingrH   r   reshapepermuter   rY   rZ   rX   )rP   rq   r|   r}   r~   Znum_patchesZnum_positionsra   Z
new_heightZ	new_widthZsqrt_num_positionsr&   r&   r'   interpolate_pos_encoding.  s$   


z(HieraEmbeddings.interpolate_pos_encodingr   c                 C   s   |r|  || j||S | jS rt   )r   r{   )rP   rq   r}   r~   r   r&   r&   r'   get_position_embeddingT  s
   z&HieraEmbeddings.get_position_embeddingrT   r]   c           	      C   sD   |j dd  \}}| j||d\}}}|| |||| }|||fS )Nr6   rn   )rW   ry   r   )	rP   rT   r]   r   r}   r~   rq   r*   r+   r&   r&   r'   rr   ]  s   
zHieraEmbeddings.forwardrs   NF)r   r    r!   r"   r   rj   rB   r#   ru   re   r   r$   r   r   r   r,   r-   rr   rv   r&   r&   rR   r'   rw     sH    
&
rw   c                       s~   e Zd ZdZ			ddedededed	ed
eddf fddZ		ddejde	ej
 dedeeje	ej f fddZ  ZS )HieraMaskUnitAttentionz
    Computes either Mask Unit or Global Attention. Also is able to perform query pooling.

    Note: this assumes the tokens have already been flattened and unrolled into mask units.
    r   r   Fhidden_sizehidden_size_output	num_headsquery_stridewindow_sizeuse_mask_unit_attnrU   Nc                    sb   t    || _|| _|| _|| | _| jd | _t|d| | _	t||| _
|| _|| _d S )Ng      r   )rA   rB   r   r   r   head_dimscaler   Linearqkvprojr   r   )rP   r   r   r   r   r   r   rR   r&   r'   rB   p  s   
	

zHieraMaskUnitAttention.__init__r   	head_maskoutput_attentionsc                 C   s  |j \}}}d}| jr|| j| j  }| |}||d|d| j| j}|dddddd}|	d\}	}
}| jdkrO|	
|| j|| jd| j}	|	jddj}	|	| j |
dd	 }|jdd}|d
urh|| }|| }|dd|d| j}| |}|r||fS |d
fS )z3Input should be of shape [batch, tokens, channels].r   r   r   r      r4      r`   r6   N)rW   r   r   r   r   r   r   r   r   ZunbindrX   maxvaluesr   rp   Zsoftmaxr   r   )rP   r   r   r   rk   seq_len_rl   r   querykeyvalueattn_weightsZattn_outputr&   r&   r'   rr     s&   


zHieraMaskUnitAttention.forward)r   r   Fr   )r   r    r!   r"   re   rj   rB   r#   ru   r   r$   r   rr   rv   r&   r&   rR   r'   r   i  s>    r           Finput	drop_probtrainingrU   c                 C   sd   |dks|s| S d| }| j d fd| jd   }|tj|| j| jd }|  | || }|S )aF  
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

    Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
    however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
    layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
    argument.
    r   r   r   r   )dtyper_   )rW   ndimr#   rf   r   r_   Zfloor_div)r   r   r   Z	keep_probrW   Zrandom_tensoroutputr&   r&   r'   	drop_path  s   
r   c                       sT   e Zd ZdZddee ddf fddZdejdejfdd	Z	de
fd
dZ  ZS )HieraDropPathzXDrop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).Nr   rU   c                    s   t    || _d S rt   )rA   rB   r   )rP   r   rR   r&   r'   rB     s   

zHieraDropPath.__init__r   c                 C   s   t || j| jS rt   )r   r   r   rP   r   r&   r&   r'   rr     s   zHieraDropPath.forwardc                 C   s   d | jS )Nzp={})formatr   rP   r&   r&   r'   
extra_repr  s   zHieraDropPath.extra_reprrt   )r   r    r!   r"   r   r[   rB   r#   ru   rr   strr   rv   r&   r&   rR   r'   r     s
    r   c                       s<   e Zd Zdeddf fddZdejdejfddZ  ZS )	HieraMlpra   rU   Nc                    sJ   t    t|j | _t|t||j | _	tt||j || _
d S rt   )rA   rB   r   Z
hidden_actactivation_fnr   r   re   Z	mlp_ratiofc1fc2)rP   rQ   ra   rR   r&   r'   rB     s   
zHieraMlp.__init__r   c                 C   s"   |  |}| |}| |}|S rt   )r   r   r   r   r&   r&   r'   rr     s   


zHieraMlp.forward)	r   r    r!   re   rB   r#   ru   rr   rv   r&   r&   rR   r'   r     s    r   c                       s   e Zd Z				ddedededed	ed
ededdf fddZ		ddejde	ej
 dedeeje	ej f fddZ  ZS )
HieraLayerr   r   r   Fr   r   r   r   r   r   r   rU   Nc	           	         s   t    || _|| _|| _tj||jd| _t	||||||d| _
tj||jd| _t||| _|dkr9t|nt | _||krKt||| _d S d S )NZeps)r   r   r   r   r   r   r   )rA   rB   r   r   r   r   	LayerNormlayer_norm_epslayernorm_beforer   attnlayernorm_afterr   mlpr   Identityr   r   r   )	rP   rQ   r   r   r   r   r   r   r   rR   r&   r'   rB     s&   
	zHieraLayer.__init__r   r   r   c           
      C   s   |j \}}}| |}| j| jkr%| |}||| jd| jjddj}| j	|||d\}}|| 
| }|}	| |}| |}|	| 
| }||fS )Nr   r   r`   r   )rW   r   r   r   r   rX   r   r   r   r   r   r   r   )
rP   r   r   r   rk   r   r   Zhidden_states_normr   Zresidualr&   r&   r'   rr     s   




zHieraLayer.forward)r   r   r   Fr   )r   r    r!   re   r[   rj   rB   r#   ru   r   r$   r   rr   rv   r&   r&   rR   r'   r     sB    	
%r   c                       s   e Zd Z	ddededededee dee ded	ed
ee ddf fddZ	dde	j
dee	j dedee	j
ee	j
 f fddZ  ZS )
HieraStageNdepthr   r   r   r   r   r   r   	stage_numrU   c                    sb   t    d|
d ur j|
dkr|
d nd t f	ddt|D | _d S )NFr   r   c                    sD   g | ]}t  |d krn| | po|d kdqS )r   )rQ   r   r   r   r   r   r   r   )r   r9   r:   	rQ   r   r   r   r   Z$previous_stage_used_masked_attentionr   r   r   r&   r'   r<   6  s    z'HieraStage.__init__.<locals>.<listcomp>)rA   rB   masked_unit_attentionr   
ModuleListrangelayers)rP   rQ   r   r   r   r   r   r   r   r   r   rR   r   r'   rB      s   

zHieraStage.__init__Fr   r   r   c                 C   sB   t | jD ]\}}|d ur|| nd }||||d\}}q||fS )Nr   )	enumerater   )rP   r   r   r   r:   Zlayer_modulelayer_head_maskr   r&   r&   r'   rr   E  s   zHieraStage.forwardrt   rs   )r   r    r!   re   r   r[   rj   r   rB   r#   ru   r$   r   rr   rv   r&   r&   rR   r'   r     sB    	
&r   r   rW   mask_unit_shapec                 C   sr   | j d | j d }}dd t||D }| j|g|||R  } | dddddd	} | j|g||R  } | S )
a]  
    Restore spatial organization by undoing windowed organization of mask units.

    Args:
        hidden_states (`torch.Tensor`): The hidden states tensor of shape `[batch_size, num_mask_unit_height*num_mask_unit_width, hidden_size]`.
        shape (`List[int]`): The original shape of the hidden states tensor before windowing.
        mask_unit_shape (`List[int]`): The shape of the mask units used for windowing.

    Returns:
        torch.Tensor: The restored hidden states tensor of shape [batch_size, num_mask_unit_height*mask_unit_height, num_mask_unit_width*mask_unit_width, hidden_size].
    r   r   c                 S   r7   r&   r&   )r9   r;   mur&   r&   r'   r<   `  r=   z"undo_windowing.<locals>.<listcomp>r   r   r4   r   r   )rW   rG   rX   r   r   )r   rW   r   rk   r   num_mask_unitsr&   r&   r'   undo_windowingQ  s   r   c                       s   e Zd Zdeddf fddZ	ddejdedeej	 dejfd	d
Z
					ddejdeej	 deej dedededeeef fddZ  ZS )HieraEncoderrQ   rU   Nc                    s  t    t j}dd tjd j|ddD }tj jddd	 }|d  j
  fddt|D }t | _ j}dg| }t j}t j}	t jD ]E\}
}t j j|
  }t ||| j|
 |||
 ||
d   |||
 ||
d   t||	|
    j|
 |
d
}|}| j| qWd	d t j jD } jgt jd d
  }i | _tt jD ]}
||f| j|
< |
 j
k rdd t| jD }|dd  }qd| _ d S )Nc                 S   s   g | ]}|  qS r&   )item)r9   xr&   r&   r'   r<   p  s    z)HieraEncoder.__init__.<locals>.<listcomp>r   cpur^   c                    s$   g | ]}|v rt  jnd qS r   )rc   rd   r   r   rQ   Zquery_pool_layerr&   r'   r<   t  s   $ r   )
rQ   r   r   r   r   r   r   r   r   r   c                 S   r7   r&   r&   r8   r&   r&   r'   r<     r=   r   c                 S   r7   r&   r&   r8   r&   r&   r'   r<     r=   F)!rA   rB   sumdepthsr#   ZlinspaceZdrop_path_rateZtensorZcumsumtolistnum_query_poolr   r   r   stagesrN   rc   rd   rJ   r   r   re   embed_dim_multiplierr   r   r   appendrG   rF   rH   rC   schedulegradient_checkpointing)rP   rQ   Ztotal_depthZdprZcumulative_depthsZquery_stridesr   Z
stage_endsZmasked_unit_areaZquery_stride_areaZ	idx_stager   r   stageZ
stage_sizeunroll_schedulerR   r   r'   rB   l  sJ   





zHieraEncoder.__init__r   	stage_idxr*   c              
   C   s   | j | \}}|j\}}}t|}	dg|	 }
|D ]C}|j|g||t| |
|R  }|ddddddd}t|	D ]}|
|  || 9  < q=|j|dg|
|R  }|jd }q|j||g|
|R  }|d	urn|S t	|||
}|S )
a\  
        Roll the given tensor back up to spatial order assuming it's from the given block.

        If no bool_masked_pos is provided returns:
            - [batch_size, height, width, hidden_size]
        If a bool_masked_pos is provided returns:
            - [batch_size, num_mask_units, mask_unit_height, mask_unit_width, hidden_size]
        r   r   r   r   r4   r      r   N)
r   rW   rC   rX   rc   rd   r   r   r   r   )rP   r   r   r*   r   rV   rk   r   r   Znum_dimr   stridesr:   r&   r&   r'   reroll  s4   
zHieraEncoder.rerollFTr   r   output_hidden_statesreturn_dictc                 C   s  |rdnd }|r
dnd }|rdnd }	|r&||f }| j |d|d}
||
f }t| jD ]E\}}|d ur7|| nd }| jrI| jrI| |j|||}n||||}|d }|r\|	|d f }	|rp||f }| j |||d}
||
f }q+|stdd |||	|fD S t|||	|dS )Nr&   r   )r   r*   r   c                 s   s    | ]	}|d ur|V  qd S rt   r&   )r9   vr&   r&   r'   	<genexpr>  s    z'HieraEncoder.forward.<locals>.<genexpr>)r   r   r   r   )	r   r   r   r   r   Z_gradient_checkpointing_func__call__tupler   )rP   r   r*   r   r   r   r   Zall_hidden_statesZall_reshaped_hidden_statesZall_self_attentionsr   r:   Zstage_moduler   Zlayer_outputsr&   r&   r'   rr     s@   	





zHieraEncoder.forwardrt   )NNFFT)r   r    r!   r   rB   r#   ru   re   r   r,   r   r$   rj   r   r   r   rr   rv   r&   r&   rR   r'   r   k  s@    3
1
r   image_shaperH   r   c                 C   s  | j \}}}dd t||D }|}| j|g| |g  } |D ]W}	dd t||	D }dd t||	D }
|g|
 |g }
| |
} t|
}dgttd|d d ttd|d d |d g }| |} | dt|	} |t	|	9 }q| 
dt	||} | S )	a  
    Reorders the tokens such that patches are contiguous in memory.
    E.g., given [batch_size, (height, width), hidden_size] and stride of (stride, stride), this will re-order the tokens as
    [batch_size, (stride, stride, height // stride, width // stride), hidden_size]

    This allows operations like Max2d to be computed as x.view(batch_size, stride*stride, -1, hidden_size).max(dim=1).
    Not only is this faster, but it also makes it easy to support inputs of arbitrary
    dimensions in addition to patch-wise sparsity.

    Performing this operation multiple times in sequence puts entire windows as contiguous
    in memory. For instance, if you applied the stride (2, 2) 3 times, entire windows of
    size 8x8 would be contiguous in memory, allowing operations like mask unit attention
    computed easily and efficiently, while also allowing max to be applied sequentially.

    Note: This means that intermediate values of the model are not in height x width order, so they
    need to be re-rolled if you want to use the intermediate values as a height x width feature map.
    The last block of the network is fine though, since by then the strides are all consumed.
    c                 S   r7   r&   r&   r8   r&   r&   r'   r<     r=   zunroll.<locals>.<listcomp>c                 S   r7   r&   r&   r8   r&   r&   r'   r<   !  r=   c                 S   s   g | ]	}|D ]}|qqS r&   r&   )r9   pairr   r&   r&   r'   r<   #  s    r   r4   r   r   )rW   rG   rX   rC   listr   r   ro   rc   rd   r   )r   r   rH   r   rk   r   r   rV   Zcurrent_sizer   Z	new_shapeZnum_dimsr   r&   r&   r'   unroll   s    
8
r   c                   @   s&   e Zd ZeZdZdZdZdddZdS )	HieraPreTrainedModelhierarT   TrU   Nc                 C   s   | j j}t|trtjj|j|d dS t|tr-tjj|j	|d tjj|j
|d dS t|tjtjtjfrRtjj|j|d |jdurPtj|j| dS dS t|tjrltj|j| tj|j| j j dS dS )zInitialize the weights)stdN)rQ   Zinitializer_range
isinstancerw   r   initZtrunc_normal_r{   HieraDecoder
mask_tokendecoder_position_embeddingsr   ZConv1drM   weightZbiasZ	constant_r   Zlayer_norm_init)rP   moduler   r&   r&   r'   _init_weights=  s   


z"HieraPreTrainedModel._init_weights)rU   N)	r   r    r!   r   Zconfig_classZbase_model_prefixZmain_input_nameZsupports_gradient_checkpointingr   r&   r&   r&   r'   r   6  s    r   c                       s8   e Zd Zdef fddZdejdejfddZ  ZS )HieraPoolerrQ   c                    sJ   t    t|j|jt|jd   }tj||j	d| _
td| _d S )Nr   r   )rA   rB   re   rN   r   rC   r   r   r   r   	layernormZAdaptiveAvgPool1dpooler)rP   rQ   num_featuresrR   r&   r'   rB   S  s   
zHieraPooler.__init__r   rU   c                 C   s0   | dd}| |}t|d}| |}|S )Nr   r4   )rp   r   r#   ro   r   )rP   r   pooled_outputr&   r&   r'   rr   Y  s
   

zHieraPooler.forward)	r   r    r!   r   rB   r#   ru   rr   rv   r&   r&   rR   r'   r   R  s    r   c                       s   e Zd Zddededef fddZdefd	d
Zdee	e
e	 f ddfddZe							ddeej deej deej dee dee dee dee deeef fddZ  ZS )
HieraModelTFrQ   add_pooling_layerr3   c                    s~   t  | t|j|jt|jd   | _t||d| _	t
|| _|jgt|jdd  | _|r6t|nd| _|   dS )z
        add_pooling_layer (`bool`, *optional*, defaults to `True`):
            Whether or not to apply pooling layer.
        is_mae (`bool`, *optional*, defaults to `False`):
            Whether or not to run the model on MAE mode.
        r   rx   Nr   )rA   rB   re   rN   r   rC   r   r   rw   rq   r   encoderr   r   r   r   	post_init)rP   rQ   r   r3   rR   r&   r'   rB   c  s    
zHieraModel.__init__rU   c                 C      | j jS rt   rq   ry   r   r&   r&   r'   get_input_embeddingsw     zHieraModel.get_input_embeddingsheads_to_pruneNc                 C   s*   |  D ]\}}| jj| j| qdS )z
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        N)itemsr   layerZ	attentionZprune_heads)rP   r  r  Zheadsr&   r&   r'   _prune_headsz  s   zHieraModel._prune_headsrT   r]   r   r   r   r   r   c              	   C   s  |dur|n| j j}|dur|n| j j}|dur|n| j j}|du r&td| |t| j j}| j|||d\}}	}
|j	d |j	d f}t
||| j j| jd}|	durvt| j j}|j	\}}}|	dd||}|| }||d|}| j||	||||d}|d	 }d}| jdur| |}|s|dur||fn|f}|	dur||	|
f n|}||dd  S t|||	|
|j|j|jd
S )z
        noise (`torch.FloatTensor` of shape `(batch_size, num_mask_units)`, *optional*):
            Mainly used for testing purposes to control randomness and maintain the reproducibility
        Nz You have to specify pixel_values)r   r]   r6   r   )r   rH   r   r   )r*   r   r   r   r   r   )r   r)   r*   r+   r   r   r   )rQ   r   r   use_return_dictrD   Zget_head_maskrC   r   rq   rW   r   rH   r   rc   rd   rJ   Z	unsqueezeZtilerX   r   r   r(   r   r   r   )rP   rT   r]   r   r   r   r   r   embedding_outputr*   r+   r   r   Zmask_unit_areark   r   r   Z	positionsZencoder_outputsZsequence_outputr   Zhead_outputsr&   r&   r'   rr     sb   

zHieraModel.forward)TFNNNNNNN)r   r    r!   r   rj   rB   r2   r  r   re   r   r  r   r   r#   ru   r$   r   r   r   rr   rv   r&   r&   rR   r'   r   a  s:    
	r   c                       s\   e Zd Zdef fddZ		ddejdejdeej d	e	d
e
ejejf f
ddZ  ZS )r   rQ   c                    sP  t    t j jt jd   }dd t j j	D } fddt| j
D | _ fddt j j
D | _t| j| _ttdd j| _ttdt| j j| _t  j j j jddg j dg j dd		| _tj j jd
| _ j	d  j
d  j   | _!| j!t j
  j" }t j|| _#d S )Nr   c                 S   r7   r&   r&   r8   r&   r&   r'   r<     r=   z)HieraDecoder.__init__.<locals>.<listcomp>c                       g | ]\}}|| j   qS r&   r   r8   rQ   r&   r'   r<         c                    r
  r&   r  r8   r  r&   r'   r<     r  Fr   r   )	rQ   r   r   r   r   r   r   r   r   r   r   )$rA   rB   re   rN   r   rC   r   rG   rF   rH   r   tokens_spatial_shape_finalrJ   mask_unit_spatial_shape_finalr   r   decoder_hidden_sizedecoder_embeddingsrz   r#   rh   r   rc   rd   r   r   Zdecoder_num_headsZdecoder_depthdecoder_blockr   r   decoder_normr   pred_striderE   decoder_pred)rP   rQ   r   rI   Zpred_dimrR   r  r'   rB     s:   





zHieraDecoder.__init__NFencoder_hidden_statesr*   r   r   rU   c              	   C   s6  |  |}|jdd  \}}}|j\}	}
tj|	|
||||j|jd}| jddddd}||	|
ddd}|	dd|||}|
 ||< d|  | | |  }t|| j| j}t|dddf | j| j}||jd d|jd }||jd d}|| j }| j|||d\}}| |}| |}||fS )Nr4   )r_   r   r   r   .r   )r   r   )r  rW   r#   rh   r_   r   r   rX   r   expandro   r[   r   r  r  r   r  r  r  )rP   r  r*   r   r   r   mask_unit_heightmask_unit_widthr  rk   r   Zdecoder_hidden_statesZmask_tokensr   r&   r&   r'   rr     sP   

	





zHieraDecoder.forwardr   )r   r    r!   r   rB   r#   ru   r,   r   rj   r   rr   rv   r&   r&   rR   r'   r     s    +r   c                       sX   e Zd Zdef fddZdejdejdejfddZ	d	e
ej dejfd
dZ  ZS )HieraMultiScaleHeadrQ   c              	      s   t     fddt j jD | _ fddtt jD | _	 j}t
 | _t jD ]+}dd t|| jD }dd t| jD }| jt
j| j	| | j	d ||d q0| jt
  d S )Nc                    r
  r&   r  r8   r  r&   r'   r<   @  r  z0HieraMultiScaleHead.__init__.<locals>.<listcomp>c                        g | ]}t  j j|  qS r&   re   rN   r   r   r  r&   r'   r<   C      c                 S   r7   r&   r&   r8   r&   r&   r'   r<   J  r=   c                 S   r7   r&   r&   r8   r&   r&   r'   r<   K  r=   r   )r>   r?   )rA   rB   rG   rJ   r   r  r   rC   r   Zstage_dimensionsr   r   multi_scale_fusion_headsr   r   rM   r   )rP   rQ   Zcurrent_masked_unit_sizeidxZkernelrR   r  r'   rB   >  s*   



zHieraMultiScaleHead.__init__headr   rU   c           
      C   s   t |tjr|S |j\}}}}}||| |||}|dddd}||}|dddd}|jdd  \}}	}|||||	|}|S )Nr   r   r   r4   )r   r   r   rW   r   r   )
rP   r   r   rk   r   r  r  r   Zmask_unit_height_finalZmask_unit_width_finalr&   r&   r'   apply_fusion_headV  s   
z%HieraMultiScaleHead.apply_fusion_headfeature_mapsc                 C   s.   d}t | j|D ]\}}|| || }q|S )Nr   )rG   r  r!  )rP   r"  r   r   Zfeature_mapr&   r&   r'   rr   m  s   zHieraMultiScaleHead.forward)r   r    r!   r   rB   r   Moduler#   ru   r!  r   rr   rv   r&   r&   rR   r'   r  =  s    "r  a;  
    The Hiera Model transformer with the decoder on top for self-supervised pre-training.

    <Tip>

    Note that we provide a script to pre-train this model on custom data in our [examples
    directory](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-pretraining).

    </Tip>
    )Zcustom_introc                       s   e Zd Zdeddf fddZdejdejdejfdd	Zdejd
ejdejfddZ	e
							ddeej deej deej dee dee dee dee deeef fddZ  ZS )HieraForPreTrainingrQ   rU   Nc                    s\   t  | t|ddd| _tj| jj|jd| _t	|| _
t|| _| jj| _|   d S )NFTr   r3   r   )rA   rB   r   r   r   r   r   r   encoder_normr  multiscale_fusionr   decoderr  r   rP   rQ   rR   r&   r'   rB     s   


zHieraForPreTraining.__init__rT   r*   c                 C   s   | dddd}| j}|d||d||}|ddd}|| }| jjr@|jddd}|jddd}|| |d d	  }|S )
Nr   r4   r   r   r   T)ra   Zkeepdimgư>r   )r   r  Zunfoldro   rQ   Znormalize_pixel_lossmeanvar)rP   rT   r*   rV   labelr*  r+  r&   r&   r'   get_pixel_label_2d  s   z&HieraForPreTraining.get_pixel_label_2dr0   c                 C   s2   | }|  ||}|| }|| d }| }|S )Nr4   )r-  r*  )rP   rT   r0   r*   r,  r/   r&   r&   r'   forward_loss  s   z HieraForPreTraining.forward_lossr]   r   r   r   r   r   c              	   C   sN  |dur|n| j j}|dur|n| j j}|dur|n| j j}| j||||d||d}|d }	|d }
|d }|	d| jj jd  |	d f }	| |	}| |}| j||
||d\}}
| 	|||
}|s||
|f}|rr||d f }|r{||d	 f }|r||d f }|dur|f| S |S t
|||
||r|jnd|j|r|jd
S dd
S )a  
        noise (`torch.FloatTensor` of shape `(batch_size, num_mask_units)`, *optional*):
            Mainly used for testing purposes to control randomness and maintain the reproducibility

        Examples:
        ```python
        >>> from transformers import AutoImageProcessor, HieraForPreTraining
        >>> import torch
        >>> from PIL import Image
        >>> import requests

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> image_processor = AutoImageProcessor.from_pretrained("facebook/hiera-tiny-224-mae-hf")
        >>> model = HieraForPreTraining.from_pretrained("facebook/hiera-tiny-224-mae-hf")

        >>> inputs = image_processor(images=image, return_tensors="pt")

        >>> outputs = model(**inputs)
        >>> logits = outputs.logits
        >>> loss = outputs.loss
        >>> print(list(logits.shape))
        [1, 196, 768]
        ```NT)r]   r   r   r   r   r   r   r   r4   )r*   r   r   r   r   )r/   r0   r*   r+   r   r   r   )rQ   r  r   r   r   r   r'  r&  r(  r.  r1   r   r   r   )rP   rT   r]   r   r   r   r   r   outputsr"  r*   Zids_to_restoreZfused_hidden_statesr0   r/   r   r&   r&   r'   rr     s\   $
 



zHieraForPreTraining.forwardr	  )r   r    r!   r   rB   r#   ru   r,   r-  r.  r   r   r$   rj   r   r   r1   rr   rv   r&   r&   rR   r'   r$  v  s:    
	r$  a  
    Hiera Model transformer with an image classification head on top (a linear layer on top of the final hidden state with
    average pooling) e.g. for ImageNet.

    <Tip>

        Note that it's possible to fine-tune Hiera on higher resolution images than the ones it has been trained on, by
        setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
        position embeddings to the higher resolution.

    </Tip>
    c                       s   e Zd Zdeddf fddZe						ddeej deej dee	 d	ee	 d
ee	 dee	 de
eef fddZ  ZS )HieraForImageClassificationrQ   rU   Nc                    sV   t  | |j| _t|ddd| _|jdkr t| jj|jnt | _	| 
  d S )NTFr%  r   )rA   rB   
num_labelsr   r   r   r   r   r   
classifierr   r)  rR   r&   r'   rB     s   "z$HieraForImageClassification.__init__r   labelsr   r   r   r   c                 C   s  |dur|n| j j}|dur|n| j j}|dur|n| j j}| j||||||d}|d }	| |	}
d}|dur||
j}| j jdu rh| j	dkrNd| j _n| j	dkrd|j
tjks_|j
tjkrdd| j _nd| j _| j jdkrt }| j	dkr||
 | }n+||
|}n%| j jdkrt }||
d| j	|d}n| j jdkrt }||
|}|s|
f|dd  }|dur|f| S |S t||
|j|j|jd	S )
a  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        N)r   r   r   r   r   r   Z
regressionZsingle_label_classificationZmulti_label_classificationr   r4   )r/   r0   r   r   r   )rQ   r  r   r   r   r2  rg   r_   Zproblem_typer1  r   r#   longre   r   Zsqueezer
   rX   r	   r.   r   r   r   )rP   rT   r   r3  r   r   r   r   r/  r   r0   r/   Zloss_fctr   r&   r&   r'   rr   #  sX   	


"


z#HieraForImageClassification.forward)NNNNNN)r   r    r!   r   rB   r   r   r#   ru   rj   r   r   r.   rr   rv   r&   r&   rR   r'   r0    s0    
	r0  zN
    Hiera backbone, to be used with frameworks like DETR and MaskFormer.
    c                       s^   e Zd Zdef fddZdd Z			ddejdee	 d	ee	 d
ee	 de
f
ddZ  ZS )HieraBackbonerQ   c                    s   t    t     jg fddtt jD  | _t dd| _	t
 | _i }t| j| jD ]\}}t|||< q4t|| _|   d S )Nc                    r  r&   r  r   r  r&   r'   r<   w  r  z*HieraBackbone.__init__.<locals>.<listcomp>Frx   )rA   rB   Z_init_backbonerN   r   rC   r   r   rw   rq   r   r   rG   Z_out_featuresZchannelsr   r   Z
ModuleDicthidden_states_normsr   )rP   rQ   r6  r   rE   rR   r  r'   rB   s  s   

zHieraBackbone.__init__c                 C   r   rt   r   r   r&   r&   r'   r    r  z"HieraBackbone.get_input_embeddingsNrT   r   r   r   rU   c                 C   sD  |dur|n| j j}|dur|n| j j}|dur|n| j j}| |\}}}| j|d|d|d}|d }d}	t| j|D ]7\}
}|
| jv rs|j	\}}}}|
||| |}| j|
 |}|
||||}|dddd	 }|	|f7 }	q<|s|	f}|r||d f7 }|r||d	 f7 }|S t|	|r|d nd|r|d	 d
S dd
S )a?  
        Returns:

        Examples:

        ```python
        >>> from transformers import AutoImageProcessor, AutoBackbone
        >>> import torch
        >>> from PIL import Image
        >>> import requests

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> processor = AutoImageProcessor.from_pretrained("facebook/hiera-tiny-224-hf")
        >>> model = AutoBackbone.from_pretrained(
        ...     "facebook/hiera-tiny-224-hf", out_features=["stage1", "stage2", "stage3", "stage4"]
        ... )

        >>> inputs = processor(image, return_tensors="pt")
        >>> outputs = model(**inputs)
        >>> feature_maps = outputs.feature_maps
        >>> list(feature_maps[-1].shape)
        [1, 768, 7, 7]
        ```NT)r   r   r   r   r   r&   r   r   r   r4   )r"  r   r   )rQ   r  r   r   rq   r   rG   Zstage_namesZout_featuresrW   rX   r6  r   
contiguousr   )rP   rT   r   r   r   r  r   r/  r   r"  r   Zhidden_staterk   r}   r~   rE   r   r&   r&   r'   rr     sJ    


zHieraBackbone.forward)NNN)r   r    r!   r   rB   r  r#   ru   r   rj   r   rr   rv   r&   r&   rR   r'   r5  m  s"    r5  )r0  r$  r5  r   r   )r   F)Cr"   rc   dataclassesr   typingr   r   r   r   r   r#   Ztorch.utils.checkpointr   Ztorch.nnr	   r
   r   Zactivationsr   Zmodeling_outputsr   r   r   r   r   Zmodeling_utilsr   utilsr   r   r   Zutils.backbone_utilsr   Zconfiguration_hierar   Z
get_loggerr   loggerr   r(   r.   r1   r#  r2   rw   r   ru   r[   rj   r   r   r   r   r   re   r   r   r   r   r   r   r   r  r$  r0  r5  __all__r&   r&   r&   r'   <module>   s   
 )#$^L E@&2 



6ri9 Yb