o
    Zh                     @   s
  d Z ddlZddlmZ ddlmZmZ ddlZddlZddlm	Z	 ddl
mZ ddlmZmZmZ dd	lmZ dd
lmZ ddlmZmZ ddlmZ ddlmZ eeZeG dd deZG dd de	jZ G dd de	jZ!G dd de	jZ"G dd de	jZ#G dd de	jZ$G dd de	jZ%G dd de	jZ&G dd  d e	jZ'G d!d" d"e	jZ(G d#d$ d$e	jZ)eG d%d& d&eZ*G d'd( d(e	jZ+G d)d* d*e	jZ,e+e,d+Z-ed,d-G d.d/ d/e*Z.G d0d1 d1e	jZ/ed2d-G d3d4 d4e*Z0g d5Z1dS )6zPyTorch TVP Model    N)	dataclass)OptionalTuple)nn   )ACT2FN)BaseModelOutputBaseModelOutputWithPoolingModelOutput)PreTrainedModel)prune_linear_layer)auto_docstringlogging)load_backbone   )	TvpConfigc                   @   sj   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eeejdf  ed< dZeeejdf  ed< dS )TvpVideoGroundingOutputa  
    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
            Temporal-Distance IoU loss for video grounding.
        logits (`torch.FloatTensor` of shape `(batch_size, 2)`):
            Contains start_time/duration and end_time/duration. It is the time slot of the videos corresponding to the
            input texts.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of
            the model at the output of each layer plus the optional initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.
    Nlosslogits.hidden_states
attentions)__name__
__module____qualname____doc__r   r   torchFloatTensor__annotations__r   r   r   r    r   r   S/var/www/auris/lib/python3.10/site-packages/transformers/models/tvp/modeling_tvp.pyr   %   s   
 r   c                       s@   e Zd ZdZ fddZdd Zdd Zdd	 Zd
d Z  Z	S )TvpLossa~  
    This class computes the losses for `TvpForVideoGrounding`. The process happens in two steps: 1) we compute
    hungarian assignment between ground truth boxes and the outputs of the model 2) we supervise each pair of matched
    ground-truth / prediction (supervise class and box).

    Args:
        losses (`List[str]`):
            List of all the losses to be applied.
    c                    sL   t    | j| j| jd| _|D ]}|| jvr td| dq|| _d S )NioudistancedurationzLoss z not supported)super__init__loss_iouloss_distanceloss_durationloss_map
ValueErrorlosses)selfr,   r   	__class__r   r   r&   H   s   


zTvpLoss.__init__c           	      C   sH   t ||t || }t ||t || }d|jdd|  }|S )z6
        Measure the intersection over union.
        r   r   min)r   r1   maxclamp)	r-   
start_timeend_timecandidates_start_timecandidates_end_timer$   interunionr"   r   r   r   r'   U   s   zTvpLoss.loss_iouc           	      C   sT   t t ||d}t t ||d}t t ||t || |jdd}|S )z5
        Measure the distance of mid points.
        g       @g?r0   )r   divaddr2   r1   r3   )	r-   r4   r5   r6   r7   r$   Zmid_candidatesZmid_groundtruthZdistance_diffr   r   r   r(   _   s   zTvpLoss.loss_distancec           	      C   sB   t ||}t ||}t t t |||}|jdd}|S )z5
        Measure the difference of duration.
        g?r0   )r   subZsquarer:   r3   )	r-   r4   r5   r6   r7   r$   Zduration_candidatesZduration_groundtruthZduration_diffr   r   r   r)   k   s
   zTvpLoss.loss_durationc              
   C   st   |\}}}t ||}|dddf  |dddf  }}i }	| jD ]}
|	|
| j|
 |||||i q%|	S )am  
        This performs the loss computation.

        Args:
            logits (`torch.FloatTensor`):
                The output logits of head module.
            labels (`List[torch.FloatTensor]`):
                List of tensors ([start, end, duration]), which contains start time, end time of the video corresponding to the text, and also the duration.
        Nr   r   )r   mulfloatr,   updater*   )r-   r   labelsr$   r4   r5   
