o
    Zh9                     @   s  d Z ddlmZmZmZ ddlZddlZddlmZ ddlm	Z	m
Z
mZ ddlmZ ddlmZmZmZmZmZ dd	lmZ dd
lmZmZ ddlmZ eeZG dd dejZG dd dejZ ej!j"dd Z#ej!j"dd Z$ej!j"dd Z%ej!j"dd Z&ej!j"dej'de(fddZ)ej!j"dej'dej'fddZ*ej!j"dej'dej'd e(fd!d"Z+ej!j"dej'dej'fd#d$Z,G d%d& d&ejZ-G d'd( d(ejZ.G d)d* d*ejZ/G d+d, d,ejZ0G d-d. d.ejZ1G d/d0 d0ejZ2G d1d2 d2ejZ3eG d3d4 d4eZ4eG d5d6 d6e4Z5G d7d8 d8ejZ6G d9d: d:ejZ7G d;d< d<ejZ8G d=d> d>ejZ9G d?d@ d@ejZ:eG dAdB dBe4Z;G dCdD dDejZ<edEdFG dGdH dHe4Z=eG dIdJ dJe4Z>eG dKdL dLe4Z?g dMZ@dS )NzPyTorch DeBERTa model.    )OptionalTupleUnionN)nn)BCEWithLogitsLossCrossEntropyLossMSELoss   )ACT2FN)BaseModelOutputMaskedLMOutputQuestionAnsweringModelOutputSequenceClassifierOutputTokenClassifierOutput)PreTrainedModel)auto_docstringlogging   )DebertaConfigc                       s*   e Zd ZdZd fdd	Zdd Z  ZS )DebertaLayerNormzBLayerNorm module in the TF style (epsilon inside the square root).-q=c                    s8   t    tt|| _tt|| _|| _	d S N)
super__init__r   	Parametertorchonesweightzerosbiasvariance_epsilon)selfsizeeps	__class__ [/var/www/auris/lib/python3.10/site-packages/transformers/models/deberta/modeling_deberta.pyr   +   s   

zDebertaLayerNorm.__init__c                 C   sj   |j }| }|jddd}|| djddd}|| t|| j  }||}| j| | j	 }|S )NT)Zkeepdim   )
dtypefloatmeanpowr   sqrtr    tor   r   )r!   hidden_statesZ
input_typer,   Zvarianceyr&   r&   r'   forward1   s   
zDebertaLayerNorm.forward)r   __name__
__module____qualname____doc__r   r2   __classcell__r&   r&   r$   r'   r   (   s    r   c                       $   e Zd Z fddZdd Z  ZS )DebertaSelfOutputc                    s>   t    t|j|j| _t|j|j| _t	|j
| _d S r   )r   r   r   Linearhidden_sizedenser   layer_norm_eps	LayerNormDropouthidden_dropout_probdropoutr!   configr$   r&   r'   r   =   s   
zDebertaSelfOutput.__init__c                 C   &   |  |}| |}| || }|S r   r=   rB   r?   r!   r0   Zinput_tensorr&   r&   r'   r2   C      

zDebertaSelfOutput.forwardr4   r5   r6   r   r2   r8   r&   r&   r$   r'   r:   <   s    r:   c                 C   s   |  d}| d}tj|tj| jd}tj|tj|jd}|dddf |dd|d }|d|ddf }|d}|S )a  
    Build relative position according to the query and key

    We assume the absolute position of query \(P_q\) is range from (0, query_size) and the absolute position of key
    \(P_k\) is range from (0, key_size), The relative positions from query to key is \(R_{q \rightarrow k} = P_q -
    P_k\)

    Args:
        query_size (int): the length of query
        key_size (int): the length of key

    Return:
        `torch.LongTensor`: A tensor with shape [1, query_size, key_size]

    r*   deviceNr   r(   r   )r"   r   arangelongrL   viewrepeat	unsqueeze)query_layer	key_layerZ
query_sizeZkey_sizeZq_idsZk_idsZrel_pos_idsr&   r&   r'   build_relative_positionJ   s   

