o
    Zh                     @   s  d Z ddlZddlZddlZddlmZ ddlmZmZm	Z	m
Z
 ddlZddlmZ ddlmZmZ ddlmZmZ dd	lmZ dd
lmZmZmZ ddlmZ eeZG dd dejZeG dd deZ eG dd deZ!eG dd deZ"dd Z#G dd dejZ$G dd dejZ%G dd dejZ&G dd dejZ'G dd  d ejZ(G d!d" d"ejZ)G d#d$ d$ejZ*G d%d& d&ejZ+G d'd( d(ejZ,G d)d* d*ejZ-G d+d, d,ejZ.G d-d. d.ejZ/G d/d0 d0ejZ0G d1d2 d2ejZ1G d3d4 d4ejZ2G d5d6 d6ejZ3G d7d8 d8ejZ4eG d9d: d:eZ5eG d;d< d<e5Z6eG d=d> d>e5Z7ed?d@G dAdB dBe5Z8g dCZ9dS )DzPyTorch LXMERT model.    N)	dataclass)DictOptionalTupleUnion)nn)CrossEntropyLossSmoothL1Loss   )ACT2FNgelu)PreTrainedModel)ModelOutputauto_docstringlogging   )LxmertConfigc                       $   e Zd Z fddZdd Z  ZS )GeLUc                    s   t    d S N)super__init__self	__class__ Y/var/www/auris/lib/python3.10/site-packages/transformers/models/lxmert/modeling_lxmert.pyr   %   s   zGeLU.__init__c                 C   s   t |S r   )r   )r   xr   r   r   forward(      zGeLU.forward__name__
__module____qualname__r   r   __classcell__r   r   r   r   r   $   s    r   c                   @   s   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eej ed< dZeeej  ed< dZeeej  ed< dZeeej  ed< dZeeej  ed	< dZeeej  ed
< dS )LxmertModelOutputak  
    Lxmert's outputs that contain the last hidden states, pooled outputs, and attention probabilities for the language,
    visual, and, cross-modality encoders. (note: the visual encoder in Lxmert is referred to as the "relation-ship"
    encoder")


    Args:
        language_output (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the language encoder.
        vision_output (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the visual encoder.
        pooled_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
            Last layer hidden-state of the first token of the sequence (classification, CLS, token) further processed
            by a Linear layer and a Tanh activation function. The Linear
        language_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for input features + one for the output of each cross-modality layer) of
            shape `(batch_size, sequence_length, hidden_size)`.
        vision_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for input features + one for the output of each cross-modality layer) of
            shape `(batch_size, sequence_length, hidden_size)`.
        language_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
            the self-attention heads.
        vision_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
            the self-attention heads.
        cross_encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
            the self-attention heads.
    Nlanguage_outputvision_outputpooled_outputlanguage_hidden_statesvision_hidden_stateslanguage_attentionsvision_attentionscross_encoder_attentions)r"   r#   r$   __doc__r'   r   torchFloatTensor__annotations__r(   r)   r*   r   r+   r,   r-   r.   r   r   r   r   r&   ,   s   
 "r&   c                   @   s   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eeej  ed< dZeeej  ed< dZeeej  ed< dZeeej  ed< dZeeej  ed	< dS )
 LxmertForQuestionAnsweringOutputa	  
    Output type of [`LxmertForQuestionAnswering`].

    Args:
        loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
            Total loss as the sum of the masked language modeling loss and the next sequence prediction
            (classification) loss.k.
        question_answering_score (`torch.FloatTensor` of shape `(batch_size, n_qa_answers)`, *optional*):
            Prediction scores of question answering objective (classification).
        language_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for input features + one for the output of each cross-modality layer) of
            shape `(batch_size, sequence_length, hidden_size)`.
        vision_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for input features + one for the output of each cross-modality layer) of
            shape `(batch_size, sequence_length, hidden_size)`.
        language_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
            the self-attention heads.
        vision_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
            the self-attention heads.
        cross_encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
            the self-attention heads.
    Nlossquestion_answering_scorer*   r+   r,   r-   r.   )r"   r#   r$   r/   r4   r   r0   r1   r2   r5   r*   r   r+   r,   r-   r.   r   r   r   r   r3   Z   s   
 r3   c                   @   s   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eej ed< dZeej ed< dZeeej  ed< dZeeej  ed< dZeeej  ed	< dZeeej  ed
