o
    ZhR                     @   s  d Z ddlZddlZddlmZ ddlmZ ddlm	Z	m
Z
mZ ddlZddlZddlmZ ddlmZmZmZ dd	lmZ dd
lmZmZ ddlmZ ddlmZmZ ddlmZmZm Z m!Z!m"Z" ddl#m$Z$ e!%e&Z'dZ(dZ)eG dd deZ*eG dd deZ+eG dd deZ,dMddZ-dNddZ.dOdd Z/G d!d" d"ej0Z1G d#d$ d$ej0Z2G d%d& d&ej0Z3G d'd( d(ej0Z4G d)d* d*ej0Z5G d+d, d,ej0Z6G d-d. d.ej0Z7G d/d0 d0ej0Z8G d1d2 d2ej0Z9G d3d4 d4ej0Z:G d5d6 d6ej0Z;G d7d8 d8eZ<d9Z=d:Z>ed;e=G d<d= d=e<Z?G d>d? d?ej0Z@ed@e=G dAdB dBe<ZAG dCdD dDej0ZBG dEdF dFej0ZCG dGdH dHej0ZDedIe=G dJdK dKe<ZEg dLZFdS )PzPyTorch TVLT model.    N)deepcopy)	dataclass)OptionalTupleUnion)nn)BCEWithLogitsLossCrossEntropyLossMSELoss   )ACT2FN)BaseModelOutputSequenceClassifierOutput)PreTrainedModel) find_pruneable_heads_and_indicesprune_linear_layer)ModelOutputadd_start_docstrings%add_start_docstrings_to_model_forwardloggingreplace_return_docstrings   )
TvltConfigr   zZinengTang/tvlt-basec                   @   s   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eej ed< dZeej ed< dZeej ed< dZeej ed< dZeej ed	< dZeeejd
f  ed< dZeeejd
f  ed< dS )TvltModelOutputa  
    Class for TvltModel's outputs, with potential hidden states and attentions.

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        last_pixel_hidden_state (`torch.FloatTensor` of shape `(batch_size, pixel_sequence_length, hidden_size)`):
            Pixel sequence of hidden-states at the output of the last layer of the model.
        last_audio_hidden_state (`torch.FloatTensor` of shape `(batch_size, audio_sequence_length, hidden_size)`):
            Audio sequence of hidden-states at the output of the last layer of the model.
        pixel_label_masks (`torch.FloatTensor` of shape `(batch_size, pixel_patch_length)`):
            Tensor indicating which pixel patches are masked (1) and which are not (0).
        audio_label_masks (`torch.FloatTensor` of shape `(batch_size, audio_patch_length)`):
            Tensor indicating which audio patches are masked (1) and which are not (0).
        pixel_ids_restore (`torch.LongTensor` of shape `(batch_size, pixel_patch_length)`):
            Tensor containing the ids permutation of pixel masking.
        audio_ids_restore (`torch.LongTensor` of shape `(batch_size, audio_patch_length)`):
            Tensor containing the ids permutation of audio masking.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings and one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
            plus the initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
            the self-attention heads.
    Nlast_hidden_statelast_pixel_hidden_statelast_audio_hidden_statepixel_label_masksaudio_label_maskspixel_ids_restoreaudio_ids_restore.hidden_states
attentions)__name__
__module____qualname____doc__r   r   torchFloatTensor__annotations__r   r   r   
LongTensorr   r   r    r!   r   r"    r+   r+   `/var/www/auris/lib/python3.10/site-packages/transformers/models/deprecated/tvlt/modeling_tvlt.pyr   0   s   
 r   c                   @   sX   e Zd ZU dZdZeej ed< dZ	ee
ejdf  ed< dZee
ejdf  ed< dS )TvltDecoderOutputaM  
    Class for TvltDecoder's outputs, with potential hidden states and attentions.

    Args:
        logits (`torch.FloatTensor` of shape `(batch_size, patch_size ** 2 * num_channels)`):
            Pixel reconstruction logits.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings and one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
            plus the initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
            the self-attention heads.
    Nlogits.r!   r"   )r#   r$   r%   r&   r.   r   r'   r(   r)   r!   r   r"   r+   r+   r+   r,   r-   Y   s
   
 r-   c                   @   s   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eej ed< dZeej ed< dZeeejdf  ed< dZeeejdf  ed	< dS )