$
rT   c                 C   s*   |  |d|d|d|dgS )Nr   r   r)   r(   expandr"   )c2p_posrR   relative_posr&   r&   r'   c2p_dynamic_expandg      *rY   c                 C   s*   |  |d|d|d|dgS )Nr   r   rJ   rU   )rW   rR   rS   r&   r&   r'   p2c_dynamic_expandl   rZ   r[   c                 C   s*   |  | d d | d|df S )Nr)   rJ   rU   )	pos_indexp2c_attrS   r&   r&   r'   pos_dynamic_expandq   rZ   r^   rR   scale_factorc                 C   s    t t j| dt jd| S )Nr(   r*   )r   r.   tensorr"   r+   )rR   r_   r&   r&   r'   scaled_size_sqrty   s    rb   rS   c                 C   s"   |  d| dkrt| |S |S NrJ   )r"   rT   )rR   rS   rX   r&   r&   r'   
build_rpos~   s   
rd   max_relative_positionsc                 C   s"   t tt| d|d|S rc   )r   ra   minmaxr"   )rR   rS   re   r&   r&   r'   compute_attention_span   s   "rh   c                 C   sR   | d| dkr'|d d d d d d df d}tj| dt|| |dS | S )NrJ   r   r(   r)   dimindex)r"   rQ   r   gatherr^   )r]   rR   rS   rX   r\   r&   r&   r'   uneven_size_corrected   s   "rm   c                       s   e Zd ZdZ fddZdd Z				ddejd	ejd
ede	ej de	ej de	ej de
eje	ej f fddZdejdejdejdejdef
ddZ  ZS )DisentangledSelfAttentiona  
    Disentangled self-attention module

    Parameters:
        config (`str`):
            A model config class instance with the configuration to build a new model. The schema is similar to
            *BertConfig*, for more details, please refer [`DebertaConfig`]

    c                    s  t    |j|j dkrtd|j d|j d|j| _t|j|j | _| j| j | _tj	|j| jd dd| _