candidatesr6   r7   Zlosses_dictr   r   r   r   forwardv   s   

*
zTvpLoss.forward)
r   r   r   r   r&   r'   r(   r)   rB   __classcell__r   r   r.   r   r    =   s    

r    c                       $   e Zd Z fddZdd Z  ZS )TvpVisionModelc              	      s   t    t|| _|jd ur|jjd }n,t| jdr+t| jjdr+| jjjd }nt| jdr>t| jjdr>| jjj}nt	dt
j||jdddddd	| _d S )
Nconfighidden_sizeshidden_sizezBackbone config not foundr   r   F)kernel_sizestridepaddinggroupsbias)r%   r&   r   backboneZbackbone_configrH   hasattrrG   rI   r+   r   Conv2dgrid_encoder_conv)r-   rG   Zin_channelsr.   r   r   r&      s$   


zTvpVisionModel.__init__c                 C   s   |j \}}}}}||| |||}| |d d }| |}tjj|ddd}tjj|dd}|j dd  \}	}
}||||	|
|}|ddd	d
d}|S )NZfeature_mapsr      )rJ   rK   T)Zinplacer   r      )	shapeviewrO   rR   r   
functionalZ
max_pool2drelupermute)r-   pixel_values
batch_size
num_framesnum_channelsheightwidthZgrid_feat_outputsgridZnew_channelZ
new_heightZ	new_widthr   r   r   rB      s   
zTvpVisionModel.forwardr   r   r   r&   rB   rC   r   r   r.   r   rE      s    rE   c                       s^   e Zd ZdZ fddZdejdededejfdd	Zdde	fddZ
dde	fddZ  ZS )TvpVisualInputEmbeddingz;
    Takes input of both image and video (multi-frame)
    c                    s   t    t|j|j| _t|j|j| _t|j	|j| _
td|j| _tj|j|jd| _t|j| _|j| _|j	| _	d S )Nr   Zeps)r%   r&   r   	Embeddingmax_position_embeddingsrI   position_embeddings max_grid_row_position_embeddingsrow_position_embeddings max_grid_col_position_embeddingscol_position_embeddingstoken_type_embeddings	LayerNormlayer_norm_eps
layer_normDropouthidden_dropout_probdropoutr-   rG   r.   r   r   r&      s   
z TvpVisualInputEmbedding.__init__	embeddingr_   r`   returnc                 C   sl   d }}|| j kr|| j  }|| jkr|| j }|dddd}tjj|||fddd}|dddd}|S )z
        This method allows to interpolate the pre-trained pad weights , to be able to use the model on collection of high
        resolution images (high resolution videos).

        r   r   r   rS   bicubicFZscale_factormodeZalign_corners)rh   rj   rZ   r   rX   interpolate)r-   rt   r_   r`   h0w0r   r   r   interpolate_pos_encoding   s   



z0TvpVisualInputEmbedding.interpolate_pos_encodingFr|   c                 C   s   |j \}}}}t| j|}tj|tj|jd}| |}	dt|j d  |d|f }
|	j	|
 }	t| j
