o
    Zh`                    @   sv  d Z ddlZddlmZ ddlmZmZmZ ddlZddl	Zddlm
Z
 ddlmZmZmZ ddlmZmZ dd	lmZmZ dd
lmZ ddlmZ ddlmZmZmZ ddlmZ ee Z!eG dd deZ"eG dd deZ#eG dd deZ$eG dd deZ%eG dd deZ&eG dd deZ'eG dd deZ(eG dd deZ)eG dd  d eZ*eG d!d" d"eZ+G d#d$ d$e
j,Z-G d%d& d&e
j,Z.G d'd( d(e
j,Z/G d)d* d*e
j,Z0G d+d, d,e
j,Z1G d-d. d.e
j,Z2G d/d0 d0e
j,Z3G d1d2 d2e
j,Z4G d3d4 d4e
j,Z5G d5d6 d6e
j,Z6G d7d8 d8e
j,Z7G d9d: d:e
j,Z8eG d;d< d<eZ9ed=d>G d?d@ d@e9Z:dAdB Z;G dCdD dDe
j,Z<edEd>G dFdG dGe9Z=edHd>G dIdJ dJe9Z>edKd>G dLdM dMe9Z?edNd>G dOdP dPe9Z@edQd>G dRdS dSe9ZAedTd>G dUdV dVe9ZBeG dWdX dXe9ZCeG dYdZ dZe9ZDg d[ZEdS )\zPyTorch LUKE model.    N)	dataclass)OptionalTupleUnion)nn)BCEWithLogitsLossCrossEntropyLossMSELoss   )ACT2FNgelu)BaseModelOutputBaseModelOutputWithPooling)PreTrainedModel)apply_chunking_to_forward)ModelOutputauto_docstringlogging   )
LukeConfigc                   @   >   e Zd ZU dZdZeej ed< dZ	ee
ejdf  ed< dS )BaseLukeModelOutputWithPoolinga  
    Base class for outputs of the LUKE model.

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        entity_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, entity_length, hidden_size)`):
            Sequence of entity hidden-states at the output of the last layer of the model.
        pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
            Last layer hidden-state of the first token of the sequence (classification token) further processed by a
            Linear layer and a Tanh activation function.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
            plus the initial embedding outputs.
        entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each
            layer plus the initial entity embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length +
            entity_length, sequence_length + entity_length)`. Attentions weights after the attention softmax, used to
            compute the weighted average in the self-attention heads.
    Nentity_last_hidden_state.entity_hidden_states__name__
__module____qualname____doc__r   r   torchFloatTensor__annotations__r   r    r"   r"   U/var/www/auris/lib/python3.10/site-packages/transformers/models/luke/modeling_luke.pyr   %      
 r   c                   @   r   )BaseLukeModelOutputa#  
    Base class for model's outputs, with potential hidden states and attentions.

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        entity_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, entity_length, hidden_size)`):
            Sequence of entity hidden-states at the output of the last layer of the model.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each
            layer plus the initial entity embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    Nr   .r   r   r"   r"   r"   r#   r%   D   r$   r%   c                   @   s   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eej ed< dZeej ed< dZeej ed< dZeeej  ed< dZeeejd	f  ed
< dZeeejd	f  ed< dS )LukeMaskedLMOutputa>	  
    Base class for model's outputs, with potential hidden states and attentions.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            The sum of masked language modeling (MLM) loss and entity prediction loss.
        mlm_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Masked language modeling (MLM) loss.
        mep_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Masked entity prediction (MEP) loss.
        logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        entity_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
            Prediction scores of the entity prediction head (scores for each entity vocabulary token before SoftMax).
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each
            layer plus the initial entity embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    Nlossmlm_lossmep_losslogitsentity_logitshidden_states.r   
attentions)r   r   r   r   r'   r   r   r    r!   r(   r)   r*   r+   r,   r   r   r-   r"   r"   r"   r#   r&   c   s   
 r&   c                   @      e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eeejdf  ed< dZeeejdf  ed< dZeeejdf  ed< dS )	EntityClassificationOutputay  
    Outputs of entity classification models.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Classification loss.
        logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
            Classification scores (before SoftMax).
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
            plus the initial embedding outputs.
        entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each
            layer plus the initial entity embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
            the self-attention heads.
    Nr'   r*   .r,   r   r-   r   r   r   r   r'   r   r   r    r!   r*   r,   r   r   r-   r"   r"   r"   r#   r/         
 r/   c                   @   r.   )	EntityPairClassificationOutputa~  
    Outputs of entity pair classification models.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Classification loss.
        logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
            Classification scores (before SoftMax).
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
            plus the initial embedding outputs.
        entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each
            layer plus the initial entity embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
            the self-attention heads.
    Nr'   r*   .r,   r   r-   r0   r"   r"   r"   r#   r2      r1   r2   c                   @   r.   )	EntitySpanClassificationOutputa  
    Outputs of entity span classification models.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Classification loss.
        logits (`torch.FloatTensor` of shape `(batch_size, entity_length, config.num_labels)`):
            Classification scores (before SoftMax).
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
            plus the initial embedding outputs.
        entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each
            layer plus the initial entity embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
            the self-attention heads.
    Nr'   r*   .r,   r   r-   r0   r"   r"   r"   r#   r3      r1   r3   c                   @   r.   )	LukeSequenceClassifierOutputa  
    Outputs of sentence classification models.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Classification (or regression if config.num_labels==1) loss.
        logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
            Classification (or regression if config.num_labels==1) scores (before SoftMax).
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
        entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each
            layer plus the initial entity embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    Nr'   r*   .r,   r   r-   r0   r"   r"   r"   r#   r4         
 r4   c                   @   r.   )	LukeTokenClassifierOutputa  
    Base class for outputs of token classification models.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided) :
            Classification loss.
        logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`):
            Classification scores (before SoftMax).
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
        entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each
            layer plus the initial entity embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    Nr'   r*   .r,   r   r-   r0   r"   r"   r"   r#   r6     r5   r6   c                   @   s   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eej ed< dZeeejdf  ed< dZeeejdf  ed< dZeeejdf  ed	< dS )
 LukeQuestionAnsweringModelOutputay  
    Outputs of question answering models.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
        start_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
            Span-start scores (before SoftMax).
        end_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
            Span-end scores (before SoftMax).
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
        entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each
            layer plus the initial entity embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    Nr'   start_logits
end_logits.r,   r   r-   )r   r   r   r   r'   r   r   r    r!   r8   r9   r,   r   r   r-   r"   r"   r"   r#   r7   /  s   
 r7   c                   @   r.   )	LukeMultipleChoiceModelOutputa  
    Outputs of multiple choice models.

    Args:
        loss (`torch.FloatTensor` of shape *(1,)*, *optional*, returned when `labels` is provided):
            Classification loss.
        logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`):
            *num_choices* is the second dimension of the input tensors. (see *input_ids* above).

            Classification scores (before SoftMax).
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
        entity_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, entity_length, hidden_size)`. Entity hidden-states of the model at the output of each
            layer plus the initial entity embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    Nr'   r*   .r,   r   r-   r0   r"   r"   r"   r#   r:   T  s   
 r:   c                       s:   e Zd ZdZ fddZ				d	ddZdd Z  ZS )
LukeEmbeddingszV
    Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
    c                    s   t    tj|j|j|jd| _t|j|j| _	t|j
|j| _tj|j|jd| _t|j| _|j| _tj|j|j| jd| _	d S )Npadding_idxZeps)super__init__r   	Embedding
vocab_sizehidden_sizeZpad_token_idword_embeddingsmax_position_embeddingsposition_embeddingstype_vocab_sizetoken_type_embeddings	LayerNormlayer_norm_epsDropouthidden_dropout_probdropoutr=   selfconfig	__class__r"   r#   r@   }  s   
zLukeEmbeddings.__init__Nc           	      C   s   |d u r|d urt || j|j}n| |}|d ur!| }n| d d }|d u r8tj|tj| j	jd}|d u rA| 
|}| |}| |}|| | }| |}| |}|S )Ndtypedevice)"create_position_ids_from_input_idsr=   torV   &create_position_ids_from_inputs_embedssizer   zeroslongposition_idsrD   rF   rH   rI   rM   )	rO   	input_idstoken_type_idsr]   inputs_embedsinput_shaperF   rH   
embeddingsr"   r"   r#   forward  s"   






zLukeEmbeddings.forwardc                 C   sN   |  dd }|d }tj| jd || j d tj|jd}|d|S )z
        We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.

        Args:
            inputs_embeds: torch.Tensor

        Returns: torch.Tensor
        NrS   r   rT   r   )rZ   r   Zaranger=   r\   rV   	unsqueezeexpand)rO   r`   ra   Zsequence_lengthr]   r"   r"   r#   rY     s   	z5LukeEmbeddings.create_position_ids_from_inputs_embeds)NNNN)r   r   r   r   r@   rc   rY   __classcell__r"   r"   rQ   r#   r;   x  s    