< dZeeej  ed< dS )LxmertForPreTrainingOutputak  
    Output type of [`LxmertForPreTraining`].

    Args:
        loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
            Total loss as the sum of the masked language modeling loss and the next sequence prediction
            (classification) loss.
        prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        cross_relationship_score (`torch.FloatTensor` of shape `(batch_size, 2)`):
            Prediction scores of the textual matching objective (classification) head (scores of True/False
            continuation before SoftMax).
        question_answering_score (`torch.FloatTensor` of shape `(batch_size, n_qa_answers)`):
            Prediction scores of question answering objective (classification).
        language_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for input features + one for the output of each cross-modality layer) of
            shape `(batch_size, sequence_length, hidden_size)`.
        vision_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for input features + one for the output of each cross-modality layer) of
            shape `(batch_size, sequence_length, hidden_size)`.
        language_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
            the self-attention heads.
        vision_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
            the self-attention heads.
        cross_encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
            the self-attention heads.

    Nr4   prediction_logitscross_relationship_scorer5   r*   r+   r,   r-   r.   )r"   r#   r$   r/   r4   r   r0   r1   r2   r7   r8   r5   r*   r   r+   r,   r-   r.   r   r   r   r   r6      s   
 #r6   c                 C   s  zddl }ddl}ddl}W n ty   td  w tj|}t	d|  |j
|}g }g }	|D ] \}
}t	d|
 d|  |j
||
}||
 |	| q6t||	D ]\}
}|
d}
tdd	 |
D rzt	d
d|
  q\| }|
D ]|}|d|r|d|}n|g}|d dks|d dkrt|d}nH|d dks|d dkrt|d}n6|d dkrt|d}n*|d dkrt|d}nz	t||d }W n ty   t	d
d|
  Y q~w t|dkrt|d }|| }q~|dd dkr
t|d}n
|dkr||}z|j|jksJ W n ty8 } z| j|j|jf7  _ d}~ww t	d|
  t||_q\| S )z'Load tf checkpoints in a pytorch model.r   NzLoading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see https://www.tensorflow.org/install/ for installation instructions.z&Converting TensorFlow checkpoint from zLoading TF weight z with shape /c                 s   s    | ]}|d v V  qdS ))Zadam_vZadam_mZAdamWeightDecayOptimizerZAdamWeightDecayOptimizer_1Zglobal_stepNr   ).0nr   r   r   	<genexpr>   s    	
z,load_tf_weights_in_lxmert.<locals>.<genexpr>z	Skipping z[A-Za-z]+_\d+z_(\d+)ZkernelgammaweightZoutput_biasbetabiasZoutput_weightsZsquadZ
classifier   r   iZ_embeddingszInitialize PyTorch weight )renumpyZ
tensorflowImportErrorloggererrorospathabspathinfotrainZlist_variablesZload_variableappendzipsplitanyjoin	fullmatchgetattrAttributeErrorlenint	transposeshapeAssertionErrorargsr0   Z
from_numpydata)modelconfigZtf_checkpoint_pathrB   nptfZtf_pathZ	init_varsnamesZarraysnamerW   arrayZpointerZm_nameZscope_namesnumer   r   r   load_tf_weights_in_lxmert   s   

	

rd   c                       s*   e Zd ZdZ fddZdddZ  ZS )LxmertEmbeddingszGConstruct the embeddings from word, position and token_type embeddings.c                    sp   t    tj|j|jdd| _tj|j|jdd| _tj|j	|jdd| _
tj|jdd| _t|j| _d S )Nr   )padding_idx-q=Zeps)r   r   r   	Embedding
vocab_sizehidden_sizeword_embeddingsZmax_position_embeddingsposition_embeddingsZtype_vocab_sizetoken_type_embeddings	LayerNormDropouthidden_dropout_probdropoutr   r\   r   r   r   r     s   
zLxmertEmbeddings.__init__Nc                 C   s   |d ur|  }|j}n|  d d }|j}|d }tj|tj|d}|d|}|d u r;tj|tj| jjd}|d u rD| 	|}| 
|}| |}	|| |	 }
| |
}
| |
}
|
S )Nr   dtypedevicer   )sizerw   r0   Zarangelong	unsqueezeexpandzerosposition_idsrl   rm   rn   ro   rr   )r   	input_idstoken_type_idsinputs_embedsinput_shaperw   Z
seq_lengthr}   rm   rn   
embeddingsr   r   r   r     s$   




