o
    ZhX                     @   sr  d Z ddlZddlmZ ddlmZ ddlmZm	Z	m
Z
mZmZ ddlZddlZddlZddlmZ ddlmZmZmZ dd	lmZ dd
lmZmZ ddlmZmZ ddlmZm Z  ddl!m"Z"m#Z#m$Z$ ddl%m&Z&m'Z' ddl(m)Z) e$*e+Z,eG dd de"Z-eG dd de"Z.dd Z/G dd dej0Z1G dd dej0Z2	dAdej0dej3dej3dej3d e	ej3 d!e4d"e4fd#d$Z5G d%d& d&ej0Z6G d'd( d(ej0Z7G d)d* d*ej0Z8G d+d, d,ej0Z9G d-d. d.ej0Z:G d/d0 d0ej0Z;G d1d2 d2ej0Z<e#G d3d4 d4eZ=e#G d5d6 d6e=Z>G d7d8 d8ej0Z?e#d9d:G d;d< d<e=Z@e#d=d:G d>d? d?e=ZAg d@ZBdS )Bz,PyTorch VideoMAE (masked autoencoder) model.    N)deepcopy)	dataclass)CallableOptionalSetTupleUnion)nn)BCEWithLogitsLossCrossEntropyLossMSELoss   )ACT2FN)BaseModelOutputImageClassifierOutput)ALL_ATTENTION_FUNCTIONSPreTrainedModel) find_pruneable_heads_and_indicesprune_linear_layer)ModelOutputauto_docstringlogging)IMAGENET_DEFAULT_MEANIMAGENET_DEFAULT_STD   )VideoMAEConfigc                   @   sP   e Zd ZU dZdZeej ed< dZ	ee
ej  ed< dZee
ej  ed< dS )VideoMAEDecoderOutputaO  
    Class for VideoMAEDecoder's outputs, with potential hidden states and attentions.

    Args:
        logits (`torch.FloatTensor` of shape `(batch_size, patch_size ** 2 * num_channels)`):
            Pixel reconstruction logits.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
            plus the initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
            the self-attention heads.
    Nlogitshidden_states
attentions)__name__
__module____qualname____doc__r   r   torchFloatTensor__annotations__r   r   r    r'   r'   ]/var/www/auris/lib/python3.10/site-packages/transformers/models/videomae/modeling_videomae.pyr   ,   s
   
 r   c                   @   sb   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eeej  ed< dZeeej  ed< dS )VideoMAEForPreTrainingOutputa  
    Class for VideoMAEForPreTraining's outputs, with potential hidden states and attentions.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`):
            Pixel reconstruction loss.
        logits (`torch.FloatTensor` of shape `(batch_size, patch_size ** 2 * num_channels)`):
            Pixel reconstruction logits.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
            plus the initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
            the self-attention heads.
    Nlossr   r   r   )r    r!   r"   r#   r*   r   r$   r%   r&   r   r   r   r   r'   r'   r'   r(   r)   C   s   
 r)   c                    s    fddt fddt| D }t |dddddf |dddddf< t |dddddf |dddddf< t|dS )	z Sinusoid position encoding tablec                    s    fddt D S )Nc              	      s(   g | ]}t d d|d     qS )i'     )nppower).0Zhid_j)d_hidpositionr'   r(   
<listcomp>d   s   ( zOget_sinusoid_encoding_table.<locals>.get_position_angle_vec.<locals>.<listcomp>)ranger0   )r/   r3   r(   get_position_angle_vecc   s   z;get_sinusoid_encoding_table.<locals>.get_position_angle_vecc                    s   g | ]} |qS r'   r'   )r.   Zpos_i)r4   r'   r(   r1   f       z/get_sinusoid_encoding_table.<locals>.<listcomp>Nr   r+   r   )r,   arrayr2   sincosr$   r%   Z	unsqueeze)Z
n_positionr/   Zsinusoid_tabler'   )r/   r4   r(   get_sinusoid_encoding_table_   s
   ..r9   c                       (   e Zd ZdZ fddZdd Z  ZS )VideoMAEEmbeddingsz7
    Construct the patch and position embeddings.

    c                    s8   t    t|| _| jj| _t| j|j| _|| _d S N)	super__init__VideoMAEPatchEmbeddingspatch_embeddingsnum_patchesr9   hidden_sizeposition_embeddingsconfigselfrD   	__class__r'   r(   r>   s   s
   