!r;   c                       sF   e Zd Zdef fddZ	d
dejdejdeej fdd	Z  Z	S )LukeEntityEmbeddingsrP   c                    s   t    || _tj|j|jdd| _|j|jkr$tj	|j|jdd| _
t|j|j| _t|j|j| _tj|j|jd| _t|j| _d S )Nr   r<   Fbiasr>   )r?   r@   rP   r   rA   entity_vocab_sizeentity_emb_sizeentity_embeddingsrC   Linearentity_embedding_denserE   rF   rG   rH   rI   rJ   rK   rL   rM   rN   rQ   r"   r#   r@     s   
zLukeEntityEmbeddings.__init__N
entity_idsr]   r_   c           	      C   s   |d u r	t |}| |}| jj| jjkr| |}| |jdd}|dk	|
d}|| }t j|dd}||jddjdd }| |}|| | }| |}| |}|S )Nr   )minrS   dimgHz>)r   Z
zeros_likerl   rP   rk   rC   rn   rF   clamptype_asrd   sumrH   rI   rM   )	rO   ro   r]   r_   rl   rF   Zposition_embedding_maskrH   rb   r"   r"   r#   rc     s   





zLukeEntityEmbeddings.forwardN)
r   r   r   r   r@   r   
LongTensorr   rc   rf   r"   r"   rQ   r#   rg     s    rg   c                       4   e Zd Z fddZdd Z			d	ddZ  ZS )
LukeSelfAttentionc                    s   t    |j|j dkrt|dstd|j d|j d|j| _t|j|j | _| j| j | _|j	| _	t
|j| j| _t
|j| j| _t
|j| j| _| j	rpt
|j| j| _t
|j| j| _t
|j| j| _t
|j| _d S )Nr   Zembedding_sizezThe hidden size z4 is not a multiple of the number of attention heads .)r?   r@   rC   num_attention_headshasattr
ValueErrorintattention_head_sizeall_head_sizeuse_entity_aware_attentionr   rm   querykeyvalue	w2e_query	e2w_query	e2e_queryrK   Zattention_probs_dropout_probrM   rN   rQ   r"   r#   r@     s&   

zLukeSelfAttention.__init__c                 C   s6   |  d d | j| jf }|j| }|ddddS )NrS   r      r   r
   )rZ   r|   r   viewpermute)rO   xZnew_x_shaper"   r"   r#   transpose_for_scores  s   
z&LukeSelfAttention.transpose_for_scoresNFc                  C   s  | d}|d u r|}n	tj||gdd}| | |}| | |}	| jr|d ur| | |}
| | |}| | 	|}| | 
|}|d d d d d |d d f }|d d d d d |d d f }|d d d d |d d d f }|d d d d |d d d f }t|
|dd}t||dd}t||dd}t||dd}tj||gdd}tj||gdd}tj||gdd}n| | |}t||dd}|t| j }|d ur|| }tjj|dd}| |}|d ur|| }t||	}|dddd }|  d d | jf }|j| }|d d d |d d f }|d u r>d }n|d d |d d d f }|rU|||f}|S ||f}|S )Nr   rr   rS   rq   r
   r   r   )rZ   r   catr   r   r   r   r   r   r   r   matmulZ	transposemathsqrtr   r   
functionalZsoftmaxrM   r   
contiguousr   r   ) rO   word_hidden_statesr   attention_mask	head_maskoutput_attentions	word_sizeconcat_hidden_statesZ	key_layerZvalue_layerZw2w_query_layerZw2e_query_layerZe2w_query_layerZe2e_query_layerZw2w_key_layerZe2w_key_layerZw2e_key_layerZe2e_key_layerZw2w_attention_scoresZw2e_attention_scoresZe2w_attention_scoresZe2e_attention_scoresZword_attention_scoresZentity_attention_scoresZattention_scoresZquery_layerZattention_probsZcontext_layerZnew_context_layer_shapeZoutput_word_hidden_statesZoutput_entity_hidden_statesoutputsr"   r"   r#   rc   
  sX   
    




zLukeSelfAttention.forwardNNF)r   r   r   r@   r   rc   rf   r"   r"   rQ   r#   rz     s    	rz   c                       8   e Zd Z fddZdejdejdejfddZ  ZS )LukeSelfOutputc                    sB   t    t|j|j| _tj|j|jd| _t|j	| _
d S Nr>   )r?   r@   r   rm   rC   denserI   rJ   rK   rL   rM   rN   rQ   r"   r#   r@   _     
zLukeSelfOutput.__init__r,   input_tensorreturnc                 C   &   |  |}| |}| || }|S rw   r   rM   rI   rO   r,   r   r"   r"   r#   rc   e     

zLukeSelfOutput.forwardr   r   r   r@   r   ZTensorrc   rf   r"   r"   rQ   r#   r   ^      $r   c                       ry   )
LukeAttentionc                    s*   t    t|| _t|| _t | _d S rw   )r?   r@   rz   rO   r   outputsetZpruned_headsrN   rQ   r"   r#   r@   m  s   