zLxmertEmbeddings.forwardNN)r"   r#   r$   r/   r   r   r%   r   r   r   r   re     s    re   c                       s0   e Zd Zd	 fdd	Zdd Zd
ddZ  ZS )LxmertAttentionNc                    s   t    |j|j dkrtd|j d|j d|j| _t|j|j | _| j| j | _|d u r5|j}t	|j| j| _
t	|| j| _t	|| j| _t|j| _d S )Nr   zThe hidden size (z6) is not a multiple of the number of attention heads ())r   r   rk   num_attention_heads
ValueErrorrU   attention_head_size	head_sizer   Linearquerykeyvaluerp   Zattention_probs_dropout_probrr   )r   r\   Zctx_dimr   r   r   r   *  s    

zLxmertAttention.__init__c                 C   s6   |  d d | j| jf }||}|ddddS )Nrt   r   rA   r   r
   )rx   r   r   viewpermute)r   r   Znew_x_shaper   r   r   transpose_for_scores>  s   
z$LxmertAttention.transpose_for_scoresFc                 C   s   |  |}| |}| |}| |}| |}	| |}
t||	dd}|t| j	 }|d ur8|| }t
jj|dd}| |}t||
}|dddd }| d d | jf }||}|rn||f}|S |f}|S )Nrt   )dimr   rA   r   r
   )r   r   r   r   r0   matmulrV   mathsqrtr   r   Z
functionalZsoftmaxrr   r   
contiguousrx   r   r   )r   hidden_statescontextattention_maskoutput_attentionsZmixed_query_layerZmixed_key_layerZmixed_value_layerZquery_layerZ	key_layerZvalue_layerZattention_scoresattention_probsZcontext_layerZnew_context_layer_shapeoutputsr   r   r   r   F  s(   







zLxmertAttention.forwardr   NF)r"   r#   r$   r   r   r   r%   r   r   r   r   r   )  s    r   c                       r   )LxmertAttentionOutputc                    s@   t    t|j|j| _tj|jdd| _t|j| _	d S Nrg   rh   )
r   r   r   r   rk   densero   rp   rq   rr   rs   r   r   r   r   g     
zLxmertAttentionOutput.__init__c                 C   &   |  |}| |}| || }|S r   r   rr   ro   r   r   input_tensorr   r   r   r   m     

zLxmertAttentionOutput.forwardr!   r   r   r   r   r   f      r   c                       &   e Zd Z fddZdddZ  ZS )LxmertCrossAttentionLayerc                    "   t    t|| _t|| _d S r   )r   r   r   attr   outputrs   r   r   r   r   u     

z"LxmertCrossAttentionLayer.__init__NFc           	      C   sH   | j ||||d}|r|d }| |d |}|r||f}|S |f}|S Nr   r   r   )r   r   )	r   r   Z
ctx_tensorctx_att_maskr   r   r   attention_outputr   r   r   r   r   z  s   z!LxmertCrossAttentionLayer.forwardr   r!   r   r   r   r   r   t      r   c                       s&   e Zd Z fddZdddZ  ZS )LxmertSelfAttentionLayerc                    r   r   )r   r   r   r   r   r   rs   r   r   r   r     r   z!LxmertSelfAttentionLayer.__init__Fc                 C   sH   | j ||||d}|r|d }| |d |}|r||f}|S |f}|S r   )r   r   )r   r   r   r   r   r   r   r   r   r   r   r     s   z LxmertSelfAttentionLayer.forwardFr!   r   r   r   r   r     r   r   c                       r   )LxmertIntermediatec                    s,   t    t|j|j| _t|j | _	d S r   )
r   r   r   r   rk   intermediate_sizer   r   
hidden_actintermediate_act_fnrs   r   r   r   r     s   
zLxmertIntermediate.__init__c                 C   s   |  |}| |}|S r   )r   r   r   r   r   r   r   r     s   

zLxmertIntermediate.forwardr!   r   r   r   r   r         r   c                       r   )LxmertOutputc                    s@   t    t|j|j| _tj|jdd| _t|j	| _
d S r   )r   r   r   r   r   rk   r   ro   rp   rq   rr   rs   r   r   r   r     r   zLxmertOutput.__init__c                 C   r   r   r   r   r   r   r   r     r   zLxmertOutput.forwardr!   r   r   r   r   r     r   r   c                       r   )LxmertLayerc                    s,   t    t|| _t|| _t|| _d S r   )r   r   r   	attentionr   intermediater   r   rs   r   r   r   r     s   


zLxmertLayer.__init__NFc                 C   sD   | j |||d}|d }| |}| ||}|f|dd   }|S )Nr   r   r   )r   r   r   )r   r   r   r   r   r   Zintermediate_outputZlayer_outputr   r   r   r     s   
zLxmertLayer.forwardr   r!   r   r   r   r   r     s    r   c                       sD   e Zd Z fddZ	dddZdd Zdd	 Z	dd
dZ  ZS )LxmertXLayerc                    sT   t    t|| _t|| _t|| _t|| _t	|| _
t|| _t	|| _d S r   )r   r   r   visual_attentionr   lang_self_attvisn_self_attr   
lang_interr   lang_output
visn_intervisn_outputrs   r   r   r   r     s   