TvltForPreTrainingOutputa
  
    Class for TvltForPreTraining's outputs, with potential hidden states and attentions.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`):
            Pixel reconstruction loss.
        matching_logits (`torch.FloatTensor` of shape `(batch_size, 1)`):
            Matching objective logits.
        pixel_logits (`torch.FloatTensor` of shape
            `(batch_size, pixel_patch_length, image_patch_size ** 3 * pixel_num_channels)`): Pixel reconstruction
            logits.
        audio_logits (`torch.FloatTensor` of shape
            `(batch_size, audio_patch_length, image_patch_size[0] * image_patch_size[1])`): Audio reconstruction
            logits.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings and one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
            plus the initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
            the self-attention heads.
    Nlossmatching_logitspixel_logitsaudio_logits.r!   r"   )r#   r$   r%   r&   r0   r   r'   r(   r)   r1   r2   r3   r!   r   r"   r+   r+   r+   r,   r/   p   s   
 r/         ?c                 C   s>   | j dd \}}tj||f| jd}t|d|  }||fS )!Generate noise for audio masking.N   devicer   )shaper'   randr8   int)pixel_values
pixel_mask
mask_ratio
batch_sizeseq_lennoiselen_keepr+   r+   r,   generate_pixel_mask_noise   s   rC   patch-level   c           
      C   s   | j dd \}}|dkr'|| }tj||| jdddd|||}n|dkr4tj||| jd}t|d|  }	||	fS )r5   Nr6   zframe-levelr7   r   rD   )r9   r'   r:   r8   	unsqueezerepeatviewr;   )
audio_values
audio_maskr>   	mask_typefreq_lenr?   r@   num_time_patchesrA   rB   r+   r+   r,   generate_audio_mask_noise   s   
rO   c                 C   s   | j \}}}tj|dd}tj|dd}|ddd|f }	tj| d|	ddd|d}
tj||g| jd}d|ddd|f< tj|d|d}|durZ||9 }tj|d|	d}|
|||fS )z
    Perform random masking by per-sample shuffling on frame-level. Per-sample shuffling is done by argsort random
    noise. sequence: [batch_size, seq_len, hidden_dim], sequence
    r   dimNrF   rQ   indexr7   r   )r9   r'   ZargsortgatherrG   rH   Zonesr8   )sequencerA   rB   attention_masksr?   r@   Z
hidden_dimZids_shuffleids_restoreZids_keepZsequence_maskedZlabel_masksr+   r+   r,   random_masking   s    rX   c                       *   e Zd ZdZ fddZdddZ  ZS )TvltPixelEmbeddings,Construct the patch and position embeddings.c                    st   t    t|| _| jj| _ttdd|j	| _
ttd|j|j	| _ttd| j|j	| _|| _d S Nr   )super__init__TvltPixelPatchEmbeddingspatch_embeddingsnum_patches_per_imager   	Parameterr'   zeroshidden_sizetype_embed_v
num_framestemporal_embedpos_embed_vconfigselfri   	__class__r+   r,   r^      s   



zTvltPixelEmbeddings.__init__Nc           	      C   sh   |j \}}}}}| |}|| jd|d7 }|tj| jd d d |f | jdd7 }|| j7 }||fS Nr   rP   )	r9   r`   rh   rH   r'   repeat_interleaverg   ra   re   )	rk   r<   rV   r?   rf   num_channelsheightwidth
embeddingsr+   r+   r,   forward   s   
(
zTvltPixelEmbeddings.forwardNr#   r$   r%   r&   r^   rt   __classcell__r+   r+   rl   r,   rZ      s    rZ   c                       rY   )TvltAudioEmbeddingsr[   c                    s   t    t|| _| jj| _ttdd|j	| _
|j|jd  | _ttd| j| j |j	| _ttd| j|j	| _|j|jd  | _|| _d S r\   )r]   r^   TvltAudioPatchEmbeddingsr`   num_patchesr   rb   r'   rc   rd   type_embed_afrequency_lengthaudio_patch_sizenum_freq_patchespos_embed_a
freq_embedri   rj   rl   r+   r,   r^      s   


 
zTvltAudioEmbeddings.__init__Nc                 C   sh   |  |}|d| j }|| jd|d7 }|tj| jd d d |f | jdd7 }|| j7 }||fS rn   )	r`   sizer~   r   rH   r'   ro   r   r{   )rk   rJ   rV   rs   rN   r+   r+   r,   rt      s   
(
zTvltAudioEmbeddings.forwardru   rv   r+   r+   rl   r,   rx      s    rx   c                       6   e Zd ZdZ fddZdejdejfddZ  ZS )r_   z
    This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
    `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
    Transformer.
    c                    s   t    |j|j}}|j|j}}t|tjj	r|n||f}t|tjj	r)|n||f}|d |d  |d |d   }|| _|| _