zLukeAttention.__init__c                 C      t dNz4LUKE does not support the pruning of attention headsNotImplementedError)rO   Zheadsr"   r"   r#   prune_headss     zLukeAttention.prune_headsNFc                 C   s   | d}| |||||}|d u r|d }|}	ntj|d d dd}tj||gdd}	| ||	}
|
d d d |d d f }|d u rGd }n|
d d |d d d f }||f|dd   }|S )Nr   r   r   rr   )rZ   rO   r   r   r   )rO   r   r   r   r   r   r   Zself_outputsZconcat_self_outputsr   attention_outputZword_attention_outputZentity_attention_outputr   r"   r"   r#   rc   v  s(   
zLukeAttention.forwardr   )r   r   r   r@   r   rc   rf   r"   r"   rQ   r#   r   l  s    r   c                       2   e Zd Z fddZdejdejfddZ  ZS )LukeIntermediatec                    sD   t    t|j|j| _t|jt	rt
|j | _d S |j| _d S rw   )r?   r@   r   rm   rC   intermediate_sizer   
isinstance
hidden_actstrr   intermediate_act_fnrN   rQ   r"   r#   r@     s
   
zLukeIntermediate.__init__r,   r   c                 C   s   |  |}| |}|S rw   )r   r   rO   r,   r"   r"   r#   rc     s   

zLukeIntermediate.forwardr   r"   r"   rQ   r#   r     s    r   c                       r   )
LukeOutputc                    sB   t    t|j|j| _tj|j|jd| _t	|j
| _d S r   )r?   r@   r   rm   r   rC   r   rI   rJ   rK   rL   rM   rN   rQ   r"   r#   r@     r   zLukeOutput.__init__r,   r   r   c                 C   r   rw   r   r   r"   r"   r#   rc     r   zLukeOutput.forwardr   r"   r"   rQ   r#   r     r   r   c                       s4   e Zd Z fddZ			d	ddZdd Z  ZS )
	LukeLayerc                    s:   t    |j| _d| _t|| _t|| _t|| _	d S Nr   )
r?   r@   chunk_size_feed_forwardseq_len_dimr   	attentionr   intermediater   r   rN   rQ   r"   r#   r@     s   


zLukeLayer.__init__NFc                 C   s   | d}| j|||||d}|d u r|d }ntj|d d dd}|dd  }	t| j| j| j|}
|
d d d |d d f }|d u rGd }n|
d d |d d d f }||f|	 }	|	S )Nr   )r   r   r   rr   )rZ   r   r   r   r   feed_forward_chunkr   r   )rO   r   r   r   r   r   r   Zself_attention_outputsZconcat_attention_outputr   layer_outputZword_layer_outputZentity_layer_outputr"   r"   r#   rc     s*   

zLukeLayer.forwardc                 C   s   |  |}| ||}|S rw   )r   r   )rO   r   Zintermediate_outputr   r"   r"   r#   r     s   
zLukeLayer.feed_forward_chunkr   )r   r   r   r@   rc   r   rf   r"   r"   rQ   r#   r     s    
%r   c                       s0   e Zd Z fddZ					dddZ  ZS )	LukeEncoderc                    s:   t     | _t fddt jD | _d| _d S )Nc                    s   g | ]}t  qS r"   )r   ).0_rP   r"   r#   
<listcomp>  s    z(LukeEncoder.__init__.<locals>.<listcomp>F)	r?   r@   rP   r   Z
ModuleListrangenum_hidden_layerslayergradient_checkpointingrN   rQ   r   r#   r@     s   
 
zLukeEncoder.__init__NFTc              	   C   s  |rdnd }|r
dnd }	|rdnd }
t | jD ]I\}}|r'||f }|	|f }	|d ur/|| nd }| jrC| jrC| |j|||||}n||||||}|d }|d urW|d }|r`|
|d f }
q|rm||f }|	|f }	|s}tdd |||
||	fD S t|||
||	dS )Nr"   r   r   r   c                 s       | ]	}|d ur|V  qd S rw   r"   r   vr"   r"   r#   	<genexpr>(      z&LukeEncoder.forward.<locals>.<genexpr>)last_hidden_stater,   r-   r   r   )	enumerater   r   ZtrainingZ_gradient_checkpointing_func__call__tupler%   )rO   r   r   r   r   r   output_hidden_statesreturn_dictZall_word_hidden_statesZall_entity_hidden_statesZall_self_attentionsiZlayer_moduleZlayer_head_maskZlayer_outputsr"   r"   r#   rc     sb   


	

zLukeEncoder.forward)NNFFTr   r   r   r@   rc   rf   r"   r"   rQ   r#   r     s    
r   c                       r   )
LukePoolerc                    s*   t    t|j|j| _t | _d S rw   )r?   r@   r   rm   rC   r   ZTanh
activationrN   rQ   r"   r#   r@   >  s   
zLukePooler.__init__r,   r   c                 C   s(   |d d df }|  |}| |}|S )Nr   )r   r   )rO   r,   Zfirst_token_tensorpooled_outputr"   r"   r#   rc   C  s   

zLukePooler.forwardr   r"   r"   rQ   r#   r   =  s    r   c                       $   e Zd Z fddZdd Z  ZS )EntityPredictionHeadTransformc                    sV   t    t|j|j| _t|jt	rt
|j | _n|j| _tj|j|jd| _d S r   )r?   r@   r   rm   rC   rk   r   r   r   r   r   transform_act_fnrI   rJ   rN   rQ   r"   r#   r@   M  s   
z&EntityPredictionHeadTransform.__init__c                 C   s"   |  |}| |}| |}|S rw   )r   r   rI   r   r"   r"   r#   rc   V  s   


z%EntityPredictionHeadTransform.forwardr   r"   r"   rQ   r#   r   L  s    	r   c                       r   )EntityPredictionHeadc                    sH   t    || _t|| _tj|j|jdd| _	t
t|j| _d S )NFrh   )r?   r@   rP   r   	transformr   rm   rk   rj   decoder	Parameterr   r[   ri   rN   rQ   r"   r#   r@   ^  s
   

zEntityPredictionHead.__init__c                 C   s   |  |}| || j }|S rw   )r   r   ri   r   r"   r"   r#   rc   e  s   
zEntityPredictionHead.forwardr   r"   r"   rQ   r#   r   ]  s    r   c                   @   s0   e Zd ZeZdZdZddgZdej	fddZ
dS )	LukePreTrainedModellukeTr   rg   modulec                 C   s   t |tjr |jjjd| jjd |jdur|jj	  dS dS t |tj
rO|jdkr2|jj	  n|jjjd| jjd |jdurM|jj|j 	  dS dS t |tjrd|jj	  |jjd dS dS )zInitialize the weightsg        )meanZstdNr         ?)r   r   rm   weightdataZnormal_rP   Zinitializer_rangeri   Zzero_rA   Zembedding_dimr=   rI   Zfill_)rO   r   r"   r"   r#   _init_weightss  s    


z!LukePreTrainedModel._init_weightsN)r   r   r   r   Zconfig_classZbase_model_prefixZsupports_gradient_checkpointingZ_no_split_modulesr   Moduler   r"   r"   r"   r#   r   l  s    r   zt
    The bare LUKE model transformer outputting raw hidden-states for both word tokens and entities without any
    )Zcustom_introc                        s  e Zd Zd$dedef fddZdd Zdd	 Zd