zLxmertXLayer.__init__Fc                 C   s,   | j ||||d}| j |||dd}||fS )N)r   r   F)r   )r   
lang_inputlang_attention_maskvisual_inputvisual_attention_maskoutput_x_attentionslang_att_outputvisual_att_outputr   r   r   	cross_att  s   	zLxmertXLayer.cross_attc                 C   s0   | j ||dd}| j||dd}|d |d fS )NFr   r   )r   r   )r   r   r   r   r   r   r   r   r   r   self_att  s   zLxmertXLayer.self_attc                 C   s4   |  |}| |}| ||}| ||}||fS r   )r   r   r   r   )r   r   r   Zlang_inter_outputZvisual_inter_outputr   visual_outputr   r   r   	output_fc  s
   

zLxmertXLayer.output_fcc                 C   sj   | j |||||d\}}|dd  }| |d ||d |\}}| ||\}	}
|r1|	|
|d fS |	|
fS )N)r   r   r   r   r   r   r   )r   r   r   )r   
lang_featsr   visual_featsr   r   r   r   r   r   r   r   r   r   r     s0   
zLxmertXLayer.forwardr   )	r"   r#   r$   r   r   r   r   r   r%   r   r   r   r   r     s    
r   c                       r   )LxmertVisualFeatureEncoderc                    sl   t    |j}|j}t||j| _tj|jdd| _	t||j| _
tj|jdd| _t|j| _d S r   )r   r   visual_feat_dimZvisual_pos_dimr   r   rk   visn_fcro   visn_layer_normbox_fcbox_layer_normrp   rq   rr   )r   r\   Zfeat_dimZpos_dimr   r   r   r     s   
z#LxmertVisualFeatureEncoder.__init__c                 C   sB   |  |}| |}| |}| |}|| d }| |}|S NrA   )r   r   r   r   rr   )r   r   
visual_posr   yr   r   r   r   r   -  s   