|}tj|tj|jd}| |}|d||f}|j	| }|	| }|rj|| jks_|| j
krj|| ||| }|S || }|S )af  
        Args:
            grid: (batch_size, height, width, hidden_dim)
            interpolate_pos_encoding: (`bool`, *optional*, defaults to `False`):
                Whether to interpolate the pre-trained position encodings.
        Returns:
            grid + col_position_embeddings.view(*col_shape): (batch_size, *, height, width, hidden_dim)
        dtypedevice)r   r   r   )rV   r1   rh   r   arangelongr   ri   lenrW   rj   rk   r|   )r-   ra   r|   r\   r_   r`   Z
hidden_dim
row_heightZrow_position_idsri   Z	row_shapeZ	row_widthZcol_position_idsrk   Z	col_shapeZpositional_embeddingsr   r   r   add_2d_positional_embeddings   s$   	



z4TvpVisualInputEmbedding.add_2d_positional_embeddingsc                 C   s   |j \}}}}}|d}| j||d}||d|}|j dd }	|j}
tj|	tj|
d}| |}|| }| 	|}| 
|}|S )a  
        Args:
            grid: Array of shape (batch_size, num_frames, height, width, num_channels).
                It contains processed frames extracted from videos, and is generated by Tvp image preprocessor. Note,
                num_frames can be 1
            interpolate_pos_encoding: (bool, *optional*, defaults to `False`):
                Whether to interpolate the pre-trained position encodings.

        Returns:
            embeddings: The embedding of grid with size (batch_size, height*width, num_channels)

        r   r|   rF   Nr}   )rV   meanr   rW   r   r   zerosr   rl   ro   rr   )r-   ra   r|   r\   r]   r_   r`   r^   Zvisual_tokensZvisual_tokens_shaper   token_type_idsrl   
embeddingsr   r   r   rB     s   



zTvpVisualInputEmbedding.forwardF)r   r   r   r   r&   r   Tensorintr|   boolr   rB   rC   r   r   r.   r   rc      s    )rc   c                       s*   e Zd ZdZ fddZdddZ  ZS )TvpTextInputEmbeddingszGConstruct the embeddings from word, position and token_type embeddings.c                    sl   t    tj|j|j|jd| _t|j|j| _	t|j
|j| _tj|j|jd| _t|j| _d S )N)Zpadding_idxrd   )r%   r&   r   re   Z
vocab_sizerI   Zpad_token_idword_embeddingsrf   rg   Ztype_vocab_sizerl   rm   rn   ro   rp   rq   rr   rs   r.   r   r   r&   )  s   
zTvpTextInputEmbeddings.__init__Nc                 C   s   |d ur	|  }n|  d d }|d }|d ur|jn|j}|d u r4tj|tj|d}|d|}|d u rAtj|tj|d}|d u rJ| |}| 	|}| 
|}	|| |	 }
| |
}
| |
}
|
S )NrF   r   r}   r   )sizer   r   r   r   Z	unsqueezeexpandr   r   rg   rl   ro   rr   )r-   	input_idsr   Zposition_idsZinputs_embedsZinput_shapeZ
seq_lengthr   rg   rl   r   r   r   r   rB   1  s$   





zTvpTextInputEmbeddings.forward)NNNNr   r   r   r   r&   rB   rC   r   r   r.   r   r   &  s    r   c                       sV   e Zd Z fddZdd Zdejdedefdd	Z	
	
	
dde	e
 fddZ  ZS )TvpAttentionc                    s   t    |j|j dkrt|dstd|j d|j |j| _t|j|j | _| j| j | _t	