zVideoMAEEmbeddings.__init__c                 C   sZ   |  |}|| j |j|jdd }|d ur+|j\}}}||  }||d|}|S )NTdevicecopy)r@   rC   detachtype_astorJ   shapereshape)rF   pixel_valuesbool_masked_pos
embeddings
batch_size_num_channelsr'   r'   r(   forward|   s   

zVideoMAEEmbeddings.forwardr    r!   r"   r#   r>   rX   __classcell__r'   r'   rG   r(   r;   m   s    	r;   c                       r:   )r?   aw  
    Video to Patch Embedding. This module turns a batch of videos of shape (batch_size, num_frames, num_channels,
    height, width) into a tensor of shape (batch_size, seq_len, hidden_size) to be consumed by a Transformer encoder.

    The seq_len (the number of patches) equals (number of frames // tubelet_size) * (height // patch_size) * (width //
    patch_size).

    c           	         s   t    |j}|j}|j}|j}|j}|j}t|t	j
jr |n||f}t|t	j
jr-|n||f}|| _|| _t|| _|d |d  |d |d   || j  }|| _|| _tj||| j|d |d f| j|d |d fd| _d S )Nr   r   )Zin_channelsZout_channelsZkernel_sizeZstride)r=   r>   
image_size
patch_sizerW   rB   
num_framestubelet_size
isinstancecollectionsabcIterableintrA   r	   Conv3d
projection)	rF   rD   r[   r\   rW   rB   r]   r^   rA   rG   r'   r(   r>      s,   

(z VideoMAEPatchEmbeddings.__init__c              
   C   s   |j \}}}}}|| jkrtd|| jd ks|| jd kr6td| d| d| jd  d| jd  d	|dddd	d
}| |ddd}|S )NzeMake sure that the channel dimension of the pixel values match with the one set in the configuration.r   r   zInput image size (*z) doesn't match model (z).r+   r      )rP   rW   
ValueErrorr[   permutere   flatten	transpose)rF   rR   rU   r]   rW   heightwidthrT   r'   r'   r(   rX      s   
(zVideoMAEPatchEmbeddings.forwardrY   r'   r'   rG   r(   r?      s    	r?           modulequerykeyvalueattention_maskscalingdropoutc           
      K   s|   t ||dd| }tjj|dt jd|j}tjj	||| j
d}|d ur,|| }t ||}	|	dd }	|	|fS )NrL   )dimdtype)ptrainingr   r+   )r$   matmulrk   r	   
functionalZsoftmaxZfloat32rO   rx   ru   rz   
contiguous)
ro   rp   rq   rr   rs   rt   ru   kwargsZattn_weightsZattn_outputr'   r'   r(   eager_attention_forward   s   r   c                
       sv   e Zd Zdeddf fddZdejdejfddZ		dd
eej de	de
eejejf eej f fddZ  ZS )VideoMAESelfAttentionrD   returnNc                    s
  t    |j|j dkrt|dstd|j d|j d|| _|j| _t|j|j | _| j| j | _	|j
| _| jd | _d| _tj|j| j	dd| _tj|j| j	dd| _tj|j| j	dd| _|jr}tt| j	| _tt| j	| _d S d | _d | _d S )	Nr   Zembedding_sizezThe hidden size z4 is not a multiple of the number of attention heads .g      Fbias)r=   r>   rB   num_attention_headshasattrrh   rD   rc   attention_head_sizeall_head_sizeZattention_probs_dropout_probdropout_probrt   	is_causalr	   Linearrp   rq   rr   Zqkv_bias	Parameterr$   zerosq_biasv_biasrE   rG   r'   r(   r>      s,   