ttj| jtjd| _ttj| jtjd| _|jd ur]|jng | _t|d	d| _t|d
d| _| jrtj	|j|jdd| _tj	|j|jdd| _nd | _d | _| jrt|dd| _| jdk r|j| _t|j| _d| jv rtj	|j| jdd| _d| jv rt	|j| j| _t|j| _d S )Nr   zThe hidden size (z6) is not a multiple of the number of attention heads ()r	   Fr   r`   relative_attentiontalking_headre   r(   r   c2pp2c) r   r   r<   num_attention_heads
ValueErrorintZattention_head_sizeZall_head_sizer   r;   in_projr   r   r   r+   q_biasv_biaspos_att_typegetattrrq   rr   head_logits_projhead_weights_projre   max_position_embeddingsr@   rA   pos_dropoutpos_proj
pos_q_projZattention_probs_dropout_probrB   rC   r$   r&   r'   r      s>   




z"DisentangledSelfAttention.__init__c                 C   s4   |  d d | jdf }||}|ddddS )Nr(   r   r)   r   r	   )r"   ru   rO   permute)r!   xZnew_x_shaper&   r&   r'   transpose_for_scores   s   
z.DisentangledSelfAttention.transpose_for_scoresFNr0   attention_maskoutput_attentionsquery_statesrX   rel_embeddingsreturnc                    s  |du r  |} |jddd\}}	}
nZ j jj jd dd fddtdD }t|d | j	|d j
d}t|d	 | j	|d	 j
d}t|d
 | j	|d
 j
d} fdd|||fD \}}	}
|  jddddf  }|
  jddddf  }
d}d	t j }t||}||j	|j
d }t||	dd} jr|dur|durɈ |} ||	|||}|dur|| } jdur |dd
dd	ddd	d
}| }|| t|j
j}tjj|dd} |} jdur |dd
dd	ddd	d
}t||
}|dd
d	d }|  dd d }|!|}|sA|dfS ||fS )a  
        Call the module

        Args:
            hidden_states (`torch.FloatTensor`):
                Input states to the module usually the output from previous layer, it will be the Q,K and V in
                *Attention(Q,K,V)*

            attention_mask (`torch.BoolTensor`):
                An attention mask matrix of shape [*B*, *N*, *N*] where *B* is the batch size, *N* is the maximum
                sequence length in which element [i,j] = *1* means the *i* th token in the input can attend to the *j*
                th token.

            output_attentions (`bool`, *optional*):
                Whether return the attention matrix.

            query_states (`torch.FloatTensor`, *optional*):
                The *Q* state in *Attention(Q,K,V)*.

            relative_pos (`torch.LongTensor`):
                The relative position encoding between the tokens in the sequence. It's of shape [*B*, *N*, *N*] with
                values ranging in [*-max_relative_positions*, *max_relative_positions*].

            rel_embeddings (`torch.FloatTensor`):
                The embedding of relative distances. It's a tensor of shape [\(2 \times
                \text{max_relative_positions}\), *hidden_size*].


        Nr	   r(   rj   r   c                    s0   g | ] t j fd dtjD ddqS )c                    s   g | ]
}|d     qS )r	   r&   ).0i)kwsr&   r'   
<listcomp>   s    z@DisentangledSelfAttention.forward.<locals>.<listcomp>.<listcomp>r   r   )r   catrangeru   )r   r!   r   )r   r'   r      s   0 z5DisentangledSelfAttention.forward.<locals>.<listcomp>r`   r   r)   c                    s   g | ]}  |qS r&   )r   )r   r   r!   r&   r'   r      s    rJ   r(   )"rx   r   chunkr   ru   r   r   matmultr/   r*   ry   rz   lenr{   rb   	transposerq   r   disentangled_att_biasr}   r   boolZmasked_fillZfinforf   r   Z
functionalZsoftmaxrB   r~   
contiguousr"   rO   )r!   r0   r   r   r   rX   r   ZqprR   rS   Zvalue_layerZqkvwqr   vZrel_attr_   scaleZattention_scoresZattention_probsZcontext_layerZnew_context_layer_shaper&   r   r'   r2      sH   &
"""


"
"
z!DisentangledSelfAttention.forwardrR   rS   r_   c                 C   s  |d u rt |||j}| dkr|dd}n| dkr&|d}n| dkr5td|  t||| j}| }|| j| | j| d d f d}d}d| jv r| 	|}| 
|}t||dd	}	t|| d|d d }
tj|	dt|
||d
}	||	7 }d| jv r| |}| 
|}|t|| }t|||}t| | d|d d }t||dd	j|jd}tj|dt|||d
dd	}t||||}||7 }|S )Nr)   r   r	   r      z2Relative position ids must be of dim 2 or 3 or 4. rs   r(   rJ   ri   rt   r`   )rT   rL   rj   rQ   rv   rh   re   rN   r{   r   r   r   r   r   clamprl   rY   r   rb   rd   r/   r*   r[   rm   )r!   rR   rS   rX   r   r_   Zatt_spanZscoreZpos_key_layerZc2p_attrW   Zpos_query_layerZr_posZp2c_posr]   r&   r&   r'   r   $  sT   





z/DisentangledSelfAttention.disentangled_att_biasFNNN)r4   r5   r6   r7   r   r   r   Tensorr   r   r   r2   rw   r   r8   r&   r&   r$   r'   rn      sD    
&	
Wrn   c                       s*   e Zd ZdZ fddZdddZ  ZS )DebertaEmbeddingszGConstruct the embeddings from word, position and token_type embeddings.c                    s   t    t|dd}t|d|j| _tj|j| j|d| _t|dd| _	| j	s,d | _
n	t|j| j| _
|jdkrDt|j| j| _nd | _| j|jkrYtj| j|jdd| _nd | _t|j|j| _t|j| _|| _| jd	t|jd
dd d S )Npad_token_idr   embedding_size)padding_idxposition_biased_inputTFrp   position_ids)r   r(   )
persistent)r   r   r|   r<   r   r   	Embedding
vocab_sizeword_embeddingsr   position_embeddingsr   Ztype_vocab_sizetoken_type_embeddingsr;   
embed_projr   r>   r?   r@   rA   rB   rD   Zregister_bufferr   rM   rV   )r!   rD   r   r$   r&   r'   r   `  s(   


zDebertaEmbeddings.__init__Nc                 C   sH  |d ur	|  }n|  d d }|d }|d u r$| jd d d |f }|d u r3tj|tj| jjd}|d u r<| |}| jd urI| | }nt|}|}	| j	rW|	| }	| j
d ure| 
|}
|	|
 }	| jd uro| |	}	| |	}	|d ur| |	 kr| dkr|dd}|d}||	j}|	| }	| |	}	|	S )Nr(   r   rK   r   r)   )r"   r   r   r   rN   rL   r   r   Z
zeros_liker   r   r   r?   rj   squeezerQ   r/   r*   rB   )r!   	input_idstoken_type_idsr   maskinputs_embedsinput_shapeZ
seq_lengthr   
embeddingsr   r&   r&   r'   r2     s>   










zDebertaEmbeddings.forward)NNNNNr3   r&   r&   r$   r'   r   ]  s    r   c                       sH   e Zd Z fddZ				d	dedeejeej f fddZ	  Z
S )
DebertaAttentionc                    s(   t    t|| _t|| _|| _d S r   )r   r   rn   r!   r:   outputrD   rC   r$   r&   r'   r     s   



zDebertaAttention.__init__FNr   r   c           
      C   sF   | j ||||||d\}}|d u r|}| ||}	|r|	|fS |	d fS )N)r   rX   r   )r!   r   )
r!   r0   r   r   r   rX   r   Zself_output
att_matrixattention_outputr&   r&   r'   r2     s   	
zDebertaAttention.forwardr   r4   r5   r6   r   r   r   r   r   r   r2   r8   r&   r&   r$   r'   r     s    
r   c                       2   e Zd Z fddZdejdejfddZ  ZS )DebertaIntermediatec                    sD   t    t|j|j| _t|jt	rt
|j | _d S |j| _d S r   )r   r   r   r;   r<   intermediate_sizer=   
isinstance
hidden_actstrr
   intermediate_act_fnrC   r$   r&   r'   r     s
   
zDebertaIntermediate.__init__r0   r   c                 C      |  |}| |}|S r   )r=   r   r!   r0   r&   r&   r'   r2        

zDebertaIntermediate.forwardr4   r5   r6   r   r   r   r2   r8   r&   r&   r$   r'   r     s    r   c                       r9   )DebertaOutputc                    sD   t    t|j|j| _t|j|j| _	t
|j| _|| _d S r   )r   r   r   r;   r   r<   r=   r   r>   r?   r@   rA   rB   rD   rC   r$   r&   r'   r     s
   

zDebertaOutput.__init__c                 C   rE   r   rF   rG   r&   r&   r'   r2     rH   zDebertaOutput.forwardrI   r&   r&   r$   r'   r     s    r   c                       sH   e Zd Z fddZ				d	dedeejeej f fddZ	  Z
S )
DebertaLayerc                    s,   t    t|| _t|| _t|| _d S r   )r   r   r   	attentionr   intermediater   r   rC   r$   r&   r'   r     s   


zDebertaLayer.__init__NFr   r   c                 C   sD   | j ||||||d\}}| |}	| |	|}
|r|
|fS |
d fS )Nr   r   rX   r   )r   r   r   )r!   r0   r   r   rX   r   r   r   r   Zintermediate_outputZlayer_outputr&   r&   r'   r2     s   	

zDebertaLayer.forward)NNNFr   r&   r&   r$   r'   r     s    
r   c                       sh   e Zd ZdZ fddZdd Zdd Zdd	d
Z					ddej	dej	de
de
de
f
ddZ  ZS )DebertaEncoderz8Modified BertEncoder with relative position bias supportc                    s~   t    t fddt jD | _t dd| _| jr:t dd| _	| j	dk r/ j
| _	t| j	d  j| _d| _d S )	Nc                    s   g | ]}t  qS r&   )r   r   _rD   r&   r'   r     s    z+DebertaEncoder.__init__.<locals>.<listcomp>rq   Fre   r(   r   r)   )r   r   r   Z
ModuleListr   Znum_hidden_layerslayerr|   rq   re   r   r   r<   r   gradient_checkpointingrC   r$   r   r'   r     s   
 

zDebertaEncoder.__init__c                 C   s   | j r	| jj}|S d }|S r   )rq   r   r   )r!   r   r&   r&   r'   get_rel_embedding  s   z DebertaEncoder.get_rel_embeddingc                 C   sN   |  dkr|dd}||dd }|S |  dkr%|d}|S )Nr)   r   rJ   r(   r	   )rj   rQ   r   )r!   r   Zextended_attention_maskr&   r&   r'   get_attention_mask"  s   
z!DebertaEncoder.get_attention_maskNc                 C   s2   | j r|d u r|d urt||}|S t||}|S r   )rq   rT   )r!   r0   r   rX   r&   r&   r'   get_rel_pos+  s   

zDebertaEncoder.get_rel_posTFr0   r   output_hidden_statesr   return_dictc              
   C   s   |  |}| |||}|r|fnd }|rdnd }	|}
|  }t| jD ]<\}}| jr=| jr=| |j|
|||||\}}n||
|||||d\}}|rP||f }|d urW|}n|}
|r`|	|f }	q$|sot	dd |||	fD S t
|||	dS )Nr&   )r   rX   r   r   c                 s   s    | ]	}|d ur|V  qd S r   r&   )r   r   r&   r&   r'   	<genexpr>g  s    z)DebertaEncoder.forward.<locals>.<genexpr>Zlast_hidden_stater0   
attentions)r   r   r   	enumerater   r   ZtrainingZ_gradient_checkpointing_func__call__tupler   )r!   r0   r   r   r   r   rX   r   Zall_hidden_statesZall_attentionsZnext_kvr   r   Zlayer_moduleZatt_mr&   r&   r'   r2   3  sL   




	

zDebertaEncoder.forward)NN)TFNNT)r4   r5   r6   r7   r   r   r   r   r   r   r   r2   r8   r&   r&   r$   r'   r     s,    
	r   c                   @   s&   e Zd ZeZdZdgZdZdd ZdS )DebertaPreTrainedModeldebertar   Tc                 C   s   t |tjr |jjjd| jjd |jdur|jj	  dS dS t |tj
rC|jjjd| jjd |jdurA|jj|j 	  dS dS t |tjtfrZ|jjd |jj	  dS t |trm|jj	  |jj	  dS t |ttfr||jj	  dS dS )zInitialize the weights.g        )r,   ZstdNg      ?)r   r   r;   r   dataZnormal_rD   Zinitializer_ranger   Zzero_r   r   r?   r   Zfill_rn   ry   rz   LegacyDebertaLMPredictionHeadDebertaLMPredictionHead)r!   moduler&   r&   r'   _init_weightst  s&   


z$DebertaPreTrainedModel._init_weightsN)	r4   r5   r6   r   Zconfig_classZbase_model_prefixZ"_keys_to_ignore_on_load_unexpectedZsupports_gradient_checkpointingr   r&   r&   r&   r'   r   m  s    r   c                       s   e Zd Z fddZdd Zdd Zdd Ze																dd
ee	j
 dee	j
 dee	j
 dee	j
 dee	j
 dee dee dee deeef fddZ  ZS )DebertaModelc                    s8   t  | t|| _t|| _d| _|| _|   d S Nr   )	r   r   r   r   r   encoderz_stepsrD   	post_initrC   r$   r&   r'   r     s   

zDebertaModel.__init__c                 C      | j jS r   r   r   r   r&   r&   r'   get_input_embeddings  s   z!DebertaModel.get_input_embeddingsc                 C   s   || j _d S r   r   r!   Znew_embeddingsr&   r&   r'   set_input_embeddings  s   z!DebertaModel.set_input_embeddingsc                 C   s   t d)z
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        z7The prune function is not implemented in DeBERTa model.)NotImplementedError)r!   Zheads_to_pruner&   r&   r'   _prune_heads  s   zDebertaModel._prune_headsNr   r   r   r   r   r   r   r   r   c	              	      s  |d ur|n j j}|d ur|n j j}|d ur|n j j}|d ur*|d ur*td|d ur9 || | }	n|d urF| d d }	ntd|d urQ|jn|j}
|d u r_tj	|	|
d}|d u rltj
|	tj|
d} j|||||d} j||d||d}|d	 } jd	kr|d
 } fddt jD }|d } j } j|} j|}|d	d  D ]}|||d|||d}|| q|d }|s|f||rd	ndd   S t||r|jnd |jdS )NzDYou cannot specify both input_ids and inputs_embeds at the same timer(   z5You have to specify either input_ids or inputs_embeds)rL   rK   )r   r   r   r   r   T)r   r   r   r   rJ   c                    s   g | ]} j jd  qS r   )r   r   r   r   r&   r'   r     s    z(DebertaModel.forward.<locals>.<listcomp>Fr   r)   r   )rD   r   r   use_return_dictrv   Z%warn_if_padding_and_no_attention_maskr"   rL   r   r   r   rN   r   r   r   r   r   r   r   appendr   r0   r   )r!   r   r   r   r   r   r   r   r   r   rL   Zembedding_outputZencoder_outputsZencoded_layersr0   Zlayersr   r   Zrel_posr   sequence_outputr&   r   r'   r2     sr   


zDebertaModel.forward)NNNNNNNN)r4   r5   r6   r   r   r   r   r   r   r   r   r   r   r   r   r2   r8   r&   r&   r$   r'   r     sB    
	

r   c                       r9   )$LegacyDebertaPredictionHeadTransformc                    sf   t    t|d|j| _t|j| j| _t|j	t
r#t|j	 | _n|j	| _tj| j|jd| _d S )Nr   )r#   )r   r   r|   r<   r   r   r;   r=   r   r   r   r
   transform_act_fnr?   r>   rC   r$   r&   r'   r     s   
z-LegacyDebertaPredictionHeadTransform.__init__c                 C   s"   |  |}| |}| |}|S r   )r=   r   r?   r   r&   r&   r'   r2     s   


z,LegacyDebertaPredictionHeadTransform.forwardrI   r&   r&   r$   r'   r     s    r   c                       s,   e Zd Z fddZdd Zdd Z  ZS )r   c                    s\   t    t|| _t|d|j| _tj| j|j	dd| _
tt|j	| _| j| j
_d S )Nr   Frp   )r   r   r   	transformr|   r<   r   r   r;   r   decoderr   r   r   r   rC   r$   r&   r'   r   	  s   

z&LegacyDebertaLMPredictionHead.__init__c                 C   s   | j | j_ d S r   )r   r   r   r&   r&   r'   _tie_weights  s   z*LegacyDebertaLMPredictionHead._tie_weightsc                 C   r   r   )r   r   r   r&   r&   r'   r2     r   z%LegacyDebertaLMPredictionHead.forward)r4   r5   r6   r   r   r2   r8   r&   r&   r$   r'   r     s    r   c                       r   )LegacyDebertaOnlyMLMHeadc                       t    t|| _d S r   )r   r   r   predictionsrC   r$   r&   r'   r   "     
z!LegacyDebertaOnlyMLMHead.__init__r   r   c                 C   s   |  |}|S r   )r   )r!   r   prediction_scoresr&   r&   r'   r2   &  s   
z LegacyDebertaOnlyMLMHead.forwardr   r&   r&   r$   r'   r   !  s    r   c                       s(   e Zd ZdZ fddZdd Z  ZS )r   zMhttps://github.com/microsoft/DeBERTa/blob/master/DeBERTa/deberta/bert.py#L270c                    sl   t    t|j|j| _t|jtrt	|j | _
n|j| _
tj|j|jdd| _tt|j| _d S )NT)r#   Zelementwise_affine)r   r   r   r;   r<   r=   r   r   r   r
   r   r?   r>   r   r   r   r   r   rC   r$   r&   r'   r   .  s   
z DebertaLMPredictionHead.__init__c                 C   s:   |  |}| |}| |}t||j | j }|S r   )r=   r   r?   r   r   r   r   r   )r!   r0   r   r&   r&   r'   r2   <  s   

zDebertaLMPredictionHead.forwardr3   r&   r&   r$   r'   r   +  s    r   c                       r9   )DebertaOnlyMLMHeadc                    r   r   )r   r   r   lm_headrC   r$   r&   r'   r   G  r   zDebertaOnlyMLMHead.__init__c                 C   s   |  ||}|S r   )r   )r!   r   r   r   r&   r&   r'   r2   L  s   zDebertaOnlyMLMHead.forwardrI   r&   r&   r$   r'   r   F  s    r   c                       s   e Zd ZddgZ fddZdd Zdd Ze																		dd
ee	j
 dee	j
 dee	j
 dee	j
 dee	j
 dee	j
 dee dee dee deeef fddZ  ZS )DebertaForMaskedLMzcls.predictions.decoder.weightzcls.predictions.decoder.biasc                    sP   t  | |j| _t|| _| jrt|| _n
ddg| _t|| _	| 
  d S )Nzlm_predictions.lm_head.weightz)deberta.embeddings.word_embeddings.weight)r   r   legacyr   r   r   cls_tied_weights_keysr   lm_predictionsr   rC   r$   r&   r'   r   U  s   


zDebertaForMaskedLM.__init__c                 C   s   | j r| jjjS | jjjS r   )r  r  r   r   r  r   r=   r   r&   r&   r'   get_output_embeddingsb  s   

z(DebertaForMaskedLM.get_output_embeddingsc                 C   s:   | j r|| jj_|j| jj_d S || jj_|j| jj_d S r   )r  r  r   r   r   r  r   r=   r   r&   r&   r'   set_output_embeddingsh  s
   

z(DebertaForMaskedLM.set_output_embeddingsNr   r   r   r   r   labelsr   r   r   r   c
              
   C   s   |	dur|	n| j j}	| j||||||||	d}
|
d }| jr$| |}n	| || jjj}d}|durDt }||	d| j j
|	d}|	sZ|f|
dd  }|durX|f| S |S t|||
j|
jdS )a  
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
            config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
            loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
        Nr   r   r   r   r   r   r   r   r(   r   losslogitsr0   r   )rD   r   r   r  r  r  r   r   r   rO   r   r   r0   r   )r!   r   r   r   r   r   r  r   r   r   outputsr   r   Zmasked_lm_lossloss_fctr   r&   r&   r'   r2   p  s8   zDebertaForMaskedLM.forward	NNNNNNNNN)r4   r5   r6   r  r   r  r  r   r   r   r   r   r   r   r   r2   r8   r&   r&   r$   r'   r  Q  sH    	

r  c                       s0   e Zd Z fddZdd Zedd Z  ZS )ContextPoolerc                    s4   t    t|j|j| _t|j| _|| _	d S r   )
r   r   r   r;   Zpooler_hidden_sizer=   r@   Zpooler_dropoutrB   rD   rC   r$   r&   r'   r     s   

zContextPooler.__init__c                 C   s8   |d d df }|  |}| |}t| jj |}|S r   )rB   r=   r
   rD   Zpooler_hidden_act)r!   r0   Zcontext_tokenpooled_outputr&   r&   r'   r2     s
   

zContextPooler.forwardc                 C   r   r   )rD   r<   r   r&   r&   r'   
output_dim  s   zContextPooler.output_dim)r4   r5   r6   r   r2   propertyr  r8   r&   r&   r$   r'   r    s
    
r  z
    DeBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the
    pooled output) e.g. for GLUE tasks.
    )Zcustom_introc                       s   e Zd Z fddZdd Zdd Ze									ddeej	 d	eej	 d
eej	 deej	 deej	 deej	 dee
 dee
 dee
 deeef fddZ  ZS ) DebertaForSequenceClassificationc                    s   t  | t|dd}|| _t|| _t|| _| jj}t	
||| _t|dd }|d u r2| jjn|}t	|| _|   d S )N
num_labelsr)   Zcls_dropout)r   r   r|   r  r   r   r  poolerr  r   r;   
classifierrD   rA   r@   rB   r   )r!   rD   r  r  Zdrop_outr$   r&   r'   r     s   

z)DebertaForSequenceClassification.__init__c                 C   s
   | j  S r   )r   r   r   r&   r&   r'   r     s   
z5DebertaForSequenceClassification.get_input_embeddingsc                 C   s   | j | d S r   )r   r   r   r&   r&   r'   r     s   z5DebertaForSequenceClassification.set_input_embeddingsNr   r   r   r   r   r  r   r   r   r   c
              
   C   s:  |	dur|	n| j j}	| j||||||||	d}
|
d }| |}| |}| |}d}|dur| j jdu r| jdkrQt	 }|
d|j}|||
d}n| dks^|ddkr|dk }| }|ddkrt|d||d|d}t|d|
d}t }||
d| j |
d}n^td|}nUtd}||| d  }nC| j jdkrt	 }| jdkr|| | }n+|||}n%| j jdkrt }||
d| j|
d}n| j jdkrt }|||}|	s|f|
dd  }|dur|f| S |S t|||
j|
jd	S )
a  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        N)r   r   r   r   r   r   r   r   r   r(   Z
regressionZsingle_label_classificationZmulti_label_classificationr
  )rD   r   r   r  rB   r  Zproblem_typer  r   r   rO   r/   r*   rj   r"   ZnonzerorN   r   rl   rV   r   r+   ra   Z
LogSoftmaxsumr,   r   r   r   r0   r   )r!   r   r   r   r   r   r  r   r   r   r  Zencoder_layerr  r  r  Zloss_fnZlabel_indexZlabeled_logitsr  Zlog_softmaxr   r&   r&   r'   r2     sh   



 


z(DebertaForSequenceClassification.forwardr  )r4   r5   r6   r   r   r   r   r   r   r   r   r   r   r   r2   r8   r&   r&   r$   r'   r    sF    	

r  c                       s   e Zd Z fddZe									ddeej deej deej deej deej d	eej d
ee dee dee de	e
ef fddZ  ZS )DebertaForTokenClassificationc                    sJ   t  | |j| _t|| _t|j| _t	|j
|j| _|   d S r   )r   r   r  r   r   r   r@   rA   rB   r;   r<   r  r   rC   r$   r&   r'   r   0  s   
z&DebertaForTokenClassification.__init__Nr   r   r   r   r   r  r   r   r   r   c
              
   C   s   |	dur|	n| j j}	| j||||||||	d}
|
d }| |}| |}d}|dur;t }||d| j|d}|	sQ|f|
dd  }|durO|f| S |S t|||
j	|
j
dS )z
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
        Nr	  r   r(   r   r
  )rD   r   r   rB   r  r   rO   r  r   r0   r   )r!   r   r   r   r   r   r  r   r   r   r  r   r  r  r  r   r&   r&   r'   r2   ;  s0   

z%DebertaForTokenClassification.forwardr  )r4   r5   r6   r   r   r   r   r   r   r   r   r   r2   r8   r&   r&   r$   r'   r  .  sB    	

r  c                       s   e Zd Z fddZe										ddeej deej deej deej deej d	eej d
eej dee dee dee de	e
ef fddZ  ZS )DebertaForQuestionAnsweringc                    s<   t  | |j| _t|| _t|j|j| _| 	  d S r   )
r   r   r  r   r   r   r;   r<   
qa_outputsr   rC   r$   r&   r'   r   n  s
   
z$DebertaForQuestionAnswering.__init__Nr   r   r   r   r   start_positionsend_positionsr   r   r   r   c              
   C   sF  |
d ur|
n| j j}
| j|||||||	|
d}|d }| |}|jddd\}}|d }|d }d }|d ur|d urt| dkrN|d}t| dkr[|d}|d}|	d|}|	d|}t
|d}|||}|||}|| d }|
s||f|dd   }|d ur|f| S |S t||||j|jdS )	Nr	  r   r   r(   r   )Zignore_indexr)   )r  start_logits
end_logitsr0   r   )rD   r   r   r  splitr   r   r   r"   r   r   r   r0   r   )r!   r   r   r   r   r   r  r  r   r   r   r  r   r  r  r  Z
total_lossZignored_indexr  Z
start_lossZend_lossr   r&   r&   r'   r2   x  sN   






z#DebertaForQuestionAnswering.forward)
NNNNNNNNNN)r4   r5   r6   r   r   r   r   r   r   r   r   r   r2   r8   r&   r&   r$   r'   r  l  sH    
	

r  )r  r  r  r  r   r   )Ar7   typingr   r   r   r   Ztorch.utils.checkpointr   Ztorch.nnr   r   r   Zactivationsr
   Zmodeling_outputsr   r   r   r   r   Zmodeling_utilsr   utilsr   r   Zconfiguration_debertar   Z
get_loggerr4   loggerModuler   r:   ZjitscriptrT   rY   r[   r^   r   rw   rb   rd   rh   rm   rn   r   r   r   r   r   r   r   r   r   r   r   r   r   r  r  r  r  r  __all__r&   r&   r&   r'   <module>   sv   




 GQ#!]j
Vj=K