d Zdd Z	dd Z
e													d%deej deej deej deej deej deej deej deej deej deej dee dee dee deeef fdd Zd!ejdeej fd"d#Z  ZS )&	LukeModelTrP   add_pooling_layerc                    sN   t  | || _t|| _t|| _t|| _|rt	|nd| _
|   dS )zv
        add_pooling_layer (bool, *optional*, defaults to `True`):
            Whether to add a pooling layer
        N)r?   r@   rP   r;   rb   rg   rl   r   encoderr   pooler	post_init)rO   rP   r   rQ   r"   r#   r@     s   


zLukeModel.__init__c                 C      | j jS rw   rb   rD   rO   r"   r"   r#   get_input_embeddings  r   zLukeModel.get_input_embeddingsc                 C      || j _d S rw   r   rO   r   r"   r"   r#   set_input_embeddings     zLukeModel.set_input_embeddingsc                 C   s   | j j S rw   rl   r   r"   r"   r#   get_entity_embeddings  r   zLukeModel.get_entity_embeddingsc                 C   s   || j _ d S rw   r   r   r"   r"   r#   set_entity_embeddings  r   zLukeModel.set_entity_embeddingsc                 C   r   r   r   )rO   Zheads_to_pruner"   r"   r#   _prune_heads  r   zLukeModel._prune_headsNr^   r   r_   r]   ro   entity_attention_maskentity_token_type_idsentity_position_idsr   r`   r   r   r   r   c              	   C   s  |dur|n| j j}|dur|n| j j}|dur|n| j j}|dur*|
dur*td|dur9| || | }n|
durF|
 dd }ntd|\}}|durU|jn|
j}|du retj	||f|d}|du rrtj
|tj|d}|dur|d}|du rtj	||f|d}|du rtj
||ftj|d}| |	| j j}	| j||||
d}| ||}|du rd}n| |||}| j||||	|||d	}|d
 }| jdur| |nd}|s||f|dd  S t|||j|j|j|jdS )uz  
        entity_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`):
            Indices of entity tokens in the entity vocabulary.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.
        entity_attention_mask (`torch.FloatTensor` of shape `(batch_size, entity_length)`, *optional*):
            Mask to avoid performing attention on padding entity token indices. Mask values selected in `[0, 1]`:

            - 1 for entity tokens that are **not masked**,
            - 0 for entity tokens that are **masked**.
        entity_token_type_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`, *optional*):
            Segment token indices to indicate first and second portions of the entity token inputs. Indices are
            selected in `[0, 1]`:

            - 0 corresponds to a *portion A* entity token,
            - 1 corresponds to a *portion B* entity token.
        entity_position_ids (`torch.LongTensor` of shape `(batch_size, entity_length, max_mention_length)`, *optional*):
            Indices of positions of each input entity in the position embeddings. Selected in the range `[0,
            config.max_position_embeddings - 1]`.

        Examples:

        ```python
        >>> from transformers import AutoTokenizer, LukeModel

        >>> tokenizer = AutoTokenizer.from_pretrained("studio-ousia/luke-base")
        >>> model = LukeModel.from_pretrained("studio-ousia/luke-base")
        # Compute the contextualized entity representation corresponding to the entity mention "Beyoncé"

        >>> text = "Beyoncé lives in Los Angeles."
        >>> entity_spans = [(0, 7)]  # character-based entity span corresponding to "Beyoncé"

        >>> encoding = tokenizer(text, entity_spans=entity_spans, add_prefix_space=True, return_tensors="pt")
        >>> outputs = model(**encoding)
        >>> word_last_hidden_state = outputs.last_hidden_state
        >>> entity_last_hidden_state = outputs.entity_last_hidden_state
        # Input Wikipedia entities to obtain enriched contextualized representations of word tokens

        >>> text = "Beyoncé lives in Los Angeles."
        >>> entities = [
        ...     "Beyoncé",
        ...     "Los Angeles",
        ... ]  # Wikipedia entity titles corresponding to the entity mentions "Beyoncé" and "Los Angeles"
        >>> entity_spans = [
        ...     (0, 7),
        ...     (17, 28),
        ... ]  # character-based entity spans corresponding to "Beyoncé" and "Los Angeles"

        >>> encoding = tokenizer(
        ...     text, entities=entities, entity_spans=entity_spans, add_prefix_space=True, return_tensors="pt"
        ... )
        >>> outputs = model(**encoding)
        >>> word_last_hidden_state = outputs.last_hidden_state
        >>> entity_last_hidden_state = outputs.entity_last_hidden_state
        ```NzDYou cannot specify both input_ids and inputs_embeds at the same timerS   z5You have to specify either input_ids or inputs_embeds)rV   rT   r   )r^   r]   r_   r`   )r   r   r   r   r   r   )r   pooler_outputr,   r-   r   r   )rP   r   r   use_return_dictr~   Z%warn_if_padding_and_no_attention_maskrZ   rV   r   Zonesr[   r\   Zget_head_maskr   rb   get_extended_attention_maskrl   r   r   r   r,   r-   r   r   )rO   r^   r   r_   r]   ro   r   r   r   r   r`   r   r   r   ra   Z
batch_sizeZ
seq_lengthrV   Zentity_seq_lengthZword_embedding_outputextended_attention_maskZentity_embedding_outputZencoder_outputssequence_outputr   r"   r"   r#   rc     sp   I