zVideoMAESelfAttention.__init__xc                 C   s6   |  d d | j| jf }||}|ddddS )NrL   r   r+   r   r   )sizer   r   viewri   )rF   r   Znew_x_shaper'   r'   r(   transpose_for_scores   s   
z*VideoMAESelfAttention.transpose_for_scoresF	head_maskoutput_attentionsc              
   C   s"  | j d urtj| jddnd }tjj|| jj|d}tjj|| j	j| jd}tjj|| j
j| j d}| |}| |}	| |}
t}| jjdkr]| jjdkrW|rWtd nt| jj }|| |
||	|| j| j| jsldn| jd\}}| d d	 | jf }||}|r||f}|S |f}|S )
NF)Zrequires_grad)inputweightr   eagerZsdpaz`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.rn   )r   rt   ru   rv   )r   r$   Z
zeros_liker   r	   r|   Zlinearrq   r   rr   rp   r   r   rD   Z_attn_implementationloggerZwarning_oncer   r   rt   rz   r   r   r   rQ   )rF   r   r   r   Zk_biaskeysvaluesZqueriesZ	key_layerZvalue_layerZquery_layerZattention_interfaceZcontext_layerZattention_probsZnew_context_layer_shapeoutputsr'   r'   r(   rX     s<   




zVideoMAESelfAttention.forwardNF)r    r!   r"   r   r>   r$   Tensorr   r   boolr   r   rX   rZ   r'   r'   rG   r(   r      s    r   c                       sF   e Zd ZdZdeddf fddZdejdejdejfd	d
Z  Z	S )VideoMAESelfOutputz
    The residual connection is defined in VideoMAELayer instead of here (as is the case with other models), due to the
    layernorm applied before each block.
    rD   r   Nc                    s.   t    t|j|j| _t|j| _d S r<   )	r=   r>   r	   r   rB   denseDropouthidden_dropout_probru   rE   rG   r'   r(   r>   2     
zVideoMAESelfOutput.__init__r   input_tensorc                 C      |  |}| |}|S r<   r   ru   rF   r   r   r'   r'   r(   rX   7     

zVideoMAESelfOutput.forward)
r    r!   r"   r#   r   r>   r$   r   rX   rZ   r'   r'   rG   r(   r   ,  s    $r   c                       s~   e Zd Zdeddf fddZdee ddfddZ			dd
ej	de
ej	 dedeeej	ej	f eej	 f fddZ  ZS )VideoMAEAttentionrD   r   Nc                    s*   t    t|| _t|| _t | _d S r<   )r=   r>   r   	attentionr   outputsetpruned_headsrE   rG   r'   r(   r>   @  s   


zVideoMAEAttention.__init__headsc                 C   s   t |dkrd S t|| jj| jj| j\}}t| jj|| j_t| jj|| j_t| jj	|| j_	t| j
j|dd| j
_| jjt | | j_| jj| jj | j_| j|| _d S )Nr   r   rw   )lenr   r   r   r   r   r   rp   rq   rr   r   r   r   union)rF   r   indexr'   r'   r(   prune_headsF  s   zVideoMAEAttention.prune_headsFr   r   r   c                 C   s4   |  |||}| |d |}|f|dd   }|S )Nr   r   )r   r   )rF   r   r   r   Zself_outputsattention_outputr   r'   r'   r(   rX   X  s   zVideoMAEAttention.forwardr   )r    r!   r"   r   r>   r   rc   r   r$   r   r   r   r   r   rX   rZ   r'   r'   rG   r(   r   ?  s    r   c                       s<   e Zd Zdeddf fddZdejdejfddZ  ZS )	VideoMAEIntermediaterD   r   Nc                    sD   t    t|j|j| _t|jt	rt
|j | _d S |j| _d S r<   )r=   r>   r	   r   rB   intermediate_sizer   r_   Z
hidden_actstrr   intermediate_act_fnrE   rG   r'   r(   r>   h  s
   
zVideoMAEIntermediate.__init__r   c                 C   r   r<   )r   r   )rF   r   r'   r'   r(   rX   p  r   zVideoMAEIntermediate.forward	r    r!   r"   r   r>   r$   r   rX   rZ   r'   r'   rG   r(   r   g  s    r   c                       sB   e Zd Zdeddf fddZdejdejdejfdd	Z  ZS )
VideoMAEOutputrD   r   Nc                    s.   t    t|j|j| _t|j| _	d S r<   )
r=   r>   r	   r   r   rB   r   r   r   ru   rE   rG   r'   r(   r>   y  r   zVideoMAEOutput.__init__r   r   c                 C   s    |  |}| |}|| }|S r<   r   r   r'   r'   r(   rX   ~  s   

zVideoMAEOutput.forwardr   r'   r'   rG   r(   r   x  s    $r   c                       sl   e Zd ZdZdeddf fddZ		ddejd	eej d
e	de
eejejf eej f fddZ  ZS )VideoMAELayerz?This corresponds to the Block class in the timm implementation.rD   r   Nc                    sb   t    |j| _d| _t|| _t|| _t|| _	t
j|j|jd| _t
j|j|jd| _d S )Nr   Zeps)r=   r>   Zchunk_size_feed_forwardZseq_len_dimr   r   r   intermediater   r   r	   	LayerNormrB   layer_norm_epslayernorm_beforelayernorm_afterrE   rG   r'   r(   r>     s   



zVideoMAELayer.__init__Fr   r   r   c                 C   s`   | j | |||d}|d }|dd  }|| }| |}| |}| ||}|f| }|S )N)r   r   r   )r   r   r   r   r   )rF   r   r   r   Zself_attention_outputsr   r   Zlayer_outputr'   r'   r(   rX     s   


zVideoMAELayer.forwardr   )r    r!   r"   r#   r   r>   r$   r   r   r   r   r   rX   rZ   r'   r'   rG   r(   r     s    r   c                       sb   e Zd Zdeddf fddZ				ddejd	eej d
ededede	e
ef fddZ  ZS )VideoMAEEncoderrD   r   Nc                    s:   t     | _t fddt jD | _d| _d S )Nc                       g | ]}t  qS r'   r   r.   rV   rD   r'   r(   r1     r5   z,VideoMAEEncoder.__init__.<locals>.<listcomp>F)	r=   r>   rD   r	   
ModuleListr2   num_hidden_layerslayergradient_checkpointingrE   rG   r   r(   r>     s   
 
zVideoMAEEncoder.__init__FTr   r   r   output_hidden_statesreturn_dictc                 C   s   |rdnd }|r
dnd }t | jD ]8\}}	|r||f }|d ur$|| nd }
| jr6| jr6| |	j||
|}n|	||
|}|d }|rI||d f }q|rQ||f }|s_tdd |||fD S t|||dS )Nr'   r   r   c                 s       | ]	}|d ur|V  qd S r<   r'   r.   vr'   r'   r(   	<genexpr>      z*VideoMAEEncoder.forward.<locals>.<genexpr>Zlast_hidden_stater   r   )	enumerater   r   rz   _gradient_checkpointing_func__call__tupler   )rF   r   r   r   r   r   all_hidden_statesall_self_attentionsilayer_moduleZlayer_head_masklayer_outputsr'   r'   r(   rX     s6   

zVideoMAEEncoder.forward)NFFT)r    r!   r"   r   r>   r$   r   r   r   r   r   r   rX   rZ   r'   r'   rG   r(   r     s&    	
r   c                   @   s,   e Zd ZeZdZdZdZdZdZ	dd Z
dS )VideoMAEPreTrainedModelvideomaerR   Tc                 C   st   t |tjtjfr#|jjjd| jjd |j	dur!|j	j
  dS dS t |tjr8|j	j
  |jjd dS dS )zInitialize the weightsrn   )meanstdNg      ?)r_   r	   r   rd   r   dataZnormal_rD   Zinitializer_ranger   Zzero_r   Zfill_)rF   ro   r'   r'   r(   _init_weights  s   
z%VideoMAEPreTrainedModel._init_weightsN)r    r!   r"   r   Zconfig_classZbase_model_prefixZmain_input_nameZsupports_gradient_checkpointingZ_supports_sdpaZ_supports_flash_attn_2r   r'   r'   r'   r(   r     s    r   c                       s   e Zd Z fddZdd Zdd Ze					ddejd	e	ej
 d