z"LxmertVisualFeatureEncoder.forwardr!   r   r   r   r   r     s    r   c                       s*   e Zd Z fddZ		dddZ  ZS )LxmertEncoderc                    s   t    t | _ | _ j| _ j| _ j	| _
t fddt| jD | _t fddt| jD | _t fddt| j
D | _	d S )Nc                       g | ]}t  qS r   r   r:   _r\   r   r   
<listcomp>G      z*LxmertEncoder.__init__.<locals>.<listcomp>c                    r   r   )r   r   r   r   r   r   H  r   c                    r   r   r   r   r   r   r   r   I  r   )r   r   r   r   r\   Zl_layersZnum_l_layersx_layersZnum_x_layersr_layersZnum_r_layersr   Z
ModuleListrangelayerrs   r   r   r   r   9  s   

  $zLxmertEncoder.__init__Nc                 C   s`  d}d}|s
| j jrdnd }	|s| j jrdnd }
|s| j jr dnd }| ||}| jD ]}||||d}|d }||f }|
d urH|
|d f }
q+| jD ]}||||d}|d }||f }|	d uri|	|d f }	qL| jD ](}||||||d}|d d \}}||f }||f }|d ur||d f }qm||r|	nd f}||r|
nd f}|||r|fS d fS )Nr   r   r   r   rA   )r\   r   r   r   r   r   )r   r   r   r   r   r   r   r+   r*   r-   r,   r.   Zlayer_moduleZ	l_outputsZ	v_outputsZ	x_outputsvisual_encoder_outputslang_encoder_outputsr   r   r   r   K  s\   	








zLxmertEncoder.forwardr   r!   r   r   r   r   r   8  s
    r   c                       r   )LxmertPoolerc                    s.   t t|   t|j|j| _t | _d S r   )	r   r   r   r   r   rk   r   ZTanh
activationrs   r   r   r   r     s   zLxmertPooler.__init__c                 C   s(   |d d df }|  |}| |}|S )Nr   )r   r   )r   r   Zfirst_token_tensorr)   r   r   r   r     s   

zLxmertPooler.forwardr!   r   r   r   r   r     r   r   c                       r   )LxmertPredictionHeadTransformc                    sB   t t|   t|j|j| _t|j | _	tj
|jdd| _
d S r   )r   r   r   r   r   rk   r   r   r   transform_act_fnro   rs   r   r   r   r     s   z&LxmertPredictionHeadTransform.__init__c                 C   s"   |  |}| |}| |}|S r   )r   r   ro   r   r   r   r   r     s   


z%LxmertPredictionHeadTransform.forwardr!   r   r   r   r   r     r   r   c                       r   )LxmertLMPredictionHeadc                    sZ   t t|   t|| _tj|d|ddd| _|| j_	t
t|d| _d S )Nr   r   Fr@   )r   r   r   r   	transformr   r   rx   decoderr>   	Parameterr0   r|   r@   r   r\   Zlxmert_model_embedding_weightsr   r   r   r     s   
zLxmertLMPredictionHead.__init__c                 C   s   |  |}| || j }|S r   )r   r   r@   r   r   r   r   r     s   
zLxmertLMPredictionHead.forwardr!   r   r   r   r   r     s    r   c                       r   )LxmertVisualAnswerHeadc              	      sN   t    |j}tt||d t tj|d ddt|d || _d S )NrA   rg   rh   )	r   r   rk   r   Z
Sequentialr   r   ro   logit_fc)r   r\   
num_labelsZhid_dimr   r   r   r     s   

zLxmertVisualAnswerHead.__init__c                 C   s
   |  |S r   )r   r   r   r   r   r     s   
zLxmertVisualAnswerHead.forwardr!   r   r   r   r   r     s    
r   c                       r   )LxmertVisualObjHeadc                    s   t    t _i } jrd jd|d<  jr"d jd|d<  jr0d j	f j	d|d< |_
t fddj
D _d S )	Nrt   )rW   rb   objattrrt   featc                    s&   i | ]}|t  jj| d  qS )rb   )r   r   rk   visual_losses)r:   r   r\   r   r   r   
<dictcomp>  s   & z0LxmertVisualObjHead.__init__.<locals>.<dictcomp>)r   r   r   r   visual_obj_lossnum_object_labelsvisual_attr_lossnum_attr_labelsvisual_feat_lossr   r   r   Z
ModuleDictdecoder_dictr   r\   r   r   r   r   r     s   



zLxmertVisualObjHead.__init__c                 C   s0   |  |}i }| jD ]}| j| |||< q
|S r   )r   r   r  )r   r   r   r   r   r   r   r     s
   

zLxmertVisualObjHead.forwardr!   r   r   r   r   r     s    r   c                       r   )LxmertPreTrainingHeadsc                    s.   t t|   t||| _t|jd| _d S r   )	r   r  r   r   predictionsr   r   rk   seq_relationshipr   r   r   r   r     s   zLxmertPreTrainingHeads.__init__c                 C   s   |  |}| |}||fS r   )r  r	  )r   Zsequence_outputr)   Zprediction_scoresZseq_relationship_scorer   r   r   r     s   

zLxmertPreTrainingHeads.forwardr!   r   r   r   r   r    r   r  c                   @   s$   e Zd ZeZeZdZdZdd Z	dS )LxmertPreTrainedModellxmertFc                 C   s   t |tjr |jjjd| jjd |jdur|jj	  dS dS t |tj
rC|jjjd| jjd |jdurA|jj|j 	  dS dS t |tjrX|jj	  |jjd dS t |tre|jj	  dS dS )zInitialize the weights        )meanZstdN      ?)
isinstancer   r   r>   rZ   Znormal_r\   Zinitializer_ranger@   Zzero_ri   rf   ro   Zfill_r   )r   moduler   r   r   _init_weights  s    


z#LxmertPreTrainedModel._init_weightsN)
r"   r#   r$   r   Zconfig_classrd   Zload_tf_weightsZbase_model_prefixZ!_supports_param_buffer_assignmentr  r   r   r   r   r
    s    r
  c                       s   e Zd Z fddZdd Zdd Ze										ddeej	 d	eej
 d
eej
 deej
 deej
 deej	 deej
 dee dee dee deeeej
 f fddZ  ZS )LxmertModelc                    s6   t  | t|| _t|| _t|| _|   d S r   )	r   r   re   r   r   encoderr   pooler	post_initrs   r   r   r   r     s
   


zLxmertModel.__init__c                 C   s   | j jS r   r   rl   r   r   r   r   get_input_embeddings  r    z LxmertModel.get_input_embeddingsc                 C   s   || j _d S r   r  )r   new_embeddingsr   r   r   set_input_embeddings  s   z LxmertModel.set_input_embeddingsNr~   r   r   r   r   r   r   r   output_hidden_statesreturn_dictreturnc              
   C   sl  |dur|n| j j}|	dur|	n| j j}	|
dur|
n| j j}
|dur*|dur*td|dur9| || | }n|durF| dd }ntd|du rRtd|du rZtd|dura|jn|j}|du rotj	||d}|du r|tj
|tj|d}|d	d
}|j| jd}d| t| jj }|dur|d	d
}|j| jd}d| t| jj }nd}| |||}| j||||||d}|dd
 \}}|d }|d }d}|r|d	 }|d	 }|d
 }|||f}|	r||fnd}|d }|d }| |}|
s|||f| | S t||||	r|nd|	r|nd|r$|nd|r*|nd|r2|dS ddS )aw  
        visual_feats (`torch.FloatTensor` of shape `(batch_size, num_visual_features, visual_feat_dim)`):
            This input represents visual features. They ROI pooled object features from bounding boxes using a
            faster-RCNN model)

            These are currently not provided by the transformers library.
        visual_pos (`torch.FloatTensor` of shape `(batch_size, num_visual_features, visual_pos_dim)`):
            This input represents spatial features corresponding to their relative (via index) visual features. The
            pre-trained LXMERT model expects these spatial features to be normalized bounding boxes on a scale of 0 to
            1.

            These are currently not provided by the transformers library.
        visual_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)
        NzDYou cannot specify both input_ids and inputs_embeds at the same timert   z5You have to specify either input_ids or inputs_embedsz`visual_feats` cannot be `None`z`visual_pos` cannot be `None`rw   ru   r   rA   )rv   r  )r   r   r   r   r   r   )r)   r'   r(   r*   r+   r,   r-   r.   )r\   r   r  use_return_dictr   Z%warn_if_padding_and_no_attention_maskrx   rw   r0   Zonesr|   ry   rz   torv   Zfinfominr   r  r  r&   )r   r~   r   r   r   r   r   r   r   r  r  r   rw   Zextended_attention_maskZextended_visual_attention_maskZembedding_outputZencoder_outputsr   r   r+   r*   Zall_attentionsr,   r-   r.   r   r   r   r)   r   r   r   r     s   "
	
zLxmertModel.forward)
NNNNNNNNNN)r"   r#   r$   r   r  r  r   r   r0   
LongTensorr1   boolr   r&   r   r   r%   r   r   r   r   r    sL    	
r  c                #       sf  e Zd ZdgZ fddZdd Z	d*ded	ee d
ede	j
f fddZdefddZdd Zdd Zde	jfddZdd Zdd Ze														d+deej deej deej deej deej deej d eej d!eej d"eeeeejejf f  d#eej d$eej d%ee d&ee d'ee deeeej f fd(d)Z  ZS ),LxmertForPreTrainingzcls.predictions.decoder.weightc                    s  t  | || _|j| _|j| _|j| _|j| _|j| _|j| _t	|| _
t|| j
jjj| _| jr8t|| _| jrBt|| j| _|   tddtddt d| _i }|jrbd|jdd|d< |jrnd|jdd|d< |jr}d	|jf|jd
d|d< || _d S )Nnone)Z	reduction)l2	visual_cecer   r&  )rW   rb   r4   r   r   rt   r%  r   )r   r   r\   num_qa_labelsvisual_loss_normalizertask_mask_lmtask_obj_predicttask_matchedtask_qar  r  r  r   rl   r>   clsr   obj_predict_headr   answer_headr  r	   r   	loss_fctsr   r  r  r  r  r   r   r  r   r   r   r     sH   





zLxmertForPreTraining.__init__c                 C   s   | j jjj| jjj_d S r   )r  r   rl   r>   r.  r  r   r   r   r   r   _tie_weights  s   z!LxmertForPreTraining._tie_weightsNTnew_num_tokenspad_to_multiple_ofmean_resizingr  c                    s,   t  |||}| | jjj|| jj_|S r   )r   resize_token_embeddings_resize_biasr.  r  r@   )r   r3  r4  r5  r  r   r   r   r6    s   z,LxmertForPreTraining.resize_token_embeddingsc                 C   sP   |j d }||kr|d | }ntj|| |jd}t||g}t|}|S )Nr   r  )rW   r0   r|   rw   catr   r   )r   r@   r3  Zold_num_tokensZnew_biasZ
extra_biasr   r   r   r7    s   

z!LxmertForPreTraining._resize_biasc                 C   8   |   }|du s|du rdS | |}|| j_|| _|S a  
        Build a resized question answering linear layer Module from a provided new linear layer. Increasing the size
        will add newly initialized weights. Reducing the size will remove weights from the end

        Args:
            num_labels (`int`, *optional*):
                New number of labels in the linear layer weight matrix. Increasing the size will add newly initialized
                weights at the end. Reducing the size will remove weights from the end. If not provided or `None`, just
                returns a pointer to the qa labels ``torch.nn.Linear``` module of the model without doing anything.

        Return:
            `torch.nn.Linear`: Pointer to the resized Linear layer or the old Linear layer
        Nget_qa_logit_layer_resize_qa_labelsr\   r(  r   r   cur_qa_logit_layernew_qa_logit_layerr   r   r   resize_num_qa_labels     
z)LxmertForPreTraining.resize_num_qa_labelsc                 C   &   |   }| ||}| | |   S r   r<  _get_resized_qa_labels_set_qa_logit_layerr>  r   r   r   r=       
z&LxmertForPreTraining._resize_qa_labelsc                 C      t | dr| jjd S dS )a  
        Returns the linear layer that produces question answering logits.

        Returns:
            `nn.Module`: A torch module mapping the question answering prediction hidden states or `None` if LXMERT
            does not have a visual answering head.
        r0  rt   Nhasattrr0  r   r   r   r   r   r<    s   