|j| j| _t	
|j| j| _t	
|j| j| _t	|j| _t	
|j|j| _t	j|j|jd| _t	|j| _t | _d S )Nr   Zembedding_sizezThe hidden size z4 is not a multiple of the number of attention heads rd   )r%   r&   rI   num_attention_headsrP   r+   r   attention_head_sizeall_head_sizer   Linearquerykeyvaluerp   Zattention_probs_dropout_probattn_dropoutdenserm   rn   ro   rq   rr   setpruned_headsrs   r.   r   r   r&   K  s    
zTvpAttention.__init__c                    s   t |dkrd S t| j| j}t|| j }|D ]  t fdd| jD   d| < q|d	 
d}tt ||  }t| j|| _t| j|| _t| j|| _t| j|dd| _| jt | | _| j| j | _| j|| _d S )Nr   c                 3   s     | ]}| k r
d ndV  qdS )r   r   Nr   ).0hheadr   r   	<genexpr>g  s    z+TvpAttention.prune_heads.<locals>.<genexpr>rF   r   dim)r   r   onesr   r   r   r   sumrW   
contiguouseqr   r   r   r   r   r   r   r   r9   )r-   headsmaskindexr   r   r   prune_heads`  s    
zTvpAttention.prune_headstensorsequence_lengthr\   c                 C   s    | ||| j| jdd S )Nr   rS   )rW   r   r   	transposer   )r-   r   r   r\   r   r   r   _reshapew  s   zTvpAttention._reshapeNoutput_attentionsc                 C   s   |j d d \}}| |}| |}| |}	| |||}
| |||}| |	||}t|
|dd}|t	| j
 }|d urG|| }tjj|dd}| |}|d ur\|| }t||}|dd }|||| j}| |}| |}| || }|r||f}|S |f}|S )NrS   rF   r   r   )rV   r   r   r   r   r   matmulr   mathsqrtr   r   rX   Zsoftmaxr   r   reshaper   r   rr   ro   )r-   r   attention_mask	head_maskr   r\   r   Zmixed_query_layerZmixed_key_layerZmixed_value_layerZquery_layerZ	key_layerZvalue_layerZattention_scoresZattention_probsZattn_outputoutputsr   r   r   rB   ~  s2   





zTvpAttention.forwardNNN)r   r   r   r&   r   r   r   r   r   r   r   rB   rC   r   r   r.   r   r   J  s    
r   c                       2   e Zd Z fddZdejdejfddZ  ZS )TvpIntermediatec                    sD   t    t|j|j| _t|jt	rt
|j | _d S |j| _d S N)r%   r&   r   r   rI   intermediate_sizer   
isinstanceZ
hidden_actstrr   intermediate_act_fnrs   r.   r   r   r&     s
   
zTvpIntermediate.__init__r   ru   c                 C   s   |  |}| |}|S r   )r   r   )r-   r   r   r   r   rB     s   

zTvpIntermediate.forwardr   r   r   r&   r   r   rB   rC   r   r   r.   r   r     s    r   c                       s8   e Zd Z fddZdejdejdejfddZ  ZS )TvpOutputLayerc                    sB   t    t|j|j| _tj|j|jd| _	t
|j| _d S )Nrd   )r%   r&   r   r   r   rI   r   rm   rn   ro   rp   rq   rr   rs   r.   r   r   r&     s   
zTvpOutputLayer.__init__r   input_tensorru   c                 C   s&   |  |}| |}| || }|S r   )r   rr   ro   )r-   r   r   r   r   r   rB     s   

zTvpOutputLayer.forwardr   r   r   r.   r   r     s    $r   c                       s6   e Zd Z fddZ			ddee fddZ  ZS )TvpEncodeLayerc                    s,   t    t|| _t|| _t|| _d S r   )r%   r&   r   	attentionr   intermediater   outputrs   r.   r   r   r&     s   


zTvpEncodeLayer.__init__Nr   c           
      C   sJ   | j ||||d}|d }|dd  }| |}| ||}	|	f| }|S )N)r   r   r   )r   r   r   )
r-   r   r   r   r   Zself_attention_outputsZattention_outputr   Zintermediate_outputZlayer_outputr   r   r   rB     s   

zTvpEncodeLayer.forwardr   )r   r   r   r&   r   r   rB   rC   r   r   r.   r   r     s    	r   c                
       sT   e Zd Z fddZ					d
deej dee dee dee fdd	Z  Z	S )
TvpEncoderc                    s:   t     | _t fddt jD | _d| _d S )Nc                    s   g | ]}t  qS r   )r   )r   _rG   r   r   
<listcomp>  s    z'TvpEncoder.__init__.<locals>.<listcomp>F)	r%   r&   rG   r   Z
ModuleListrangenum_hidden_layerslayergradient_checkpointingrs   r.   r   r   r&     s   
 
zTvpEncoder.__init__Nr   r   output_hidden_statesreturn_dictc                 C   s  |d ur|n| j j}|d ur|n| j j}|d ur|n| j j}d}d}t| jD ]:\}	}
|r2||f }| jrK| jrK| |
j	|||d urF||	 nd |}n	|
||||	 |}|d }|ra||d f }q'|ri||f }|s~|f}|ru||f }|r|||f }|S t
||r|nd |r|dS d dS )Nr   r   r   )last_hidden_stater   r   )rG   r   r   r   	enumerater   r   ZtrainingZ_gradient_checkpointing_func__call__r   )r-   r   r   r   r   r   r   Zall_hidden_statesZall_attentionsiZlayer_moduleZlayer_outputsr   r   r   r   rB     sL   	




zTvpEncoder.forward)NNNNN)
r   r   r   r&   r   r   r   r   rB   rC   r   r   r.   r   r     s     	r   c                       r   )	TvpPoolerc                    s*   t    t|j|j| _t | _d S r   )r%   r&   r   r   rI   r   ZTanh
activationrs   r.   r   r   r&   &  s   
zTvpPooler.__init__r   ru   c                 C   s(   |d d df }|  |}| |}|S )Nr   )r   r   )r-   r   Zfirst_token_tensorpooled_outputr   r   r   rB   +  s   

zTvpPooler.forwardr   r   r   r.   r   r   %  s    r   c                   @   s    e Zd ZeZdZdZdd ZdS )TvpPreTrainedModelmodelTc                 C   s   t |tjtjfr|jjjd| jjd nt |tj	r(|j
j  |jjd t |tjr9|j
dur9|j
j  t |tjrXtjj|jddd |j
durZtj|j
d dS dS dS )	zInitialize the weights        )r   Zstdg      ?NZfan_outrY   )rx   Znonlinearityr   )r   r   r   re   weightdataZnormal_rG   Zinitializer_rangerm   rN   Zzero_Zfill_rQ   initZkaiming_normal_Z	constant_)r-   moduler   r   r   _init_weights:  s   
z TvpPreTrainedModel._init_weightsN)r   r   r   r   Zconfig_classZbase_model_prefixZsupports_gradient_checkpointingr   r   r   r   r   r   4  s
    r   c                       s(   e Zd ZdZ fddZdd Z  ZS )TvpFrameDownPadPrompterz>
    Pad frames extracted from videos only at the bottom.
    c              	      sb   |j dvr	tdt   |j| _|j| _|j| _|j | _ tt	
d|jd|j|jg| _d S )Nr;   replaceremove9`visual_prompter_apply` must be in (add, replace, remove)r   r   )visual_prompter_applyr+   r%   r&   visual_prompt_sizeZ	frame_nummax_img_sizer   	Parameterr   randnpad_downrs   r.   r   r   r&   R  s   


z TvpFrameDownPadPrompter.__init__c                 C   s   | j dkr&tj| j| jg|j|jd}d|| j| j | jd d f< ||9 }| j dkrctj|jd |jd d| j| jg|jd}| j| j }| j	|d d d d d d || jd d f< ||
|j7 }|S )	Nr;   r}   r   r   r   r   r   r   )r   r   r   r   r~   r   r   r   rV   r   to)r-   r[   visual_prompt_maskpromptZstart_pointr   r   r   rB   `  s   