e	ej de	e de	e de	e deeef fddZ  ZS )VideoMAEModelc                    sT   t  | || _t|| _t|| _|jrd | _n
t	j
|j|jd| _|   d S )Nr   )r=   r>   rD   r;   rT   r   encoderuse_mean_pooling	layernormr	   r   rB   r   	post_initrE   rG   r'   r(   r>     s   

zVideoMAEModel.__init__c                 C   s   | j jS r<   )rT   r@   )rF   r'   r'   r(   get_input_embeddings  s   z"VideoMAEModel.get_input_embeddingsc                 C   s*   |  D ]\}}| jj| j| qdS )z
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        N)itemsr   r   r   r   )rF   Zheads_to_pruner   r   r'   r'   r(   _prune_heads  s   zVideoMAEModel._prune_headsNrR   rS   r   r   r   r   r   c           
      C   s   |dur|n| j j}|dur|n| j j}|dur|n| j j}| || j j}| ||}| j|||||d}|d }	| jdurD| |	}	|sO|	f|dd  S t	|	|j
|jdS )a  
        bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). Each video in the
            batch must have the same number of masked patches. If `None`, then all patches are considered. Sequence
            length is `(num_frames // tubelet_size) * (image_size // patch_size) ** 2`.

        Examples:

        ```python
        >>> import av
        >>> import numpy as np

        >>> from transformers import AutoImageProcessor, VideoMAEModel
        >>> from huggingface_hub import hf_hub_download

        >>> np.random.seed(0)


        >>> def read_video_pyav(container, indices):
        ...     '''
        ...     Decode the video with PyAV decoder.
        ...     Args:
        ...         container (`av.container.input.InputContainer`): PyAV container.
        ...         indices (`List[int]`): List of frame indices to decode.
        ...     Returns:
        ...         result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
        ...     '''
        ...     frames = []
        ...     container.seek(0)
        ...     start_index = indices[0]
        ...     end_index = indices[-1]
        ...     for i, frame in enumerate(container.decode(video=0)):
        ...         if i > end_index:
        ...             break
        ...         if i >= start_index and i in indices:
        ...             frames.append(frame)
        ...     return np.stack([x.to_ndarray(format="rgb24") for x in frames])


        >>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
        ...     '''
        ...     Sample a given number of frame indices from the video.
        ...     Args:
        ...         clip_len (`int`): Total number of frames to sample.
        ...         frame_sample_rate (`int`): Sample every n-th frame.
        ...         seg_len (`int`): Maximum allowed index of sample's last frame.
        ...     Returns:
        ...         indices (`List[int]`): List of sampled frame indices
        ...     '''
        ...     converted_len = int(clip_len * frame_sample_rate)
        ...     end_idx = np.random.randint(converted_len, seg_len)
        ...     start_idx = end_idx - converted_len
        ...     indices = np.linspace(start_idx, end_idx, num=clip_len)
        ...     indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
        ...     return indices


        >>> # video clip consists of 300 frames (10 seconds at 30 FPS)
        >>> file_path = hf_hub_download(
        ...     repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset"
        ... )
        >>> container = av.open(file_path)

        >>> # sample 16 frames
        >>> indices = sample_frame_indices(clip_len=16, frame_sample_rate=1, seg_len=container.streams.video[0].frames)
        >>> video = read_video_pyav(container, indices)

        >>> image_processor = AutoImageProcessor.from_pretrained("MCG-NJU/videomae-base")
        >>> model = VideoMAEModel.from_pretrained("MCG-NJU/videomae-base")

        >>> # prepare video for the model
        >>> inputs = image_processor(list(video), return_tensors="pt")

        >>> # forward pass
        >>> outputs = model(**inputs)
        >>> last_hidden_states = outputs.last_hidden_state
        >>> list(last_hidden_states.shape)
        [1, 1568, 768]
        ```Nr   r   r   r   r   r   r   )rD   r   r   use_return_dictZget_head_maskr   rT   r   r   r   r   r   )