zLukeModel.forwardword_attention_maskc                 C   s   |}|durt j||gdd}| dkr$|dddddddf }n| dkr7|ddddddf }n	td|j d|j| jd}d	| t | jj }|S )
ac  
        Makes broadcastable attention and causal masks so that future and masked tokens are ignored.

        Arguments:
            word_attention_mask (`torch.LongTensor`):
                Attention mask for word tokens with ones indicating tokens to attend to, zeros for tokens to ignore.
            entity_attention_mask (`torch.LongTensor`, *optional*):
                Attention mask for entity tokens with ones indicating tokens to attend to, zeros for tokens to ignore.

        Returns:
            `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
        NrS   rr   r
   r   z&Wrong shape for attention_mask (shape ))rU   r   )	r   r   rs   r~   shaperX   rU   Zfinforp   )rO   r  r   r   r  r"   r"   r#   r  G  s   z%LukeModel.get_extended_attention_mask)T)NNNNNNNNNNNNN)r   r   r   r   boolr@   r   r   r   r   r   r   r   r   rx   r    r   r   r   rc   r  rf   r"   r"   rQ   r#   r     sp    	

 r   c                 C   s2   |  | }tj|dd|| }| | S )a  
    Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
    are ignored. This is modified from fairseq's `utils.make_positions`.

    Args:
        x: torch.Tensor x:

    Returns: torch.Tensor
    r   rr   )ner   r   Zcumsumru   r\   )r^   r=   maskZincremental_indicesr"   r"   r#   rW   f  s   rW   c                       s0   e Zd ZdZ fddZdd Zdd Z  ZS )
LukeLMHeadz*Roberta Head for masked language modeling.c                    sd   t    t|j|j| _tj|j|jd| _t|j|j	| _
tt|j	| _| j| j
_d S r   )r?   r@   r   rm   rC   r   rI   rJ   
layer_normrB   r   r   r   r[   ri   rN   rQ   r"   r#   r@   z  s   
zLukeLMHead.__init__c                 K   s*   |  |}t|}| |}| |}|S rw   )r   r   r  r   )rO   featureskwargsr   r"   r"   r#   rc     s
   


zLukeLMHead.forwardc                 C   s,   | j jjjdkr| j| j _d S | j j| _d S )Nmeta)r   ri   rV   typer   r"   r"   r#   _tie_weights  s   zLukeLMHead._tie_weights)r   r   r   r   r@   rc   r  rf   r"   r"   rQ   r#   r  w  s
    	
r  z
    The LUKE model with a language modeling head and entity prediction head on top for masked language modeling and
    masked entity prediction.
    c                $       s
  e Zd Zg dZ fddZ fddZdd Zdd	 Ze	
	
	
	
	
	
	
	
	
	
	
	
	
	
	
dde	e
j de	e
j de	e
j de	e
j de	e
j de	e
j de	e
j de	e
j de	e
j de	e
j de	e
j de	e
j de	e de	e de	e deeef f ddZ  ZS )LukeForMaskedLM)zlm_head.decoder.weightzlm_head.decoder.biasz!entity_predictions.decoder.weightc                    s@   t  | t|| _t|| _t|| _t	 | _
|   d S rw   )r?   r@   r   r   r  lm_headr   entity_predictionsr   r   loss_fnr   rN   rQ   r"   r#   r@     s   



zLukeForMaskedLM.__init__c                    s$   t    | | jj| jjj d S rw   )r?   tie_weightsZ_tie_or_clone_weightsr  r   r   rl   r   rQ   r"   r#   r    s   
zLukeForMaskedLM.tie_weightsc                 C   r   rw   r  r   r   r"   r"   r#   get_output_embeddings  r   z%LukeForMaskedLM.get_output_embeddingsc                 C   r   rw   r  )rO   Znew_embeddingsr"   r"   r#   set_output_embeddings  r   z%LukeForMaskedLM.set_output_embeddingsNr^   r   r_   r]   ro   r   r   r   labelsentity_labelsr   r`   r   r   r   r   c                 C   s.  |dur|n| j j}| j||||||||||||dd}d}d}| |j}|	durE|	|j}	| |d| j j	|	d}|du rE|}d}d}|j
durr| |j
}|
durr| |d| j j|
d}|du rn|}n|| }|stdd ||||||j|j|jfD S t||||||j|j|jdS )aC  
        entity_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`):
            Indices of entity tokens in the entity vocabulary.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.
        entity_attention_mask (`torch.FloatTensor` of shape `(batch_size, entity_length)`, *optional*):
            Mask to avoid performing attention on padding entity token indices. Mask values selected in `[0, 1]`:

            - 1 for entity tokens that are **not masked**,
            - 0 for entity tokens that are **masked**.
        entity_token_type_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`, *optional*):
            Segment token indices to indicate first and second portions of the entity token inputs. Indices are
            selected in `[0, 1]`:

            - 0 corresponds to a *portion A* entity token,
            - 1 corresponds to a *portion B* entity token.
        entity_position_ids (`torch.LongTensor` of shape `(batch_size, entity_length, max_mention_length)`, *optional*):
            Indices of positions of each input entity in the position embeddings. Selected in the range `[0,
            config.max_position_embeddings - 1]`.
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
            config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
            loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
        entity_labels (`torch.LongTensor` of shape `(batch_size, entity_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
            config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
            loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
        NTr^   r   r_   r]   ro   r   r   r   r   r`   r   r   r   rS   c                 s   r   rw   r"   r   r"   r"   r#   r     s    
z*LukeForMaskedLM.forward.<locals>.<genexpr>)r'   r(   r)   r*   r+   r,   r   r-   )rP   r  r   r  r   rX   rV   r  r   rB   r   r  rj   r   r,   r   r-   r&   )rO   r^   r   r_   r]   ro   r   r   r   r  r  r   r`   r   r   r   r   r'   r(   r*   r)   r+   r"   r"   r#   rc     sn   1
zLukeForMaskedLM.forwardNNNNNNNNNNNNNNN)r   r   r   Z_tied_weights_keysr@   r  r  r  r   r   r   rx   r    r	  r   r   r&   rc   rf   r"   r"   rQ   r#   r    sn    	

r  z
    The LUKE model with a classification head on top (a linear layer on top of the hidden state of the first entity
    token) for entity classification tasks, such as Open Entity.
    c                "          e Zd Z fddZe														ddeej deej deej deej deej d	eej d
eej deej deej deej deej dee	 dee	 dee	 de
eef fddZ  ZS )LukeForEntityClassificationc                    sJ   t  | t|| _|j| _t|j| _t	|j
|j| _|   d S rw   r?   r@   r   r   
num_labelsr   rK   rL   rM   rm   rC   
classifierr   rN   rQ   r"   r#   r@   2  s   
z$LukeForEntityClassification.__init__Nr^   r   r_   r]   ro   r   r   r   r   r`   r  r   r   r   r   c                 C   s   |dur|n| j j}| j|||||||||	|
||dd}|jdddddf }| |}| |}d}|dur[||j}|jdkrKt	j
||}nt	j
|d|d|}|sntdd |||j|j|jfD S t|||j|j|jd	S )
u
  
        entity_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`):
            Indices of entity tokens in the entity vocabulary.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.
        entity_attention_mask (`torch.FloatTensor` of shape `(batch_size, entity_length)`, *optional*):
            Mask to avoid performing attention on padding entity token indices. Mask values selected in `[0, 1]`:

            - 1 for entity tokens that are **not masked**,
            - 0 for entity tokens that are **masked**.
        entity_token_type_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`, *optional*):
            Segment token indices to indicate first and second portions of the entity token inputs. Indices are
            selected in `[0, 1]`:

            - 0 corresponds to a *portion A* entity token,
            - 1 corresponds to a *portion B* entity token.
        entity_position_ids (`torch.LongTensor` of shape `(batch_size, entity_length, max_mention_length)`, *optional*):
            Indices of positions of each input entity in the position embeddings. Selected in the range `[0,
            config.max_position_embeddings - 1]`.
        labels (`torch.LongTensor` of shape `(batch_size,)` or `(batch_size, num_labels)`, *optional*):
            Labels for computing the classification loss. If the shape is `(batch_size,)`, the cross entropy loss is
            used for the single-label classification. In this case, labels should contain the indices that should be in
            `[0, ..., config.num_labels - 1]`. If the shape is `(batch_size, num_labels)`, the binary cross entropy
            loss is used for the multi-label classification. In this case, labels should only contain `[0, 1]`, where 0
            and 1 indicate false and true, respectively.

        Examples:

        ```python
        >>> from transformers import AutoTokenizer, LukeForEntityClassification

        >>> tokenizer = AutoTokenizer.from_pretrained("studio-ousia/luke-large-finetuned-open-entity")
        >>> model = LukeForEntityClassification.from_pretrained("studio-ousia/luke-large-finetuned-open-entity")

        >>> text = "Beyoncé lives in Los Angeles."
        >>> entity_spans = [(0, 7)]  # character-based entity span corresponding to "Beyoncé"
        >>> inputs = tokenizer(text, entity_spans=entity_spans, return_tensors="pt")
        >>> outputs = model(**inputs)
        >>> logits = outputs.logits
        >>> predicted_class_idx = logits.argmax(-1).item()
        >>> print("Predicted class:", model.config.id2label[predicted_class_idx])
        Predicted class: person
        ```NTr  r   r   rS   c                 s   r   rw   r"   r   r"   r"   r#   r         z6LukeForEntityClassification.forward.<locals>.<genexpr>r'   r*   r,   r   r-   )rP   r  r   r   rM   r#  rX   rV   ndimr   r   cross_entropy binary_cross_entropy_with_logitsr   ru   r   r,   r   r-   r/   rO   r^   r   r_   r]   ro   r   r   r   r   r`   r  r   r   r   r   feature_vectorr*   r'   r"   r"   r#   rc   >  sH   >


 z#LukeForEntityClassification.forwardNNNNNNNNNNNNNN)r   r   r   r@   r   r   r   rx   r    r	  r   r   r/   rc   rf   r"   r"   rQ   r#   r   +  `    	