*zTvpFrameDownPadPrompter.forwardr   r   r   r.   r   r   M  s    r   c                       sN   e Zd ZdZ fddZdejdededejfdd	Zdde	fddZ
  ZS )TvpFramePadPrompterz?
    Pad frames extracted from videos in the surroundings.
    c              
      s   |j dvr	tdt   |j| _|j| _|j | _ |j|jd  | _t	t
d|jd|j|jg| _t	t
d|jd|j|jg| _t	t
d|jd|j|jd  |jg| _t	t
d|jd|j|jd  |jg| _d S )Nr   r   rS   r   r   )r   r+   r%   r&   r]   r   r   	base_sizer   r   r   r   pad_upr   pad_left	pad_rightrs   r.   r   r   r&   w  sB   


zTvpFramePadPrompter.__init__r   r_   r`   ru   c                 C   sh   || j  || j  }}|j\}}}}	}
||| ||	|
}tjj|||fddd}||||||}|S )z
        This method allows to interpolate the pre-trained pad weights, to be able to use the model on collection of high
        resolution images (high resolution videos).

        rv   Frw   )r   rV   r   r   rX   ry   )r-   r   r_   r`   rz   r{   batchr]   ZchannelsZprompt_heightZprompt_widthr   r   r   interpolate_pad_encoding  s   z,TvpFramePadPrompter.interpolate_pad_encodingFr   c                 C   s   |r|j d |j d fn| j| jf\}}| jdvr!td| j | jdv r6tj||g|j|jd}||9 }| jdv r~tjd| j	d	| j