rF   rR   rS   r   r   r   r   Zembedding_outputZencoder_outputssequence_outputr'   r'   r(   rX     s.   Y

zVideoMAEModel.forward)NNNNN)r    r!   r"   r>   r   r   r   r$   r%   r   
BoolTensorr   r   r   r   r   rX   rZ   r'   r'   rG   r(   r     s2    
r   c                       s,   e Zd Z fddZ			dddZ  ZS )VideoMAEDecoderc                    s   t    |j|j |jd  }t| |j _|j _	|j
 _|j _t fddt|jD | _t|j| _|dkrFt|j|nt | _d| _|| _d S )Nr+   c                    r   r'   r   r   Zdecoder_configr'   r(   r1     r5   z,VideoMAEDecoder.__init__.<locals>.<listcomp>r   F)r=   r>   rW   r^   r\   r   decoder_hidden_sizerB   Zdecoder_num_hidden_layersr   Zdecoder_num_attention_headsr   Zdecoder_intermediate_sizer   r	   r   r2   decoder_layersr   normr   Identityheadr   rD   )rF   rD   rA   Zdecoder_num_labelsrG   r   r(   r>     s   

zVideoMAEDecoder.__init__FTc                 C   s   |rdnd }|r
dnd }t | jD ]/\}}	|r||f }| jr,| jr,| |	j|d |}
n|	|d |d}
|
d }|r@||
d f }q|rH||f }|dkrW|d d | d f }| |}| |}|sotdd |||fD S t	|||dS )Nr'   )r   r   r   r   c                 s   r   r<   r'   r   r'   r'   r(   r     r   z*VideoMAEDecoder.forward.<locals>.<genexpr>)r   r   r   )
r   r   r   rz   r   r   r   r   r   r   )rF   r   Zreturn_token_numr   r   r   r   r   r   r   r   r   r'   r'   r(   rX     s4   	



zVideoMAEDecoder.forward)FFT)r    r!   r"   r>   rX   rZ   r'   r'   rG   r(   r     s    r   zb
    The VideoMAE Model transformer with the decoder on top for self-supervised pre-training.
    )Zcustom_introc                       sn   e Zd Z fddZe				ddejdejdeej	 dee
 dee
 d	ee
 d