r   z
    The LUKE model with a classification head on top (a linear layer on top of the hidden states of the two entity
    tokens) for entity pair classification tasks, such as TACRED.
    c                "       s   e Zd Z fddZe														ddeej deej deej deej deej d	eej d
eej deej deej deej deej dee	 dee	 dee	 de
eef fddZ  ZS )LukeForEntityPairClassificationc                    sP   t  | t|| _|j| _t|j| _t	|j
d |jd| _|   d S )Nr   Fr!  rN   rQ   r"   r#   r@     s   
z(LukeForEntityPairClassification.__init__Nr^   r   r_   r]   ro   r   r   r   r   r`   r  r   r   r   r   c                 C   s  |dur|n| j j}| j|||||||||	|
||dd}tj|jdddddf |jdddddf gdd}| |}| |}d}|durl||j	}|j
dkr\tj||}ntj|d|d|}|stdd	 |||j|j|jfD S t|||j|j|jd
S )u  
        entity_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`):
            Indices of entity tokens in the entity vocabulary.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.
        entity_attention_mask (`torch.FloatTensor` of shape `(batch_size, entity_length)`, *optional*):
            Mask to avoid performing attention on padding entity token indices. Mask values selected in `[0, 1]`:

            - 1 for entity tokens that are **not masked**,
            - 0 for entity tokens that are **masked**.
        entity_token_type_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`, *optional*):
            Segment token indices to indicate first and second portions of the entity token inputs. Indices are
            selected in `[0, 1]`:

            - 0 corresponds to a *portion A* entity token,
            - 1 corresponds to a *portion B* entity token.
        entity_position_ids (`torch.LongTensor` of shape `(batch_size, entity_length, max_mention_length)`, *optional*):
            Indices of positions of each input entity in the position embeddings. Selected in the range `[0,
            config.max_position_embeddings - 1]`.
        labels (`torch.LongTensor` of shape `(batch_size,)` or `(batch_size, num_labels)`, *optional*):
            Labels for computing the classification loss. If the shape is `(batch_size,)`, the cross entropy loss is
            used for the single-label classification. In this case, labels should contain the indices that should be in
            `[0, ..., config.num_labels - 1]`. If the shape is `(batch_size, num_labels)`, the binary cross entropy
            loss is used for the multi-label classification. In this case, labels should only contain `[0, 1]`, where 0
            and 1 indicate false and true, respectively.

        Examples:

        ```python
        >>> from transformers import AutoTokenizer, LukeForEntityPairClassification

        >>> tokenizer = AutoTokenizer.from_pretrained("studio-ousia/luke-large-finetuned-tacred")
        >>> model = LukeForEntityPairClassification.from_pretrained("studio-ousia/luke-large-finetuned-tacred")

        >>> text = "Beyoncé lives in Los Angeles."
        >>> entity_spans = [
        ...     (0, 7),
        ...     (17, 28),
        ... ]  # character-based entity spans corresponding to "Beyoncé" and "Los Angeles"
        >>> inputs = tokenizer(text, entity_spans=entity_spans, return_tensors="pt")
        >>> outputs = model(**inputs)
        >>> logits = outputs.logits
        >>> predicted_class_idx = logits.argmax(-1).item()
        >>> print("Predicted class:", model.config.id2label[predicted_class_idx])
        Predicted class: per:cities_of_residence
        ```NTr  r   r   rr   rS   c                 s   r   rw   r"   r   r"   r"   r#   r   %  r$  z:LukeForEntityPairClassification.forward.<locals>.<genexpr>r%  )rP   r  r   r   r   r   rM   r#  rX   rV   r&  r   r   r'  r(  r   ru   r   r,   r   r-   r2   r)  r"   r"   r#   rc     sL   A0


 z'LukeForEntityPairClassification.forwardr+  )r   r   r   r@   r   r   r   rx   r    r	  r   r   r2   rc   rf   r"   r"   rQ   r#   r-    r,  r-  z
    The LUKE model with a span classification head on top (a linear layer on top of the hidden states output) for tasks
    such as named entity recognition.
    c                &       s   e Zd Z fddZe																ddeej deej deej deej deej d	eej d
eej deej deej deej deej deej deej dee	 dee	 dee	 de
eef f"ddZ  ZS )LukeForEntitySpanClassificationc                    sN   t  | t|| _|j| _t|j| _t	|j
d |j| _|   d S )Nr
   r!  rN   rQ   r"   r#   r@   ;  s   
z(LukeForEntitySpanClassification.__init__Nr^   r   r_   r]   ro   r   r   r   entity_start_positionsentity_end_positionsr   r`   r  r   r   r   r   c                 C   s  |dur|n| j j}| j||||||||||||dd}|jd}|	ddd|}	|	j|jjkr:|	|jj}	t	
|jd|	}|
ddd|}
|
j|jjkrZ|
|jj}
t	
|jd|
}t	j|||jgdd}| |}| |}d}|dur||j}|jdkrtj|d| j|d}ntj|d|d|}|stdd	 |||j|j|jfD S t|||j|j|jd
S )u  
        entity_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`):
            Indices of entity tokens in the entity vocabulary.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.
        entity_attention_mask (`torch.FloatTensor` of shape `(batch_size, entity_length)`, *optional*):
            Mask to avoid performing attention on padding entity token indices. Mask values selected in `[0, 1]`:

            - 1 for entity tokens that are **not masked**,
            - 0 for entity tokens that are **masked**.
        entity_token_type_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`, *optional*):
            Segment token indices to indicate first and second portions of the entity token inputs. Indices are
            selected in `[0, 1]`:

            - 0 corresponds to a *portion A* entity token,
            - 1 corresponds to a *portion B* entity token.
        entity_position_ids (`torch.LongTensor` of shape `(batch_size, entity_length, max_mention_length)`, *optional*):
            Indices of positions of each input entity in the position embeddings. Selected in the range `[0,
            config.max_position_embeddings - 1]`.
        entity_start_positions (`torch.LongTensor`):
            The start positions of entities in the word token sequence.
        entity_end_positions (`torch.LongTensor`):
            The end positions of entities in the word token sequence.
        labels (`torch.LongTensor` of shape `(batch_size, entity_length)` or `(batch_size, entity_length, num_labels)`, *optional*):
            Labels for computing the classification loss. If the shape is `(batch_size, entity_length)`, the cross
            entropy loss is used for the single-label classification. In this case, labels should contain the indices
            that should be in `[0, ..., config.num_labels - 1]`. If the shape is `(batch_size, entity_length,
            num_labels)`, the binary cross entropy loss is used for the multi-label classification. In this case,
            labels should only contain `[0, 1]`, where 0 and 1 indicate false and true, respectively.

        Examples:

        ```python
        >>> from transformers import AutoTokenizer, LukeForEntitySpanClassification

        >>> tokenizer = AutoTokenizer.from_pretrained("studio-ousia/luke-large-finetuned-conll-2003")
        >>> model = LukeForEntitySpanClassification.from_pretrained("studio-ousia/luke-large-finetuned-conll-2003")

        >>> text = "Beyoncé lives in Los Angeles"
        # List all possible entity spans in the text

        >>> word_start_positions = [0, 8, 14, 17, 21]  # character-based start positions of word tokens
        >>> word_end_positions = [7, 13, 16, 20, 28]  # character-based end positions of word tokens
        >>> entity_spans = []
        >>> for i, start_pos in enumerate(word_start_positions):
        ...     for end_pos in word_end_positions[i:]:
        ...         entity_spans.append((start_pos, end_pos))

        >>> inputs = tokenizer(text, entity_spans=entity_spans, return_tensors="pt")
        >>> outputs = model(**inputs)
        >>> logits = outputs.logits
        >>> predicted_class_indices = logits.argmax(-1).squeeze().tolist()
        >>> for span, predicted_class_idx in zip(entity_spans, predicted_class_indices):
        ...     if predicted_class_idx != 0:
        ...         print(text[span[0] : span[1]], model.config.id2label[predicted_class_idx])
        Beyoncé PER
        Los Angeles LOC
        ```NTr  rS   rq   r   rr   c                 s   r   rw   r"   r   r"   r"   r#   r     r$  z:LukeForEntitySpanClassification.forward.<locals>.<genexpr>r%  )rP   r  r   r   rZ   rd   re   rV   rX   r   gatherr   r   rM   r#  r&  r   r   r'  r   r"  r(  ru   r   r,   r   r-   r3   )rO   r^   r   r_   r]   ro   r   r   r   r/  r0  r   r`   r  r   r   r   r   rC   Zstart_statesZ
end_statesr*  r*   r'   r"   r"   r#   rc   G  sZ   O


  z'LukeForEntitySpanClassification.forward)NNNNNNNNNNNNNNNN)r   r   r   r@   r   r   r   rx   r    r	  r   r   r3   rc   rf   r"   r"   rQ   r#   r.  4  sl    	

r.  z
    The LUKE Model transformer with a sequence classification/regression head on top (a linear layer on top of the
    pooled output) e.g. for GLUE tasks.
    c                "       r  )LukeForSequenceClassificationc                    sZ   t  | |j| _t|| _t|jd ur|jn|j| _	t
|j|j| _|   d S rw   r?   r@   r"  r   r   r   rK   classifier_dropoutrL   rM   rm   rC   r#  r   rN   rQ   r"   r#   r@     s   
z&LukeForSequenceClassification.__init__Nr^   r   r_   r]   ro   r   r   r   r   r`   r  r   r   r   r   c                 C   s  |dur|n| j j}| j|||||||||	|
||dd}|j}| |}| |}d}|dur||j}| j jdu r_| j	dkrEd| j _n| j	dkr[|j
tjksV|j
tjkr[d| j _nd| j _| j jdkr}t }| j	dkrw|| | }n+|||}n%| j jdkrt }||d| j	|d}n| j jdkrt }|||}|std	d
 |||j|j|jfD S t|||j|j|jdS )a  
        entity_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`):
            Indices of entity tokens in the entity vocabulary.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.
        entity_attention_mask (`torch.FloatTensor` of shape `(batch_size, entity_length)`, *optional*):
            Mask to avoid performing attention on padding entity token indices. Mask values selected in `[0, 1]`:

            - 1 for entity tokens that are **not masked**,
            - 0 for entity tokens that are **masked**.
        entity_token_type_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`, *optional*):
            Segment token indices to indicate first and second portions of the entity token inputs. Indices are
            selected in `[0, 1]`:

            - 0 corresponds to a *portion A* entity token,
            - 1 corresponds to a *portion B* entity token.
        entity_position_ids (`torch.LongTensor` of shape `(batch_size, entity_length, max_mention_length)`, *optional*):
            Indices of positions of each input entity in the position embeddings. Selected in the range `[0,
            config.max_position_embeddings - 1]`.
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        NTr  r   Z
regressionZsingle_label_classificationZmulti_label_classificationrS   c                 s   r   rw   r"   r   r"   r"   r#   r   B  r$  z8LukeForSequenceClassification.forward.<locals>.<genexpr>r%  )rP   r  r   r  rM   r#  rX   rV   Zproblem_typer"  rU   r   r\   r   r	   squeezer   r   r   r   r,   r   r-   r4   )rO   r^   r   r_   r]   ro   r   r   r   r   r`   r  r   r   r   r   r   r*   r'   loss_fctr"   r"   r#   rc     sd   +



"


z%LukeForSequenceClassification.forwardr+  )r   r   r   r@   r   r   r   rx   r    r	  r   r   r4   rc   rf   r"   r"   rQ   r#   r2    r,  r2  z
    The LUKE Model with a token classification head on top (a linear layer on top of the hidden-states output). To
    solve Named-Entity Recognition (NER) task using LUKE, `LukeForEntitySpanClassification` is more suitable than this
    class.
    c                "       r  )LukeForTokenClassificationc                    s^   t  | |j| _t|dd| _t|jd ur|jn|j| _	t
|j|j| _|   d S NF)r   r3  rN   rQ   r"   r#   r@   Y  s   z#LukeForTokenClassification.__init__Nr^   r   r_   r]   ro   r   r   r   r   r`   r  r   r   r   r   c                 C   s   |dur|n| j j}| j|||||||||	|
||dd}|j}| |}| |}d}|durE||j}t }||	d| j
|	d}|sXtdd |||j|j|jfD S t|||j|j|jdS )aM  
        entity_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`):
            Indices of entity tokens in the entity vocabulary.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.
        entity_attention_mask (`torch.FloatTensor` of shape `(batch_size, entity_length)`, *optional*):
            Mask to avoid performing attention on padding entity token indices. Mask values selected in `[0, 1]`:

            - 1 for entity tokens that are **not masked**,
            - 0 for entity tokens that are **masked**.
        entity_token_type_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`, *optional*):
            Segment token indices to indicate first and second portions of the entity token inputs. Indices are
            selected in `[0, 1]`:

            - 0 corresponds to a *portion A* entity token,
            - 1 corresponds to a *portion B* entity token.
        entity_position_ids (`torch.LongTensor` of shape `(batch_size, entity_length, max_mention_length)`, *optional*):
            Indices of positions of each input entity in the position embeddings. Selected in the range `[0,
            config.max_position_embeddings - 1]`.
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
            num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
            `input_ids` above)
        NTr  rS   c                 s   r   rw   r"   r   r"   r"   r#   r     r$  z5LukeForTokenClassification.forward.<locals>.<genexpr>r%  )rP   r  r   r   rM   r#  rX   rV   r   r   r"  r   r,   r   r-   r6   )rO   r^   r   r_   r]   ro   r   r   r   r   r`   r  r   r   r   r   r  r*   r'   r6  r"   r"   r#   rc   f  sF   +

z"LukeForTokenClassification.forwardr+  )r   r   r   r@   r   r   r   rx   r    r	  r   r   r6   rc   rf   r"   r"   rQ   r#   r7  Q  s`    	