| j
|jd
}tj| j|| jgdd}tj| j|| jgd	d}t|d|g }|rv| |||}|||j }|S )Nr   rF   )r;   r   r   z$Invalid visual_prompter_apply value )r   r   r}   )r   r;   r   r   r   rU   r   r   )rV   r   r   r+   r   r   r~   r   r   r]   r   catr   r   r   r   r   r   r   )r-   r[   r   r_   r`   r   baser   r   r   r   rB     s$   



zTvpFramePadPrompter.forwardr   )r   r   r   r   r&   r   r   r   r   r   rB   rC   r   r   r.   r   r   r  s
    &r   )ZframedownpadZframepadzw
    The bare Tvp Model transformer outputting BaseModelOutputWithPooling object without any specific head on top.
    )Zcustom_introc                       s   e Zd Z fddZdd Zdd Zdd Ze															
ddee	j
 dee	j dee	j
 dee	j dee dee dee defddZ  ZS )TvpModelc                    s   t  | || _t|| _t|| _t|| _t	|| _
t|| _ttdd|jg| _t|j| _|jtvr?tdt|j || _|   d S )Nr   
   z:`visual_prompter_type` must be in (framedownpad, framepad))r%   r&   rG   rE   vision_modelr   r   rc   visual_embeddingsr   encoderr   poolerr   r   r   r   rI   text_promptrp   rq   rr   Zvisual_prompter_typeTVP_PROMPTER_CLASSES_MAPPINGr+   visual_prompter	post_initrs   r.   r   r   r&     s   





zTvpModel.__init__c                 C   s   | j jS r   r   r   )r-   r   r   r   get_input_embeddings  s   zTvpModel.get_input_embeddingsc                 C   s   || j _d S r   r  )r-   r   r   r   r   set_input_embeddings  s   zTvpModel.set_input_embeddingsc                 C   s*   |  D ]\}}| jj| j| qdS )zPrunes heads of the model.
        heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base class PreTrainedModel
        N)itemsr   r   r   r   )r-   Zheads_to_pruner   r   r   r   r   _prune_heads  s   zTvpModel._prune_headsNFr   r[   r   r   r   r   r   r|   c	                 C   sR  |dur|n| j j}| | j||d}| j|d}	| j||d}
|durU||
jdd }t	|jd dj
|j|jd}tj|||gd	d
}| || 
|j}| j|	jd d	d	}tj||	|
gdd
}| j||| || j j|||d}|r|jn|d }| |}| |}| |}|s||f|dd  S t|||j|jdS )a  
        Examples:
        ```python
        >>> import torch
        >>> from transformers import AutoConfig, AutoTokenizer, TvpModel

        >>> model = TvpModel.from_pretrained("Jiqing/tiny-random-tvp")

        >>> tokenizer = AutoTokenizer.from_pretrained("Jiqing/tiny-random-tvp")

        >>> pixel_values = torch.rand(1, 1, 3, 448, 448)
        >>> text_inputs = tokenizer("This is an example input", return_tensors="pt")
        >>> output = model(text_inputs.input_ids, pixel_values, text_inputs.attention_mask)
        ```N)r   )r   r   rS   r   r   )r   r~   rF   r   r   )r   r   r   r   r   )r   pooler_outputr   r   )rG   r   r   r  r   r   Znew_onesrV   r   r   r   r   r~   r   Zget_extended_attention_maskr   r  r   r   Zget_head_maskr   r   r  rr   r	   r   r   )r-   r   r[   r   r   r   r   r   r|   Ztext_embedding_outputZvisual_embedding_outputZvisual_attention_maskZpt_maskr  Zembedding_outputZencoder_outputsr   r   r   r   r   rB     sJ   


zTvpModel.forward)NNNNNNNF)r   r   r   r&   r  r  r
  r   r   r   
LongTensorr   r   rB   rC   r   r   r.   r   r     s>    	r   c                       rD   )TvpVideoGroundingHeadc                    sL   t    t|j|jd | _t|jd d| _t | _t	 | _
d S )NrS   )r%   r&   r   r   rI   layer_0layer_1ZReLUactivation_0ZSigmoidactivation_1rs   r.   r   r   r&   @  s
   

zTvpVideoGroundingHead.__init__c                 C   s$   |  | |}| | |}|S r   )r  r  r  r  )r-   r  r   r   r   r   rB   G  s   zTvpVideoGroundingHead.forwardrb   r   r   r.   r   r  ?  s    r  zb
    Tvp Model with a video grounding head on top computing IoU, distance, and duration loss.
    c                       s   e Zd Z fddZe									ddeej deej deej dee	ej
  d	eej d
ee dee dee defddZ  ZS )TvpForVideoGroundingc                    s2   t  | || _t|| _t|| _|   d S r   )r%   r&   rG   r   r   r  video_grounding_headr  rs   r.   r   r   r&   S  s
   

zTvpForVideoGrounding.__init__NFr   r[   r   r@   r   r   r   r   r|   c
              
   C   s   |dur|n| j j}| j||||||||	d}
|
d }| |}d}|durKtg d}|| j |||}|d | j j|d   | j j|d   }|sa|f|
dd  }
|dur_|f|
 }
|
S t	|||
j
|
jd	S )
a  
        labels (`torch.FloatTensor` of shape `(batch_size, 3)`, *optional*):
            The labels contains duration, start time, and end time of the video corresponding to the text.

        Examples:
        ```python
        >>> import torch
        >>> from transformers import AutoConfig, AutoTokenizer, TvpForVideoGrounding

        >>> model = TvpForVideoGrounding.from_pretrained("Jiqing/tiny-random-tvp")

        >>> tokenizer = AutoTokenizer.from_pretrained("Jiqing/tiny-random-tvp")

        >>> pixel_values = torch.rand(1, 1, 3, 448, 448)
        >>> text_inputs = tokenizer("This is an example input", return_tensors="pt")
        >>> output = model(text_inputs.input_ids, pixel_values, text_inputs.attention_mask)
        ```N)r   r   r   r   r|   r   r!   r"   r#   r$   rS   )r   r   r   r   )rG   r   r   r  r    r   r   Zdistance_loss_weightZduration_loss_weightr   r   r   )r-   r   r[   r   r@   r   r   r   r   r|   r   r  r   r   	criterionZ	loss_dictr   r   r   rB   [  sF   



zTvpForVideoGrounding.forward)	NNNNNNNNF)r   r   r   r&   r   r   r   r  r   r   r   r   rB   rC   r   r   r.   r   r  M  s>    	
r  )r   r   r  )2r   r   dataclassesr   typingr   r   r   Ztorch.utils.checkpointr   Zactivationsr   Zmodeling_outputsr   r	   r
   Zmodeling_utilsr   Zpytorch_utilsr   utilsr   r   Zutils.backbone_utilsr   Zconfiguration_tvpr   Z
get_loggerr   loggerr   Moduler    rE   rc   r   r   r   r   r   r   r   r   r   r   r  r   r  r  __all__r   r   r   r   <module>   sX   
P(q$c?%[hM