|| _|| _|| _tj||||d| _d S Nr   r   )Zkernel_sizeZstride)r]   r^   
image_sizeimage_patch_sizenum_image_channelsrd   
isinstancecollectionsabcIterable
patch_sizerp   ra   r   Conv2d
projection)rk   ri   r   r   rp   rd   ra   rl   r+   r,   r^   	  s   
 z!TvltPixelPatchEmbeddings.__init__r<   returnc              
   C   s   |j \}}}}}|| jkrtd|| jd ks|| jd kr6td| d| d| jd  d| jd  d	||| |||}| |ddd}|||| j | j	}|S )	NeMake sure that the channel dimension of the pixel values match with the one set in the configuration.r   r   zInput image size (*) doesn't match model ().r6   )
r9   rp   
ValueErrorr   reshaper   flatten	transposera   rd   )rk   r<   r?   rf   rp   rq   rr   rs   r+   r+   r,   rt     s   
(z TvltPixelPatchEmbeddings.forward	r#   r$   r%   r&   r^   r'   Tensorrt   rw   r+   r+   rl   r,   r_     s    r_   c                       r   )ry   z
    This class turns `audio_values` of shape `(batch_size, num_channels, height, width)` into the initial
    `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
    Transformer.
    c           
         s   t    |j|j|j}}}|j|j}}||f}t|tj	j
r$|n||f}|d |d  |d |d   }|d |d  |d |d  f}	|| _|| _|| _|| _|	| _tj||||d| _d S r   )r]   r^   spectrogram_lengthr|   r}   num_audio_channelsrd   r   r   r   r   spectrogram_sizer   rp   rz   patch_shaper   r   r   )
rk   ri   r   r|   r   rp   rd   r   rz   r   rl   r+   r,   r^   2  s    

  z!TvltAudioPatchEmbeddings.__init__rJ   r   c              
   C   s   |j \}}}}|| jkrtd|| jd ks|| jd kr5td| d| d| jd  d| jd  d	| |ddd}|S )	Nr   r   r   zInput audio size (r   r   r   r6   )r9   rp   r   r   r   r   r   )rk   rJ   r?   rp   rq   rr   rs   r+   r+   r,   rt   G  s   
z TvltAudioPatchEmbeddings.forwardr   r+   r+   rl   r,   ry   +  s    ry   c                       .   e Zd Z fddZdd Zd	ddZ  ZS )
TvltSelfAttentionc                    s   t    |j|j dkrt|dstd|j d|j d|j| _t|j|j | _| j| j | _t	j
|j| j|jd| _t	j
|j| j|jd| _t	j
|j| j|jd| _t	|j| _d S )Nr   Zembedding_sizezThe hidden size z4 is not a multiple of the number of attention heads .bias)r]   r^   rd   num_attention_headshasattrr   r;   attention_head_sizeall_head_sizer   LinearZqkv_biasquerykeyvalueDropoutZattention_probs_dropout_probdropoutrj   rl   r+   r,   r^   X  s   

zTvltSelfAttention.__init__c                 C   s6   |  d d | j| jf }|j| }|ddddS )NrF   r   r6   r      )r   r   r   rI   permute)rk   xZnew_x_shaper+   r+   r,   transpose_for_scoresj  s   
z&TvltSelfAttention.transpose_for_scoresNFc                 C   s   |  |}| | |}| | |}| |}t||dd}	|	t| j	 }	|d ur4|	| }	t
jdd|	}
| |
}
|d urI|
| }
t|
|}|dddd }| d d | jf }|j| }|rr||
f}|S |f}|S )NrF   rP   r   r6   r   r   )r   r   r   r   r'   matmulr   mathsqrtr   r   ZSoftmaxr   r   
contiguousr   r   rI   )rk   r!   attention_mask	head_maskoutput_attentionsZmixed_query_layerZ	key_layerZvalue_layerZquery_layerZattention_scoresZattention_probsZcontext_layerZnew_context_layer_shapeoutputsr+   r+   r,   rt   o  s(   



zTvltSelfAttention.forwardNNF)r#   r$   r%   r^   r   rt   rw   r+   r+   rl   r,   r   W  s    r   c                       sF   e Zd ZdZdeddf fddZdejdejdejfd	d
Z  Z	S )TvltSelfOutputz
    The residual connection is defined in TvltLayer instead of here (as is the case with other models), due to the
    layernorm applied before each block.
    ri   r   Nc                    s.   t    t|j|j| _t|j| _d S ru   )	r]   r^   r   r   rd   denser   hidden_dropout_probr   rj   rl   r+   r,   r^        
zTvltSelfOutput.__init__r!   input_tensorc                 C      |  |}| |}|S ru   r   r   rk   r!   r   r+   r+   r,   rt        

zTvltSelfOutput.forward)
r#   r$   r%   r&   r   r^   r'   r   rt   rw   r+   r+   rl   r,   r     s    $r   c                       r   )
TvltAttentionc                    s*   t    t|| _t|| _t | _d S ru   )r]   r^   r   	attentionr   outputsetpruned_headsrj   rl   r+   r,   r^     s   


zTvltAttention.__init__c                 C   s   t |dkrd S t|| jj| jj| j\}}t| jj|| j_t| jj|| j_t| jj	|| j_	t| j
j|dd| j
_| jjt | | j_| jj| jj | j_| j|| _d S )Nr   r   rP   )lenr   r   r   r   r   r   r   r   r   r   r   r   union)rk   headsrS   r+   r+   r,   prune_heads  s   zTvltAttention.prune_headsNFc                 C   s6   |  ||||}| |d |}|f|dd   }|S )Nr   r   )r   r   )rk   r!   r   r   r   Zself_outputsattention_outputr   r+   r+   r,   rt     s   zTvltAttention.forwardr   )r#   r$   r%   r^   r   rt   rw   r+   r+   rl   r,   r     s    r   c                       s<   e Zd Zdeddf fddZdejdejfddZ  ZS )	TvltIntermediateri   r   Nc                    sD   t    t|j|j| _t|jt	rt
|j | _d S |j| _d S ru   )r]   r^   r   r   rd   intermediate_sizer   r   Z
hidden_actstrr   intermediate_act_fnrj   rl   r+   r,   r^     s
   
zTvltIntermediate.__init__r!   c                 C   r   ru   )r   r   rk   r!   r+   r+   r,   rt     r   zTvltIntermediate.forward	r#   r$   r%   r   r^   r'   r   rt   rw   r+   r+   rl   r,   r     s    r   c                       sB   e Zd Zdeddf fddZdejdejdejfdd	Z  ZS )

TvltOutputri   r   Nc                    s.   t    t|j|j| _t|j| _	d S ru   )
r]   r^   r   r   r   rd   r   r   r   r   rj   rl   r+   r,   r^     r   zTvltOutput.__init__r!   r   c                 C   s    |  |}| |}|| }|S ru   r   r   r+   r+   r,   rt     s   

zTvltOutput.forwardr   r+   r+   rl   r,   r     s    $r   c                       s*   e Zd ZdZ fddZdddZ  ZS )		TvltLayerz?This corresponds to the Block class in the timm implementation.c                    sb   t    |j| _d| _t|| _t|| _t|| _	t
j|j|jd| _t
j|j|jd| _d S Nr   Zeps)r]   r^   Zchunk_size_feed_forwardZseq_len_dimr   r   r   intermediater   r   r   	LayerNormrd   layer_norm_epslayernorm_beforelayernorm_afterrj   rl   r+   r,   r^     s   



zTvltLayer.__init__NFc           	      C   sj   | j | ||||d}|d }|dd  }|||j }| |}| |}| ||}|f| }|S )Nr   r   r   )r   r   tor8   r   r   r   )	rk   r!   r   r   r   Zself_attention_outputsr   r   Zlayer_outputr+   r+   r,   rt     s   


zTvltLayer.forwardr   rv   r+   r+   rl   r,   r     s    
r   c                       s0   e Zd Z fddZ					dddZ  ZS )	TvltEncoderc                    s:   t     | _t fddt jD | _d| _d S )Nc                       g | ]}t  qS r+   r   .0_ri   r+   r,   
<listcomp>      z(TvltEncoder.__init__.<locals>.<listcomp>F)	r]   r^   ri   r   
ModuleListrangenum_hidden_layerslayergradient_checkpointingrj   rl   r   r,   r^     s   
 
zTvltEncoder.__init__NFTc                 C   s   |rdnd }|r
dnd }t | jD ]:\}	}
|r||f }|d ur$||	 nd }| jr7| jr7| |
j||||}n|
||||}|d }|rK||d f }q|rS||f }|satdd |||fD S t|||dS )Nr+   r   r   c                 s       | ]	}|d ur|V  qd S ru   r+   r   vr+   r+   r,   	<genexpr>9      z&TvltEncoder.forward.<locals>.<genexpr>)r   r!   r"   )	enumerater   r   training_gradient_checkpointing_func__call__tupler   )rk   r!   r   r   r   output_hidden_statesreturn_dictall_hidden_statesall_self_attentionsilayer_moduleZlayer_head_masklayer_outputsr+   r+   r,   rt     s8   	

zTvltEncoder.forward)NNFFTr#   r$   r%   r^   rt   rw   r+   r+   rl   r,   r     s    	r   c                   @   s(   e Zd ZdZeZdZdZdZdd Z	dS )TvltPreTrainedModelz
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    tvltr<   Tc                 C   st   t |tjtjfr#|jjjd| jjd |j	dur!|j	j
  dS dS t |tjr8|j	j
  |jjd dS dS )zInitialize the weights        )meanZstdNg      ?)r   r   r   r   weightdataZnormal_ri   Zinitializer_ranger   Zzero_r   Zfill_)rk   moduler+   r+   r,   _init_weightsL  s   
z!TvltPreTrainedModel._init_weightsN)
r#   r$   r%   r&   r   config_classZbase_model_prefixZmain_input_nameZsupports_gradient_checkpointingr  r+   r+   r+   r,   r   A  s    r   aF  
    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
    as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
    behavior.

    Parameters:
        config ([`TvltConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
a	  
    Args:
        pixel_values (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
            Pixel values. Pixel values can be obtained using [`TvltProcessor`]. See [`TvltProcessor.__call__`] for
            details.

        audio_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Audio values. Audio values can be obtained using [`TvltProcessor`]. See [`TvltProcessor.__call__`] for
            details.

        pixel_mask (`torch.FloatTensor` of shape `(batch_size, num_pixel_patches)`):
            Pixel masks. Pixel masks can be obtained using [`TvltProcessor`]. See [`TvltProcessor.__call__`] for
            details.

        audio_mask (`torch.FloatTensor` of shape `(batch_size, num_audio_patches)`):
            Audio masks. Audio masks can be obtained using [`TvltProcessor`]. See [`TvltProcessor.__call__`] for
            details.

        pixel_values_mixed (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
            Pixel values that mix positive and negative samples in Tvlt vision-audio matching. Pixel values mixed can
            be obtained using [`TvltProcessor`]. See [`TvltProcessor.__call__`] for details.

        pixel_mask_mixed (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Pixel masks of pixel_values_mixed. Pixel masks mixed can be obtained using [`TvltProcessor`]. See
            [`TvltProcessor.__call__`] for details.

        mask_pixel (`bool`, *optional*):
            Whether to mask pixel for MAE tasks. Only set to True in TvltForPreTraining.

        mask_audio (`bool`, *optional*):
            Whether to mask audio for MAE tasks. Only set to True in TvltForPreTraining.

        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.

        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.

        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
z^The bare TVLT Model transformer outputting raw hidden-states without any specific head on top.c                       s   e Zd Z fddZdd Zdd Zeeee	e
d									dd
ejdejdeej deej dededee dee dee deeej e	f fddZ  ZS )	TvltModelc                    sv   t  | || _t|| _t|| _t|| _t	
tdd|j| _|jr+d | _n
t	j|j|jd| _|   d S r   )r]   r^   ri   rZ   pixel_embeddingsrx   audio_embeddingsr   encoderr   rb   r'   rc   rd   cls_embeddingZuse_mean_pooling	layernormr   r   	post_initrj   rl   r+   r,   r^     s   


zTvltModel.__init__c                 C   s   | j j| jjfS ru   )r  r`   r	  )rk   r+   r+   r,   get_input_embeddings  s   zTvltModel.get_input_embeddingsc                 C   s*   |  D ]\}}| jj| j| qdS )z
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        N)itemsr
  r   r   r   )rk   Zheads_to_pruner   r   r+   r+   r,   _prune_heads  s   zTvltModel._prune_headsoutput_typer  NFr<   rJ   r=   rK   
mask_pixel
mask_audior   r   r   r   c
                 C   s  |dur|n| j j}|dur|n| j j}|	dur|	n| j j}	| ||\}
}| ||\}}d}d}|rKt|
|| j jd\}}t|
|||d\}
}}}d}d}|rv| j j	| j j
d  }t||| j j| j j|d\}}t||||d\}}}}|d}t| j|dd|
|gd}|
d}d}|dur|durt|ddddf ||gd}| }d}|dur| ||}| j|||||	d}|d }| jdur| |}|dddd| f }|ddd| df }|	s|||||||f|dd  S t||||||||j|jd	S )	a  
        Returns:

        Examples:

        ```python
        >>> from transformers import TvltProcessor, TvltModel
        >>> import numpy as np
        >>> import torch

        >>> num_frames = 8
        >>> images = list(np.random.randn(num_frames, 3, 224, 224))
        >>> audio = list(np.random.randn(10000))

        >>> processor = TvltProcessor.from_pretrained("ZinengTang/tvlt-base")
        >>> model = TvltModel.from_pretrained("ZinengTang/tvlt-base")

        >>> input_dict = processor(images, audio, sampling_rate=44100, return_tensors="pt")

        >>> outputs = model(**input_dict)
        >>> loss = outputs.loss
        ```N)r=   r>   )rV   r   )rK   r>   rL   rM   r   )r   r   r   r   )	r   r   r   r   r   r   r    r!   r"   )ri   r   r   use_return_dictr  r	  rC   Zpixel_mask_ratiorX   r|   r}   rO   Zaudio_mask_ratioZaudio_mask_typer   r'   catr  rH   Zget_extended_attention_maskr
  r  r   r!   r"   )rk   r<   rJ   r=   rK   r  r  r   r   r   Zpixel_embedding_outputZaudio_embedding_outputr   r   Zpixel_mask_noiseZpixel_len_keepr   r    r~   Zaudio_mask_noiseZaudio_len_keepr?   Zembedding_outputZmasked_pixel_lenr   Zinput_shapeZextended_attention_maskZencoder_outputssequence_outputpixel_sequence_outputaudio_sequence_outputr+   r+   r,   rt     s   %




"



zTvltModel.forward)NNFFNNN)r#   r$   r%   r^   r  r  r   TVLT_INPUTS_DOCSTRINGr   r   _CONFIG_FOR_DOCr'   r(   r   boolr   r   rt   rw   r+   r+   rl   r,   r    sD    
	
r  c                       s,   e Zd Z fddZ			dddZ  ZS )TvltDecoderc                    sv   t    t| |j _|j _|j _|j	 _
t fddt|jD | _tj|j|jd| _d| _|| _d S )Nc                    r   r+   r   r   Zdecoder_configr+   r,   r   B  r   z(TvltDecoder.__init__.<locals>.<listcomp>r   F)r]   r^   r   decoder_hidden_sizerd   Zdecoder_num_hidden_layersr   Zdecoder_num_attention_headsr   Zdecoder_intermediate_sizer   r   r   r   decoder_layersr   r   r  r   ri   rj   rl   r  r,   r^   9  s   

zTvltDecoder.__init__FTc                 C   s   |rdnd }|r
dnd }t | jD ].\}}|r||f }| jr,| jr,| |j|d |}	n|||d}	|	d }|r?||	d f }q|rG||f }| |}
|sZtdd |
||fD S t|
||dS )Nr+   r   r   r   c                 s   r   ru   r+   r   r+   r+   r,   r   n  r   z&TvltDecoder.forward.<locals>.<genexpr>)r.   r!   r"   )	r   r   r   r   r   r   r  r   r-   )rk   r!   r   r   r   r   r   r   r   r   r.   r+   r+   r,   rt   J  s.   


zTvltDecoder.forward)FFTr   r+   r+   rl   r,   r  8  s    r  zTThe TVLT Model transformer with the decoder on top for self-supervised pre-training.c                       s   e Zd Z fddZdd Zdd Zdd Zd	d
 Zdd Ze	e
eeed								ddejdejdeej deej deej deej deej dee dee dee deeej ef fddZ  ZS )TvltForPreTrainingc           	         s  t  | || _|j| _|j| _| js| jstdt|| _| jr(t|| _	| jrt
j|j|jdd| _t
tdd|j| _t
tdd|j| _t|| _|j}|j}| jjj}t
td||| _t
td|j|| _t
tdd|| _| jjj}|j|jd  }t
td|| || _ t
td||| _!t
tdd|| _"| jj#d d | jj$ }t%||| _&| jjd | jjd  | jj' }t%||| _(|| _|| _|| _)|j#| _#|j| _| *  d S )Nz;Must set at least one of matching task and MAE task to trueTr   r   r   r6   )+r]   r^   ri   task_matchingtask_maer   r  r   TvltMatchingHeadmatching_headr   r   rd   r  encoder_to_decoderrb   r'   rc   pixel_mask_tokenaudio_mask_tokenr  decoderrf   r  ra   decoder_pixel_pos_embeddecoder_temporal_embeddecoder_pixel_type_embedr	  rz   r|   r}   decoder_audio_pos_embeddecoder_freq_embeddecoder_audio_type_embedr   r   TvltMAEHeadpixel_mae_headr   audio_mae_headr~   r  )	rk   ri   r  rf   ra   Znum_audio_patchesr~   Zpixel_mae_output_dimZaudio_mae_output_dimrl   r+   r,   r^   w  sL   




zTvltForPreTraining.__init__c           
   	   C   s   |j \}}}}}|j d | jd  }|j d | jd  }|j||||| jd || jd fd}	td|	}	|	j||| | | jd | jd  | fd}	|	S )zJ
        pixel_values: [batch_size, num_frames, 3, height, width]
        r   r   r   r   r9   zntchpwq->nthwpqc)r9   r   r   r'   einsum)
rk   r<   r?   rf   rp   rq   rr   num_patches_heightnum_patches_widthpatchified_pixel_valuesr+   r+   r,   patchify_pixel  s*   
z!TvltForPreTraining.patchify_pixelc           	      C   s   |j \}}}}|| jd  }|| jd  }|j|||| jd || jd fd}td|}|j||| | jd | jd  | fd}|S )z>
        audio_values: [batch_size, 1, height, width]
        r   r   r3  znchpwq->nhwpqc)r9   r}   r   r'   r4  )	rk   rJ   r?   rp   rq   rr   r5  r6  patchified_audio_valuesr+   r+   r,   patchify_audio  s(   
z!TvltForPreTraining.patchify_audioc                 C   :   |  |}|| d }|jdd}||  |  }|S Nr6   rF   rP   )r8  r  sum)rk   r<   Zpixel_predictionsmaskr7  r0   r+   r+   r,   pixel_mae_loss  
   
z!TvltForPreTraining.pixel_mae_lossc                 C   r;  r<  )r:  r  r=  )rk   rJ   Zaudio_predictionsr>  r9  r0   r+   r+   r,   audio_mae_loss  r@  z!TvltForPreTraining.audio_mae_lossc           	      C   sZ   |j \}}}|||j d | d}tj||gdd}tj|d|ddd|d}|S )Nr   rP   rF   rR   )r9   rH   r'   r  rT   rG   )	rk   Z
mask_tokenrU   rW   r?   Z
seq_lengthrQ   Zmask_tokensZpadded_sequencer+   r+   r,   concatenate_mask  s   z#TvltForPreTraining.concatenate_maskr  Nr<   rJ   r=   rK   labelspixel_values_mixedpixel_mask_mixedr   r   r   r   c                  C   s  |
dur|
n| j j}
d}| jrF|du rtd|du rtd| j||||||	|
d}|d }| |}t }||d|d}||7 }d}d}| jr+| j	r+| j||||dd||	|
d		}|
re|j
n|d
 }|
rn|jn|d }|
rw|jn|d }|
r|jn|d }|
r|jn|d }|
r|jn|d }| |}| |}|d
}| | j||}|| jd
|d
 }|tj| jddd|f | jd
d }|| j }| |}| |j}| | j||}|d
| j }|| j d
|d
 }|tj| j!ddd|f | jd
d }|| j" }| |}| #|j}| $|||| %||| }||7 }|
sE|||f|dd  }|durC|f| S |S t&|||||j'|j(dS )aF  
        pixel_values_mixed (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
            Pixel values that mix positive and negative samples in Tvlt vision-audio matching. Audio values can be
            obtained using [`TvltProcessor`]. See [`TvltProcessor.__call__`] for details.

        pixel_mask_mixed (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Pixel masks of pixel_values_mixed. Pixel values mixed can be obtained using [`TvltProcessor`]. See
            [`TvltProcessor.__call__`] for details.

        labels (`torch.LongTensor` of shape `(batch_size, num_labels)`, *optional*):
            Labels for computing the vision audio matching loss. Indices should be in `[0, 1]`. num_labels has to be 1.

        Return:

        Examples:

        ```python
        >>> from transformers import TvltProcessor, TvltForPreTraining
        >>> import numpy as np
        >>> import torch

        >>> num_frames = 8
        >>> images = list(np.random.randn(num_frames, 3, 224, 224))
        >>> images_mixed = list(np.random.randn(num_frames, 3, 224, 224))
        >>> audio = list(np.random.randn(10000))
        >>> processor = TvltProcessor.from_pretrained("ZinengTang/tvlt-base")
        >>> model = TvltForPreTraining.from_pretrained("ZinengTang/tvlt-base")
        >>> input_dict = processor(
        ...     images, audio, images_mixed, sampling_rate=44100, mask_pixel=True, mask_audio=True, return_tensors="pt"
        ... )

        >>> outputs = model(**input_dict)
        >>> loss = outputs.loss
        ```Nr   zMatching task requires labelsz)Matching task requires pixel_values_mixedr=   rK   r   r   r   r   rF   T)r=   rK   r  r  r   r   r   r   r6   r   r         rP      )r0   r1   r2   r3   r!   r"   ))ri   r  r"  r   r   r%  r   rI   r#  r   r   r   r   r   r   r    r&  r   rB  r'  r*  rH   r'   ro   r+  ra   r,  r)  r1  r.   r(  r~   r.  r-  r/  r2  r?  rA  r/   r!   r"   ) rk   r<   rJ   r=   rK   rC  rD  rE  r   r   r   Z
total_lossr   r  r1   loss_fctr0   r2   r3   r  r  r   r   r   r    Zpixel_decoder_inputZaudio_decoder_inputrf   Zpixel_decoder_outputsrN   Zaudio_decoder_outputsr   r+   r+   r,   rt     s   1






zTvltForPreTraining.forward)NNNNNNNN)r#   r$   r%   r^   r8  r:  r?  rA  rB  r   r  r   r/   r  r'   r(   r   r*   r  r   r   rt   rw   r+   r+   rl   r,   r!  r  sP    6	
	
r!  c                       $   e Zd Z fddZdd Z  ZS )
TvltPoolerc                    s*   t    t|j|j| _t | _d S ru   )r]   r^   r   r   rd   r   ZTanh
activationrj   rl   r+   r,   r^     s   
zTvltPooler.__init__c                 C   s(   |d d df }|  |}| |}|S )Nr   )r   rM  )rk   r!   Zfirst_token_tensorZpooled_outputr+   r+   r,   rt     s   

zTvltPooler.forwardr   r+   r+   rl   r,   rL        rL  c                       rK  )r$  c                    s(   t    t|| _t|jd| _d S r\   )r]   r^   rL  poolerr   r   rd   fcrj   rl   r+   r,   r^     s   

zTvltMatchingHead.__init__c                 C   s   |  | |}|S ru   )rP  rO  r   r+   r+   r,   rt     s   zTvltMatchingHead.forwardr   r+   r+   rl   r,   r$    rN  r$  c                       s&   e Zd Zd fdd	Zdd Z  ZS )r0  Nc                    s$   t    || _t|j|| _d S ru   )r]   r^   ri   r   r   r  r)  )rk   ri   Z
output_dimrl   r+   r,   r^     s   
zTvltMAEHead.__init__c                 C   s   |  |}|S ru   )r)  r   r+   r+   r,   rt     s   
zTvltMAEHead.forwardru   r   r+   r+   rl   r,   r0    s    r0  z
    Tvlt Model transformer with a classifier head on top (an MLP on top of the final hidden state of the [CLS] token)
    for audiovisual classification tasks, e.g. CMU-MOSEI Sentiment Analysis and Audio to Video Retrieval.
    c                       s   e Zd Z fddZeeeeed						dde	j
de	j
dee	j
 dee	j
 d	ee d
ee dee dee	j deee	j
 ef fddZ  ZS ) TvltForAudioVisualClassificationc              	      sp   t  | t|| _tt|j|jd tj|jd |j	dt
 t|jd |j| _|| _|   d S )Nr6   r   )r]   r^   r  r   r   Z
Sequentialr   rd   r   r   ZGELUZ
num_labels
classifierri   r  rj   rl   r+   r,   r^     s   
z)TvltForAudioVisualClassification.__init__r  Nr<   rJ   r=   rK   r   r   r   rC  r   c	              	   C   s   |dur|n| j j}| j|||||||d}	|	d dddf }
| |
}d}|durH| j jdkr:t }|||}n| j jdkrHt }|||}|s^|f|	dd  }|dur\|f| S |S t|||	j|	j	dS )a  
        labels (`torch.LongTensor` of shape `(batch_size, num_labels)`, *optional*):
            Labels for computing the audiovisual loss. Indices should be in `[0, ..., num_classes-1]` where num_classes
            refers to the number of classes in audiovisual tasks.

        Return:

        Examples:
        ```python
        >>> from transformers import TvltProcessor, TvltForAudioVisualClassification
        >>> import numpy as np
        >>> import torch

        >>> num_frames = 8
        >>> images = list(np.random.randn(num_frames, 3, 224, 224))
        >>> audio = list(np.random.randn(10000))
        >>> processor = TvltProcessor.from_pretrained("ZinengTang/tvlt-base")
        >>> model = TvltForAudioVisualClassification.from_pretrained("ZinengTang/tvlt-base")
        >>> input_dict = processor(images, audio, sampling_rate=44100, return_tensors="pt")

        >>> outputs = model(**input_dict)
        >>> loss = outputs.loss
        ```NrF  r   Z
regressionZclassificationr   )r0   r.   r!   r"   )
ri   r  r   rR  Z	loss_typer
   r	   r   r!   r"   )rk   r<   rJ   r=   rK   r   r   r   rC  r   r  r.   r0   rJ  r   r+   r+   r,   rt     s:   $	

z(TvltForAudioVisualClassification.forward)NNNNNN)r#   r$   r%   r^   r   r  r   r   r  r'   r(   r   r  r*   r   r   rt   rw   r+   r+   rl   r,   rQ    s:    
	
rQ  )r  r!  rQ  r   )Nr4   )Nr4   rD   rE   ru   )Gr&   collections.abcr   r   copyr   dataclassesr   typingr   r   r   r'   Ztorch.utils.checkpointr   Ztorch.nnr   r	   r
   Zactivationsr   Zmodeling_outputsr   r   Zmodeling_utilsr   Zpytorch_utilsr   r   utilsr   r   r   r   r   Zconfiguration_tvltr   Z
get_loggerr#   loggerr  Z_CHECKPOINT_FOR_DOCr   r-   r/   rC   rO   rX   ModulerZ   rx   r_   ry   r   r   r   r   r   r   r   r   ZTVLT_START_DOCSTRINGr  r  r  r!  rL  r$  r0  rQ  __all__r+   r+   r+   r,   <module>   s   
(
!
	
),<"&5- $:  Y