r7  c                $       s   e Zd Z fddZe															ddeej deej deej deej deej d	eej d
eej deej deej deej deej deej dee	 dee	 dee	 de
eef f ddZ  ZS )LukeForQuestionAnsweringc                    s@   t  | |j| _t|dd| _t|j|j| _| 	  d S r8  )
r?   r@   r"  r   r   r   rm   rC   
qa_outputsr   rN   rQ   r"   r#   r@     s
   z!LukeForQuestionAnswering.__init__Nr^   r   r_   r]   ro   r   r   r   r   r`   start_positionsend_positionsr   r   r   r   c                 C   sD  |dur|n| j j}| j|||||||||	|
||dd}|j}| |}|jddd\}}|d}|d}d}|dur|durt| dkrN|d}t| dkr[|d}|d}|	d| |	d| t
|d}|||}|||}|| d	 }|std
d ||||j|j|jfD S t||||j|j|jdS )a  
        entity_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`):
            Indices of entity tokens in the entity vocabulary.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.
        entity_attention_mask (`torch.FloatTensor` of shape `(batch_size, entity_length)`, *optional*):
            Mask to avoid performing attention on padding entity token indices. Mask values selected in `[0, 1]`:

            - 1 for entity tokens that are **not masked**,
            - 0 for entity tokens that are **masked**.
        entity_token_type_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`, *optional*):
            Segment token indices to indicate first and second portions of the entity token inputs. Indices are
            selected in `[0, 1]`:

            - 0 corresponds to a *portion A* entity token,
            - 1 corresponds to a *portion B* entity token.
        entity_position_ids (`torch.LongTensor` of shape `(batch_size, entity_length, max_mention_length)`, *optional*):
            Indices of positions of each input entity in the position embeddings. Selected in the range `[0,
            config.max_position_embeddings - 1]`.
        NTr  r   rS   rr   r   )Zignore_indexr   c                 s   r   rw   r"   r   r"   r"   r#   r     s    	z3LukeForQuestionAnswering.forward.<locals>.<genexpr>)r'   r8   r9   r,   r   r-   )rP   r  r   r   r:  splitr5  lenrZ   Zclamp_r   r   r,   r   r-   r7   )rO   r^   r   r_   r]   ro   r   r   r   r   r`   r;  r<  r   r   r   r   r  r*   r8   r9   Z