eeef fddZ  ZS )VideoMAEForPreTrainingc                    s~   t  | || _t|| _tj|j|jdd| _	t
tdd|j| _t| jjj|j| _t|| jjjd| _|   d S )NFr   r   )rA   )r=   r>   rD   r   r   r	   r   rB   r   encoder_to_decoderr   r$   r   
mask_tokenr9   rT   rA   rC   r   decoderr   rE   rG   r'   r(   r>     s   
zVideoMAEForPreTraining.__init__NrR   rS   r   r   r   r   r   c           #      C   sj  |dur|n| j j}| j||||||d}|d }| |}|j\}	}
}|du r,td| j|	dd|}|	 j
|jdd}||  |	d|}|| |	d|}tj|| | j| gdd	}| ||jd }|j}d}t  | j jd
kr|}n2|j}|j}ttj
||dddddddf }ttj
||dddddddf }|| | }|j\}	}}}}| j j| j j}}| j jr*||	|| |||| ||| |}|dddddddd
 }||	|| | | | | || | |}||jddd |jdddd  d  }||	|| | | | | || | | }nB| j jd
kr5td||	|| |||| ||| |}|dddddddd
 }||	|| | | | | || | | }|j\}	}}|| |	d|} W d   n	1 sw   Y  t! }!|!|| }|s|f|dd  }"|dur|f|" S |"S t"|||j#|j$dS )a  
        bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, sequence_length)`):
            Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). Each video in the
            batch must have the same number of masked patches. Sequence length is `(num_frames // tubelet_size) *
            (image_size // patch_size) ** 2`.

        Examples:
        ```python
        >>> from transformers import AutoImageProcessor, VideoMAEForPreTraining
        >>> import numpy as np
        >>> import torch

        >>> num_frames = 16
        >>> video = list(np.random.randint(0, 256, (num_frames, 3, 224, 224)))

        >>> image_processor = AutoImageProcessor.from_pretrained("MCG-NJU/videomae-base")
        >>> model = VideoMAEForPreTraining.from_pretrained("MCG-NJU/videomae-base")

        >>> pixel_values = image_processor(video, return_tensors="pt").pixel_values

        >>> num_patches_per_frame = (model.config.image_size // model.config.patch_size) ** 2
        >>> seq_length = (num_frames // model.config.tubelet_size) * num_patches_per_frame
        >>> bool_masked_pos = torch.randint(0, 2, (1, seq_length)).bool()

        >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
        >>> loss = outputs.loss
        ```N)rS   r   r   r   r   r   z!One must provided a boolean mask rL   TrI   r   r   r   )rJ   rx   rg      r+         rv   )rw   keepdim)rw   Zunbiasedr   gư>zQCan't unnormalize non-RGB images. Consider setting config.norm_pix_loss to False.r*   r   r   r   )%rD   r   r   r   rP   rh   rC   expandrN   rM   rO   rJ   rQ   r$   catr   r   r   Zno_gradrW   rx   Z	as_tensorr   r   r^   r\   Znorm_pix_lossr   ri   r}   r   varsqrtr   r)   r   r   )#rF   rR   rS   r   r   r   r   r   r   rU   Zseq_lenrW   Zexpanded_position_embeddingsZpos_emb_visibleZpos_emb_maskZx_fullZdecoder_outputsr   r*   framesrJ   rx   r   r   timerl   rm   r^   r\   Zframes_normZvideos_patchrV   labelsloss_fctr   r'   r'   r(   rX     s   %	&&