z'LxmertForPreTraining.get_qa_logit_layerc                 C      || j jd< d S Nrt   r0  r   r   Zqa_logit_layerr   r   r   rF       z(LxmertForPreTraining._set_qa_logit_layerc                 C      |d u r|S |j  \}}||kr|S t|dd d ur"t||}ntj||dd}||j j | | t||}|j j	d |d d f |j j	d |d d f< t|dd d urg|j
j	d | |j
j	d |< |S Nr@   Fr   r>   rx   rR   r   r   r  rw   r  r   rZ   r@   r   r?  r   Zcur_qa_labelsZ
hidden_dimr@  Znum_labels_to_copyr   r   r   rE  !     

,z+LxmertForPreTraining._get_resized_qa_labelsr~   r   r   r   r   r   r   labels
obj_labelsmatched_labelansr   r  r  c           *      K   s  d|v rt dt |d}|dur|n| jj}|dur |jn|j}| j||||||||||d
}|d |d |d }}}| ||\}}| j	rQ| 
|}n|d d }|du ri|
du ri|	du ri|du ridntjd|d	}|dur| jr| jd
 |d| jj|d}||7 }|
dur| jr| jd
 |dd|
d}||7 }|	dur| jrtjd|jd	}| |}| j D ]H\}}|	| \}} |d }!|d }"|d }#| j}$| j|" }%|| }&|%|&d|!||#}'|' dkr|'d}'|'| d  |$ }'||'7 }q||7 }|dur+| j	r+| jd
 |d| j|d}(||(7 }|sE|||f|dd  })|durC|f|) S |)S t|||||j|j|j|j|jd	S )a	  
        visual_feats (`torch.FloatTensor` of shape `(batch_size, num_visual_features, visual_feat_dim)`):
            This input represents visual features. They ROI pooled object features from bounding boxes using a
            faster-RCNN model)

            These are currently not provided by the transformers library.
        visual_pos (`torch.FloatTensor` of shape `(batch_size, num_visual_features, visual_pos_dim)`):
            This input represents spatial features corresponding to their relative (via index) visual features. The
            pre-trained LXMERT model expects these spatial features to be normalized bounding boxes on a scale of 0 to
            1.

            These are currently not provided by the transformers library.
        visual_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
            config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
            loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
        obj_labels (`Dict[Str: Tuple[Torch.FloatTensor, Torch.FloatTensor]]`, *optional*):
            each key is named after each one of the visual losses and each element of the tuple is of the shape
            `(batch_size, num_features)` and `(batch_size, num_features, visual_feature_dim)` for each the label id and
            the label score respectively
        matched_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the whether or not the text input matches the image (classification) loss. Input
            should be a sequence pair (see `input_ids` docstring) Indices should be in `[0, 1]`:

            - 0 indicates that the sentence does not match the image,
            - 1 indicates that the sentence does match the image.
        ans (`Torch.Tensor` of shape `(batch_size)`, *optional*):
            a one hot representation hof the correct answer *optional*
        Zmasked_lm_labelszlThe `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.N
r~   r   r   r   r   r   r   r  r   r  r   r   rA   r  r  r'  rt   rb   r4   rW   r
   )	r4   r7   r8   r5   r*   r+   r,   r-   r.   ) warningswarnFutureWarningpopr\   r  rw   r  r.  r-  r0  r0   Ztensorr*  r1  r   rj   r,  r+  r/  r   itemsr)  r   r  r(  r6   r*   r+   r,   r-   r.   )*r   r~   r   r   r   r   r   r   rU  rV  rW  rX  r   r  r  kwargsrw   lxmert_outputr   r   r)   Zlang_prediction_scoresr8   answer_scoreZ
total_lossZmasked_lm_lossZmatched_lossZtotal_visual_lossZvisual_prediction_scores_dictr   Zkey_infolabelZ	mask_confZ
output_dimZloss_fct_nameZlabel_shaper>   Zvisual_loss_fctZvisual_prediction_scoresZvisual_lossZanswer_lossr   r   r   r   r   <  s   8

 




 
zLxmertForPreTraining.forward)NT)NNNNNNNNNNNNNN)r"   r#   r$   Z_tied_weights_keysr   r2  rU   r   r"  r   ri   r6  r7  rA  r=  Moduler<  rF  rE  r   r0   r!  r1   r   strr   Tensorr   r6   r   r%   r   r   r   r   r#    s    7
	
r#  zR
    Lxmert Model with a visual-answering head on top for downstream QA tasks
    )Zcustom_introc                       s   e Zd Z fddZdd Zdd Zdejfdd	Zd
d Z	dd Z
e											ddeej deej deej deej deej deej deej deej dee dee dee deeeej f fddZ  ZS )LxmertForQuestionAnsweringc                    sN   t  | || _|j| _|j| _t|| _t|| j| _| 	  t
 | _d S r   )r   r   r\   r(  r)  r  r  r   r0  r  r   r4   rs   r   r   r   r     s   
z#LxmertForQuestionAnswering.__init__c                 C   r9  r:  r;  r>  r   r   r   rA    rB  z/LxmertForQuestionAnswering.resize_num_qa_labelsc                 C   rC  r   rD  r>  r   r   r   r=    rG  z,LxmertForQuestionAnswering._resize_qa_labelsr  c                 C   rH  )a  
        Returns the linear layer that produces question answering logits

        Returns:
            `nn.Module`: A torch module mapping the question answering prediction hidden states. `None`: A NoneType
            object if Lxmert does not have the visual answering head.
        r0  rt   NrI  r   r   r   r   r<  	  s   
	z-LxmertForQuestionAnswering.get_qa_logit_layerc                 C   rK  rL  rM  rN  r   r   r   rF    rO  z.LxmertForQuestionAnswering._set_qa_logit_layerc                 C   rP  rQ  rR  rS  r   r   r   rE    rT  z1LxmertForQuestionAnswering._get_resized_qa_labelsNr~   r   r   r   r   r   r   rU  r   r  r  c                 C   s   |dur|n| j j}| j||||||||
|	|d
}|d }| |}d}|dur6| |d| j|d}|sL|f|dd  }|durJ|f| S |S t|||j|j	|j
|j|jdS )a  
        visual_feats (`torch.FloatTensor` of shape `(batch_size, num_visual_features, visual_feat_dim)`):
            This input represents visual features. They ROI pooled object features from bounding boxes using a
            faster-RCNN model)

            These are currently not provided by the transformers library.
        visual_pos (`torch.FloatTensor` of shape `(batch_size, num_visual_features, visual_pos_dim)`):
            This input represents spatial features corresponding to their relative (via index) visual features. The
            pre-trained LXMERT model expects these spatial features to be normalized bounding boxes on a scale of 0 to
            1.

            These are currently not provided by the transformers library.
        visual_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)
        labels (`Torch.Tensor` of shape `(batch_size)`, *optional*):
            A one-hot representation of the correct answer
        NrY  rA   rt   r
   )r4   r5   r*   r+   r,   r-   r.   )r\   r  r  r0  r4   r   r(  r3   r*   r+   r,   r-   r.   )r   r~   r   r   r   r   r   r   rU  r   r  r  r`  r)   ra  r4   r   r   r   r   r   3  s<   %
z"LxmertForQuestionAnswering.forward)NNNNNNNNNNN)r"   r#   r$   r   rA  r=  r   rc  r<  rF  rE  r   r   r0   r!  r1   re  r"  r   r3   r   r   r%   r   r   r   r   rf    sX    	
rf  )r   r#  rf  r  r
  r   r   ):r/   r   rG   rZ  dataclassesr   typingr   r   r   r   r0   r   Ztorch.nnr   r	   Zactivationsr   r   Zmodeling_utilsr   utilsr   r   r   Zconfiguration_lxmertr   Z
get_loggerr"   rE   rc  r   r&   r3   r6   rd   re   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r  r
  r  r#  rf  __all__r   r   r   r   <module>   sl   
-'/O(=[Q   . &