total_lossZignored_indexr6  Z
start_lossZend_lossr"   r"   r#   rc     sh   (








z LukeForQuestionAnswering.forwardr  )r   r   r   r@   r   r   r   rx   r    r	  r   r   r7   rc   rf   r"   r"   rQ   r#   r9    sf    	

r9  c                "       r  )LukeForMultipleChoicec                    sP   t  | t|| _t|jd ur|jn|j| _t	|j
d| _|   d S r   )r?   r@   r   r   r   rK   r4  rL   rM   rm   rC   r#  r   rN   rQ   r"   r#   r@   8  s   
zLukeForMultipleChoice.__init__Nr^   r   r_   r]   ro   r   r   r   r   r`   r  r   r   r   r   c                 C   s  |dur|n| j j}|dur|jd n|
jd }|dur%|d|dnd}|dur4|d|dnd}|durC|d|dnd}|durR|d|dnd}|
dure|
d|
d|
dnd}
|durt|d|dnd}|dur|d|dnd}|dur|d|dnd}|dur|d|d|dnd}| j|||||||||	|
||dd}|j}| |}| |}|d|}d}|dur|	|j
}t }|||}|stdd |||j|j|jfD S t|||j|j|jd	S )
a^  
        input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
            Indices of input sequence tokens in the vocabulary.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
            Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
            1]`:

            - 0 corresponds to a *sentence A* token,
            - 1 corresponds to a *sentence B* token.

            [What are token type IDs?](../glossary#token-type-ids)
        position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
            config.max_position_embeddings - 1]`.

            [What are position IDs?](../glossary#position-ids)
        entity_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`):
            Indices of entity tokens in the entity vocabulary.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.
        entity_attention_mask (`torch.FloatTensor` of shape `(batch_size, entity_length)`, *optional*):
            Mask to avoid performing attention on padding entity token indices. Mask values selected in `[0, 1]`:

            - 1 for entity tokens that are **not masked**,
            - 0 for entity tokens that are **masked**.
        entity_token_type_ids (`torch.LongTensor` of shape `(batch_size, entity_length)`, *optional*):
            Segment token indices to indicate first and second portions of the entity token inputs. Indices are
            selected in `[0, 1]`:

            - 0 corresponds to a *portion A* entity token,
            - 1 corresponds to a *portion B* entity token.
        entity_position_ids (`torch.LongTensor` of shape `(batch_size, entity_length, max_mention_length)`, *optional*):
            Indices of positions of each input entity in the position embeddings. Selected in the range `[0,
            config.max_position_embeddings - 1]`.
        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
            model's internal embedding lookup matrix.
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
            num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
            `input_ids` above)
        Nr   rS   rq   Tr  c                 s   r   rw   r"   r   r"   r"   r#   r     r   z0LukeForMultipleChoice.forward.<locals>.<genexpr>r%  )rP   r  r  r   rZ   r   r  rM   r#  rX   rV   r   r   r,   r   r-   r:   )rO   r^   r   r_   r]   ro   r   r   r   r   r`   r  r   r   r   Znum_choicesr   r   r*   Zreshaped_logitsr'   r6  r"   r"   r#   rc   D  s~   C


zLukeForMultipleChoice.forwardr+  )r   r   r   r@   r   r   r   rx   r    r	  r   r   r:   rc   rf   r"   r"   rQ   r#   r?  6  s`    	

r?  )
r   r-  r.  r?  r9  r2  r7  r  r   r   )Fr   r   dataclassesr   typingr   r   r   r   Ztorch.utils.checkpointr   Ztorch.nnr   r   r	   Zactivationsr   r   Zmodeling_outputsr   r   Zmodeling_utilsr   Zpytorch_utilsr   utilsr   r   r   Zconfiguration_luker   Z
get_loggerr   loggerr   r%   r&   r/   r2   r3   r4   r6   r7   r:   r   r;   rg   rz   r   r   r   r   r   r   r   r   r   r   r   rW   r  r  r   r-  r.  r2  r7  r9  r?  __all__r"   r"   r"   r#   <module>   s   
*!!$#I+r04O ] |  xgv "