J
zVideoMAEForPreTraining.forward)NNNN)r    r!   r"   r>   r   r$   r%   r   r   r   r   r   r   r)   rX   rZ   r'   r'   rG   r(   r     s,    
r   z
    VideoMAE Model transformer with a video classification head on top (a linear layer on top of the average pooled hidden
    states of all tokens) e.g. for ImageNet.
    c                       sz   e Zd Z fddZe						ddeej deej deej dee dee d	ee d
e	e
ef fddZ  ZS )VideoMAEForVideoClassificationc                    sf   t  | |j| _t|| _|jrt|jnd | _	|jdkr(t
|j|jnt | _|   d S )Nr   )r=   r>   
num_labelsr   r   r   r	   r   rB   fc_normr   r   
classifierr   rE   rG   r'   r(   r>     s   
$z'VideoMAEForVideoClassification.__init__NrR   r   r  r   r   r   r   c                 C   s  |dur|n| j j}| j|||||d}|d }| jdur&| |d}n|dddf }| |}	d}
|dur| j jdu rc| jdkrId| j _n| jdkr_|jt	j
ksZ|jt	jkr_d| j _nd| j _| j jdkrt }| jdkr{||	 | }
n+||	|}
n%| j jdkrt }||	d| j|d}
n| j jdkrt }||	|}
|s|	f|dd  }|
dur|
f| S |S t|
|	|j|jd	S )
a!  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).

        Examples:

        ```python
        >>> import av
        >>> import torch
        >>> import numpy as np

        >>> from transformers import AutoImageProcessor, VideoMAEForVideoClassification
        >>> from huggingface_hub import hf_hub_download

        >>> np.random.seed(0)


        >>> def read_video_pyav(container, indices):
        ...     '''
        ...     Decode the video with PyAV decoder.
        ...     Args:
        ...         container (`av.container.input.InputContainer`): PyAV container.
        ...         indices (`List[int]`): List of frame indices to decode.
        ...     Returns:
        ...         result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
        ...     '''
        ...     frames = []
        ...     container.seek(0)
        ...     start_index = indices[0]
        ...     end_index = indices[-1]
        ...     for i, frame in enumerate(container.decode(video=0)):
        ...         if i > end_index:
        ...             break
        ...         if i >= start_index and i in indices:
        ...             frames.append(frame)
        ...     return np.stack([x.to_ndarray(format="rgb24") for x in frames])


        >>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
        ...     '''
        ...     Sample a given number of frame indices from the video.
        ...     Args:
        ...         clip_len (`int`): Total number of frames to sample.
        ...         frame_sample_rate (`int`): Sample every n-th frame.
        ...         seg_len (`int`): Maximum allowed index of sample's last frame.
        ...     Returns:
        ...         indices (`List[int]`): List of sampled frame indices
        ...     '''
        ...     converted_len = int(clip_len * frame_sample_rate)
        ...     end_idx = np.random.randint(converted_len, seg_len)
        ...     start_idx = end_idx - converted_len
        ...     indices = np.linspace(start_idx, end_idx, num=clip_len)
        ...     indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
        ...     return indices


        >>> # video clip consists of 300 frames (10 seconds at 30 FPS)
        >>> file_path = hf_hub_download(
        ...     repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset"
        ... )
        >>> container = av.open(file_path)

        >>> # sample 16 frames
        >>> indices = sample_frame_indices(clip_len=16, frame_sample_rate=1, seg_len=container.streams.video[0].frames)
        >>> video = read_video_pyav(container, indices)

        >>> image_processor = AutoImageProcessor.from_pretrained("MCG-NJU/videomae-base-finetuned-kinetics")
        >>> model = VideoMAEForVideoClassification.from_pretrained("MCG-NJU/videomae-base-finetuned-kinetics")

        >>> inputs = image_processor(list(video), return_tensors="pt")

        >>> with torch.no_grad():
        ...     outputs = model(**inputs)
        ...     logits = outputs.logits

        >>> # model predicts one of the 400 Kinetics-400 classes
        >>> predicted_label = logits.argmax(-1).item()
        >>> print(model.config.id2label[predicted_label])
        eating spaghetti
        ```Nr   r   r   Z
regressionZsingle_label_classificationZmulti_label_classificationrL   r   )rD   r   r   r	  r   r
  Zproblem_typer  rx   r$   longrc   r   Zsqueezer   r   r
   r   r   r   )rF   rR   r   r  r   r   r   r   r   r   r*   r  r   r'   r'   r(   rX     sR   \



"


z&VideoMAEForVideoClassification.forward)NNNNNN)r    r!   r"   r>   r   r   r$   r   r   r   r   r   rX   rZ   r'   r'   rG   r(   r    s0    
r  )r   r   r   r  )rn   )Cr#   collections.abcr`   rK   r   dataclassesr   typingr   r   r   r   r   numpyr,   r$   Ztorch.utils.checkpointr	   Ztorch.nnr
   r   r   Zactivationsr   Zmodeling_outputsr   r   Zmodeling_utilsr   r   Zpytorch_utilsr   r   utilsr   r   r   Zutils.constantsr   r   Zconfiguration_videomaer   Z
get_loggerr    r   r   r)   r9   Moduler;   r?   r   floatr   r   r   r   r   r   r   r   r   r   r   r   r  __all__r'   r'   r'   r(   <module>   s   
!=
J(+3 D 2 !