o
    ZhW                    @   s  d Z ddlZddlZddlmZ ddlmZmZmZm	Z	m
Z
 ddlZddlZddlmZ ddlmZmZmZ ddlmZmZ dd	lmZ dd
lmZmZmZmZmZmZmZmZ ddl m!Z! ddl"m#Z#m$Z$m%Z% ddl&m'Z'm(Z(m)Z) ddl*m+Z+ e),e-Z.dGddZ/G dd dej0Z1G dd dej0Z2G dd dej0Z3de2iZ4G dd dej0Z5G dd dej0Z6G dd dej0Z7G d d! d!ej0Z8G d"d# d#ej0Z9G d$d% d%ej0Z:G d&d' d'ej0Z;e(G d(d) d)e!Z<eG d*d+ d+e'Z=e(G d,d- d-e<Z>G d.d/ d/ej0Z?G d0d1 d1ej0Z@e(d2d3G d4d5 d5e<ZAe(d6d3G d7d8 d8e<ZBe(d9d3G d:d; d;e<ZCe(d<d3G d=d> d>e<ZDe(G d?d@ d@e<ZEe(G dAdB dBe<ZFe(dCd3G dDdE dEe<eZGg dFZHdS )HzPyTorch ELECTRA model.    N)	dataclass)CallableListOptionalTupleUnion)nn)BCEWithLogitsLossCrossEntropyLossMSELoss   )ACT2FNget_activation)GenerationMixin)"BaseModelOutputWithCrossAttentions)BaseModelOutputWithPastAndCrossAttentions!CausalLMOutputWithCrossAttentionsMaskedLMOutputMultipleChoiceModelOutputQuestionAnsweringModelOutputSequenceClassifierOutputTokenClassifierOutput)PreTrainedModel)apply_chunking_to_forward find_pruneable_heads_and_indicesprune_linear_layer)ModelOutputauto_docstringlogging   )ElectraConfigdiscriminatorc                 C   s  zddl }ddl}ddl}W n ty   td  w tj|}t	d|  |j
|}g }	g }
|D ] \}}t	d| d|  |j
||}|	| |
| q6t|	|
D ]#\}}|}z t| trp|dd}|d	kr|d
d}|dd
}|dd}|dd}|d}tdd |D rt	d|  W q\| }|D ]f}|d|r|d|}n|g}|d dks|d dkrt|d}n1|d dks|d dkrt|d}n|d dkrt|d}n|d dkrt|d}nt||d }t|d krt|d! }|| }q|d"rt|d}n
|dkr%||}z|j|jkr:td#|j d$|j d%W n tyT } z| j|j|jf7  _ d}~ww td&| | t||_ W q\ t!y } ztd| || W Y d}~q\d}~ww | S )'z'Load tf checkpoints in a pytorch model.r   NzLoading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see https://www.tensorflow.org/install/ for installation instructions.z&Converting TensorFlow checkpoint from zLoading TF weight z with shape zelectra/embeddings/zgenerator/embeddings/	generatorzelectra/zdiscriminator/z
generator/Zdense_1dense_predictionz!generator_predictions/output_biaszgenerator_lm_head/bias/c                 s   s    | ]}|d v V  qdS ))Zglobal_stepZtemperatureN ).0nr%   r%   [/var/www/auris/lib/python3.10/site-packages/transformers/models/electra/modeling_electra.py	<genexpr>^   s    z-load_tf_weights_in_electra.<locals>.<genexpr>z	Skipping z[A-Za-z]+_\d+z_(\d+)ZkernelgammaweightZoutput_biasbetabiasZoutput_weightsZsquad
classifier   r   Z_embeddingszPointer shape z and array shape z mismatchedzInitialize PyTorch weight )"renumpyZ
tensorflowImportErrorloggererrorospathabspathinfotrainZlist_variablesZload_variableappendzip
isinstanceElectraForMaskedLMreplacesplitany	fullmatchgetattrlenintendswith	transposeshape
ValueErrorargsprinttorchZ
from_numpydataAttributeError)modelconfigZtf_checkpoint_pathZdiscriminator_or_generatorr0   nptfZtf_pathZ	init_varsnamesZarraysnamerG   arrayoriginal_nameZpointerZm_nameZscope_namesnumer%   r%   r(   load_tf_weights_in_electra4   s   




rX   c                       sh   e Zd ZdZ fddZ					ddeej deej deej d	eej d
e	dej
fddZ  ZS )ElectraEmbeddingszGConstruct the embeddings from word, position and token_type embeddings.c                    s   t    tj|j|j|jd| _t|j|j| _	t|j
|j| _tj|j|jd| _t|j| _| jdt|jddd t|dd| _| jd	tj| j tjd
dd d S )N)padding_idxZepsposition_ids)r   F)
persistentposition_embedding_typeabsolutetoken_type_idsdtype)super__init__r   	Embedding
vocab_sizeembedding_sizeZpad_token_idword_embeddingsmax_position_embeddingsposition_embeddingsZtype_vocab_sizetoken_type_embeddings	LayerNormlayer_norm_epsDropouthidden_dropout_probdropoutZregister_bufferrK   arangeexpandrB   r_   zerosr\   sizelongselfrO   	__class__r%   r(   re      s   

zElectraEmbeddings.__init__Nr   	input_idsra   r\   inputs_embedspast_key_values_lengthreturnc                 C   s   |d ur	|  }n|  d d }|d }|d u r&| jd d ||| f }|d u rPt| drE| jd d d |f }||d |}	|	}ntj|tj| jjd}|d u rY| 	|}| 
|}
||
 }| jdkrp| |}||7 }| |}| |}|S )Nr]   r   ra   r   rc   devicer`   )ru   r\   hasattrra   rs   rK   rt   rv   r   ri   rl   r_   rk   rm   rq   )rx   r{   ra   r\   r|   r}   input_shape
seq_lengthbuffered_token_type_ids buffered_token_type_ids_expandedrl   
embeddingsrk   r%   r%   r(   forward   s,   







zElectraEmbeddings.forward)NNNNr   )__name__
__module____qualname____doc__re   r   rK   
LongTensorFloatTensorrD   Tensorr   __classcell__r%   r%   ry   r(   rY      s*    rY   c                       s   e Zd Zd fdd	ZdejdejfddZ						dd	ejd
eej deej deej deej dee	e	ej   dee
 de	ej fddZ  ZS )ElectraSelfAttentionNc                    s   t    |j|j dkrt|dstd|j d|j d|j| _t|j|j | _| j| j | _t	
|j| j| _t	
|j| j| _t	
|j| j| _t	|j| _|p\t|dd| _| jdksh| jd	kry|j| _t	d
|j d | j| _|j| _d S )Nr   rh   zThe hidden size (z6) is not a multiple of the number of attention heads ()r_   r`   relative_keyrelative_key_queryr/   r   )rd   re   hidden_sizenum_attention_headsr   rH   rD   attention_head_sizeall_head_sizer   Linearquerykeyvaluero   Zattention_probs_dropout_probrq   rB   r_   rj   rf   distance_embedding
is_decoderrx   rO   r_   ry   r%   r(   re      s*   

zElectraSelfAttention.__init__xr~   c                 C   s6   |  d d | j| jf }||}|ddddS )Nr]   r   r/   r   r   )ru   r   r   viewpermute)rx   r   Znew_x_shaper%   r%   r(   transpose_for_scores   s   
z)ElectraSelfAttention.transpose_for_scoresFhidden_statesattention_mask	head_maskencoder_hidden_statesencoder_attention_maskpast_key_valueoutput_attentionsc                 C   s  |  |}|d u}	|	r|d ur|d }
|d }|}nP|	r/| | |}
| | |}|}n;|d urZ| | |}
| | |}tj|d |
gdd}
tj|d |gdd}n| | |}
| | |}| |}|d u}| jrz|
|f}t||
dd}| j	dks| j	dkr	|j
d |
j
d }}|rtj|d tj|jd	dd}ntj|tj|jd	dd}tj|tj|jd	dd}|| }| || j d }|j|jd
}| j	dkrtd||}|| }n| j	dkr	td||}td|
|}|| | }|t| j }|d ur|| }tjj|dd}| |}|d ur0|| }t||}|dddd }| d d | jf }||}|rX||fn|f}| jrd||f }|S )Nr   r   r/   dimr]   r   r   r   rb   zbhld,lrd->bhlrzbhrd,lrd->bhlrr   ) r   r   r   r   rK   catr   matmulrF   r_   rG   Ztensorrv   r   r   rr   r   rj   torc   Zeinsummathsqrtr   r   Z
functionalZsoftmaxrq   r   
contiguousru   r   )rx   r   r   r   r   r   r   r   Zmixed_query_layerZis_cross_attentionZ	key_layerZvalue_layerZquery_layer	use_cacheZattention_scoresZquery_lengthZ
key_lengthZposition_ids_lZposition_ids_rZdistanceZpositional_embeddingZrelative_position_scoresZrelative_position_scores_queryZrelative_position_scores_keyZattention_probsZcontext_layerZnew_context_layer_shapeoutputsr%   r%   r(   r      sn   









zElectraSelfAttention.forwardNNNNNNF)r   r   r   re   rK   r   r   r   r   r   boolr   r   r%   r%   ry   r(   r      s4    	r   c                       8   e Zd Z fddZdejdejdejfddZ  ZS )ElectraSelfOutputc                    sB   t    t|j|j| _tj|j|jd| _t|j	| _
d S Nr[   )rd   re   r   r   r   denserm   rn   ro   rp   rq   rw   ry   r%   r(   re   Q     
zElectraSelfOutput.__init__r   input_tensorr~   c                 C   &   |  |}| |}| || }|S r   r   rq   rm   rx   r   r   r%   r%   r(   r   W     

zElectraSelfOutput.forwardr   r   r   re   rK   r   r   r   r%   r%   ry   r(   r   P      $r   eagerc                       s   e Zd Zd fdd	Zdd Z						ddejdeej d	eej d
eej deej dee	e	ej   dee
 de	ej fddZ  ZS )ElectraAttentionNc                    s4   t    t|j ||d| _t|| _t | _d S )Nr_   )	rd   re   ELECTRA_SELF_ATTENTION_CLASSESZ_attn_implementationrx   r   outputsetpruned_headsr   ry   r%   r(   re   e  s   

zElectraAttention.__init__c                 C   s   t |dkrd S t|| jj| jj| j\}}t| jj|| j_t| jj|| j_t| jj	|| j_	t| j
j|dd| j
_| jjt | | j_| jj| jj | j_| j|| _d S )Nr   r   r   )rC   r   rx   r   r   r   r   r   r   r   r   r   r   union)rx   headsindexr%   r%   r(   prune_headsm  s   zElectraAttention.prune_headsFr   r   r   r   r   r   r   r~   c              	   C   s<   |  |||||||}| |d |}	|	f|dd   }
|
S )Nr   r   )rx   r   )rx   r   r   r   r   r   r   r   Zself_outputsattention_outputr   r%   r%   r(   r     s   
	zElectraAttention.forwardr   r   )r   r   r   re   r   rK   r   r   r   r   r   r   r   r%   r%   ry   r(   r   d  s4    	r   c                       s2   e Zd Z fddZdejdejfddZ  ZS )ElectraIntermediatec                    sD   t    t|j|j| _t|jt	rt
|j | _d S |j| _d S r   )rd   re   r   r   r   intermediate_sizer   r<   
hidden_actstrr   intermediate_act_fnrw   ry   r%   r(   re     s
   
zElectraIntermediate.__init__r   r~   c                 C   s   |  |}| |}|S r   )r   r   )rx   r   r%   r%   r(   r     s   

zElectraIntermediate.forwardr   r%   r%   ry   r(   r     s    r   c                       r   )ElectraOutputc                    sB   t    t|j|j| _tj|j|jd| _t	|j
| _d S r   )rd   re   r   r   r   r   r   rm   rn   ro   rp   rq   rw   ry   r%   r(   re     r   zElectraOutput.__init__r   r   r~   c                 C   r   r   r   r   r%   r%   r(   r     r   zElectraOutput.forwardr   r%   r%   ry   r(   r     r   r   c                       s   e Zd Z fddZ						ddejdeej deej deej d	eej d
eeeej   dee	 deej fddZ
dd Z  ZS )ElectraLayerc                    sr   t    |j| _d| _t|| _|j| _|j| _| jr-| js&t|  dt|dd| _	t
|| _t|| _d S )Nr   z> should be used as a decoder model if cross attention is addedr`   r   )rd   re   chunk_size_feed_forwardseq_len_dimr   	attentionr   add_cross_attentionrH   crossattentionr   intermediater   r   rw   ry   r%   r(   re     s   


zElectraLayer.__init__NFr   r   r   r   r   r   r   r~   c              	   C   s  |d ur
|d d nd }| j |||||d}	|	d }
| jr(|	dd }|	d }n|	dd  }d }| jro|d urot| dsDtd|  d|d urN|d	d  nd }| |
||||||}|d }
||dd  }|d }|| }t| j| j| j|
}|f| }| jr||f }|S )
Nr/   )r   r   r   r   r]   r   z'If `encoder_hidden_states` are passed, z` has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`r   )	r   r   r   rH   r   r   feed_forward_chunkr   r   )rx   r   r   r   r   r   r   r   Zself_attn_past_key_valueZself_attention_outputsr   r   Zpresent_key_valueZcross_attn_present_key_valueZcross_attn_past_key_valueZcross_attention_outputslayer_outputr%   r%   r(   r     sP   


	

zElectraLayer.forwardc                 C   s   |  |}| ||}|S r   )r   r   )rx   r   Zintermediate_outputr   r%   r%   r(   r     s   
zElectraLayer.feed_forward_chunkr   )r   r   r   re   rK   r   r   r   r   r   r   r   r   r%   r%   ry   r(   r     s4    	
Ar   c                       s   e Zd Z fddZ									ddejdeej deej d	eej d
eej deeeej   dee	 dee	 dee	 dee	 de
eej ef fddZ  ZS )ElectraEncoderc                    s:   t     | _t fddt jD | _d| _d S )Nc                    s   g | ]}t  qS r%   )r   )r&   _rO   r%   r(   
<listcomp>  s    z+ElectraEncoder.__init__.<locals>.<listcomp>F)	rd   re   rO   r   Z
ModuleListrangenum_hidden_layerslayergradient_checkpointingrw   ry   r   r(   re     s   
 
zElectraEncoder.__init__NFTr   r   r   r   r   past_key_valuesr   r   output_hidden_statesreturn_dictr~   c                 C   s^  |	rdnd }|r
dnd }|r| j jrdnd }| jr%| jr%|r%td d}|r)dnd }t| jD ]^\}}|	r;||f }|d urC|| nd }|d urM|| nd }| jrc| jrc| |j	|||||||}n
||||||||}|d }|rz||d f7 }|r||d f }| j jr||d f }q0|	r||f }|
st
dd	 |||||fD S t|||||d
S )Nr%   zZ`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...Fr   r]   r   r/   c                 s   s    | ]	}|d ur|V  qd S r   r%   )r&   vr%   r%   r(   r)   W  s    z)ElectraEncoder.forward.<locals>.<genexpr>)Zlast_hidden_stater   r   
attentionscross_attentions)rO   r   r   Ztrainingr3   Zwarning_once	enumerater   Z_gradient_checkpointing_func__call__tupler   )rx   r   r   r   r   r   r   r   r   r   r   Zall_hidden_statesZall_self_attentionsZall_cross_attentionsZnext_decoder_cacheiZlayer_moduleZlayer_head_maskr   Zlayer_outputsr%   r%   r(   r     sz   


zElectraEncoder.forward)	NNNNNNFFT)r   r   r   re   rK   r   r   r   r   r   r   r   r   r   r%   r%   ry   r(   r     sD    		
r   c                       (   e Zd ZdZ fddZdd Z  ZS )ElectraDiscriminatorPredictionszEPrediction module for the discriminator, made up of two dense layers.c                    sB   t    t|j|j| _t|j| _t|jd| _	|| _
d S Nr   )rd   re   r   r   r   r   r   r   
activationr#   rO   rw   ry   r%   r(   re   n  s
   

z(ElectraDiscriminatorPredictions.__init__c                 C   s(   |  |}| |}| |d}|S )Nr]   )r   r   r#   squeeze)rx   discriminator_hidden_statesr   logitsr%   r%   r(   r   v  s   

z'ElectraDiscriminatorPredictions.forwardr   r   r   r   re   r   r   r%   r%   ry   r(   r   k  s    r   c                       r   )ElectraGeneratorPredictionszAPrediction module for the generator, made up of two dense layers.c                    s>   t    td| _tj|j|jd| _t|j	|j| _
d S )Ngelur[   )rd   re   r   r   r   rm   rh   rn   r   r   r   rw   ry   r%   r(   re     s   

z$ElectraGeneratorPredictions.__init__c                 C   s"   |  |}| |}| |}|S r   )r   r   rm   )rx   generator_hidden_statesr   r%   r%   r(   r     s   


z#ElectraGeneratorPredictions.forwardr   r%   r%   ry   r(   r   ~  s    r   c                   @   s$   e Zd ZeZeZdZdZdd Z	dS )ElectraPreTrainedModelelectraTc                 C   s   t |tjr |jjjd| jjd |jdur|jj	  dS dS t |tj
rC|jjjd| jjd |jdurA|jj|j 	  dS dS t |tjrX|jj	  |jjd dS dS )zInitialize the weightsg        )meanZstdNg      ?)r<   r   r   r+   rL   Znormal_rO   Zinitializer_ranger-   Zzero_rf   rZ   rm   Zfill_)rx   moduler%   r%   r(   _init_weights  s   

z$ElectraPreTrainedModel._init_weightsN)
r   r   r   r    config_classrX   Zload_tf_weightsbase_model_prefixZsupports_gradient_checkpointingr   r%   r%   r%   r(   r     s    r   c                   @   sb   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eeej  ed< dZeeej  ed< dS )ElectraForPreTrainingOutputa  
    Output type of [`ElectraForPreTraining`].

    Args:
        loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
            Total loss of the ELECTRA objective.
        logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
            Prediction scores of the head (scores for each token before SoftMax).
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    Nlossr   r   r   )r   r   r   r   r  r   rK   r   __annotations__r   r   r   r   r%   r%   r%   r(   r    s   
 r  c                        s   e Zd Z fddZdd Zdd Zdd Ze																										dd
ee	j
 dee	j
 dee	j
 dee	j
 dee	j
 dee	j
 dee	j
 dee	j
 deee	j  dee dee dee dee deee	j
 ef fddZ  ZS )ElectraModelc                    sP   t  | t|| _|j|jkrt|j|j| _t	|| _
|| _|   d S r   )rd   re   rY   r   rh   r   r   r   embeddings_projectr   encoderrO   	post_initrw   ry   r%   r(   re     s   

zElectraModel.__init__c                 C   s   | j jS r   r   ri   rx   r%   r%   r(   get_input_embeddings  s   z!ElectraModel.get_input_embeddingsc                 C   s   || j _d S r   r
  )rx   r   r%   r%   r(   set_input_embeddings  s   z!ElectraModel.set_input_embeddingsc                 C   s*   |  D ]\}}| jj| j| qdS )z
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        N)itemsr  r   r   r   )rx   Zheads_to_pruner   r   r%   r%   r(   _prune_heads  s   zElectraModel._prune_headsNr{   r   ra   r\   r   r|   r   r   r   r   r   r   r   r~   c                 C   s  |d ur|n| j j}|d ur|n| j j}|d ur|n| j j}|d ur*|d ur*td|d ur9| || | }n|d urF| d d }ntd|\}}|d urU|jn|j}|	d ure|	d d jd nd}|d u rrt	j
||d}|d u rt| jdr| jjd d d |f }|||}|}n	t	j|t	j|d}| ||}| j jr|d ur| \}}}||f}|d u rt	j
||d}| |}nd }| || j j}| j|||||d	}t| d
r| |}| j||||||	|
|||d
}|S )NzDYou cannot specify both input_ids and inputs_embeds at the same timer]   z5You have to specify either input_ids or inputs_embedsr   r/   )r   ra   r   )r{   r\   ra   r|   r}   r  )	r   r   r   r   r   r   r   r   r   )rO   r   r   use_return_dictrH   Z%warn_if_padding_and_no_attention_maskru   r   rG   rK   Zonesr   r   ra   rs   rt   rv   Zget_extended_attention_maskr   Zinvert_attention_maskZget_head_maskr   r  r  )rx   r{   r   ra   r\   r   r|   r   r   r   r   r   r   r   r   Z
batch_sizer   r   r}   r   r   Zextended_attention_maskZencoder_batch_sizeZencoder_sequence_lengthr   Zencoder_hidden_shapeZencoder_extended_attention_maskr   r%   r%   r(   r     sl   


zElectraModel.forward)NNNNNNNNNNNNN)r   r   r   re   r  r  r  r   r   rK   r   r   r   r   r   r   r   r   r   r%   r%   ry   r(   r    s`    	
r  c                       r   )ElectraClassificationHeadz-Head for sentence-level classification tasks.c                    s^   t    t|j|j| _|jd ur|jn|j}td| _	t
|| _t|j|j| _d S )Nr   )rd   re   r   r   r   r   classifier_dropoutrp   r   r   ro   rq   
num_labelsout_projrx   rO   r  ry   r%   r(   re   ?  s   

z"ElectraClassificationHead.__init__c                 K   sL   |d d dd d f }|  |}| |}| |}|  |}| |}|S )Nr   )rq   r   r   r  )rx   featureskwargsr   r%   r%   r(   r   I  s   




z!ElectraClassificationHead.forwardr   r%   r%   ry   r(   r  <  s    
r  c                       sJ   e Zd ZdZdef fddZ	ddejdeej	 dejfd	d
Z
  ZS )ElectraSequenceSummarya  
    Compute a single vector summary of a sequence hidden states.

    Args:
        config ([`ElectraConfig`]):
            The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
            config class of your model for the default values it uses):

            - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are:

                - `"last"` -- Take the last token hidden state (like XLNet)
                - `"first"` -- Take the first token hidden state (like Bert)
                - `"mean"` -- Take the mean of all tokens hidden states
                - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
                - `"attn"` -- Not implemented now, use multi-head attention

            - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction.
            - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes
              (otherwise to `config.hidden_size`).
            - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output,
              another string or `None` will add no activation.
            - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation.
            - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation.
    rO   c                    s   t    t|dd| _| jdkrtt | _t|dr<|j	r<t|dr1|j
r1|jdkr1|j}n|j}t|j|| _t|dd }|rHt|nt | _t | _t|drc|jdkrct|j| _t | _t|d	r{|jdkr}t|j| _d S d S d S )
Nsummary_typelastattnsummary_use_projsummary_proj_to_labelsr   Zsummary_activationsummary_first_dropoutsummary_last_dropout)rd   re   rB   r  NotImplementedErrorr   ZIdentitysummaryr   r  r  r  r   r   r   r   first_dropoutr  ro   last_dropoutr  )rx   rO   Znum_classesZactivation_stringry   r%   r(   re   n  s&   




zElectraSequenceSummary.__init__Nr   	cls_indexr~   c                 C   s  | j dkr|dddf }ne| j dkr|dddf }nW| j dkr(|jdd}nK| j d	krl|du rItj|d
ddddf |jd d tjd}n|dd}|d| d  |	df }|
d|d}n| j dkrst| |}| |}| |}| |}|S )ak  
        Compute a single vector summary of a sequence hidden states.

        Args:
            hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`):
                The hidden states of the last layer.
            cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*):
                Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token.

        Returns:
            `torch.FloatTensor`: The summary of the sequence hidden states.
        r  Nr]   firstr   r   r   r   r$  .r   rb   )r]   r  )r  r   rK   Z	full_likerG   rv   Z	unsqueezers   r   ru   gatherr   r   r"  r!  r   r#  )rx   r   r$  r   r%   r%   r(   r     s.   



"




zElectraSequenceSummary.forwardr   )r   r   r   r   r    re   rK   r   r   r   r   r   r%   r%   ry   r(   r  T  s    r  z
    ELECTRA Model transformer with a sequence classification/regression head on top (a linear layer on top of the
    pooled output) e.g. for GLUE tasks.
    )Zcustom_introc                          e Zd Z fddZe										ddeej deej deej deej deej d	eej d
eej dee dee dee de	e
ej ef fddZ  ZS ) ElectraForSequenceClassificationc                    s:   t  | |j| _|| _t|| _t|| _|   d S r   )	rd   re   r  rO   r  r   r  r.   r	  rw   ry   r%   r(   re     s   

z)ElectraForSequenceClassification.__init__Nr{   r   ra   r\   r   r|   labelsr   r   r   r~   c                 C   sh  |
dur|
n| j j}
| j||||||||	|
d	}|d }| |}d}|dur| j jdu rQ| jdkr7d| j _n| jdkrM|jtjksH|jtj	krMd| j _nd| j _| j jdkrot
 }| jdkri|| | }n+|||}n%| j jdkrt }||d| j|d}n| j jdkrt }|||}|
s|f|dd  }|dur|f| S |S t|||j|jd	S )
a  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        Nr   ra   r\   r   r|   r   r   r   r   r   Z
regressionZsingle_label_classificationZmulti_label_classificationr]   r  r   r   r   )rO   r  r   r.   Zproblem_typer  rc   rK   rv   rD   r   r   r
   r   r	   r   r   r   )rx   r{   r   ra   r\   r   r|   r)  r   r   r   r   sequence_outputr   r  loss_fctr   r%   r%   r(   r     sT   


"


z(ElectraForSequenceClassification.forward
NNNNNNNNNN)r   r   r   re   r   r   rK   r   r   r   r   r   r   r   r%   r%   ry   r(   r(    sH    
	
r(  z
    Electra model with a binary classification head on top as used during pretraining for identifying generated tokens.

    It is recommended to load the discriminator checkpoint into that model.
    c                       r'  )ElectraForPreTrainingc                    s,   t  | t|| _t|| _|   d S r   )rd   re   r  r   r   discriminator_predictionsr	  rw   ry   r%   r(   re     s   

zElectraForPreTraining.__init__Nr{   r   ra   r\   r   r|   r)  r   r   r   r~   c                 C   s   |
dur|
n| j j}
| j||||||||	|
d	}|d }| |}d}|dur_t }|durQ|d|jd dk}|d|jd | }|| }||| }n||d|jd | }|
su|f|dd  }|durs|f| S |S t	|||j
|jdS )am  
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the ELECTRA loss. Input should be a sequence of tokens (see `input_ids` docstring)
            Indices should be in `[0, 1]`:

            - 0 indicates the token is an original token,
            - 1 indicates the token was replaced.

        Examples:

        ```python
        >>> from transformers import ElectraForPreTraining, AutoTokenizer
        >>> import torch

        >>> discriminator = ElectraForPreTraining.from_pretrained("google/electra-base-discriminator")
        >>> tokenizer = AutoTokenizer.from_pretrained("google/electra-base-discriminator")

        >>> sentence = "The quick brown fox jumps over the lazy dog"
        >>> fake_sentence = "The quick brown fox fake over the lazy dog"

        >>> fake_tokens = tokenizer.tokenize(fake_sentence, add_special_tokens=True)
        >>> fake_inputs = tokenizer.encode(fake_sentence, return_tensors="pt")
        >>> discriminator_outputs = discriminator(fake_inputs)
        >>> predictions = torch.round((torch.sign(discriminator_outputs[0]) + 1) / 2)

        >>> fake_tokens
        ['[CLS]', 'the', 'quick', 'brown', 'fox', 'fake', 'over', 'the', 'lazy', 'dog', '[SEP]']

        >>> predictions.squeeze().tolist()
        [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0]
        ```Nr*  r   r]   r   r+  )rO   r  r   r0  r   r	   r   rG   floatr  r   r   )rx   r{   r   ra   r\   r   r|   r)  r   r   r   r   discriminator_sequence_outputr   r  r-  Zactive_lossZactive_logitsZactive_labelsr   r%   r%   r(   r      s@   -
zElectraForPreTraining.forwardr.  )r   r   r   re   r   r   rK   r   r   r   r   r  r   r   r%   r%   ry   r(   r/    sH    	
r/  z
    Electra model with a language modeling head on top.

    Even though both the discriminator and generator may be loaded into this model, the generator is the only model of
    the two to have been trained for the masked language modeling task.
    c                       s   e Zd ZdgZ fddZdd Zdd Ze										dd	ee	j
 d
ee	j
 dee	j
 dee	j
 dee	j
 dee	j
 dee	j
 dee dee dee deee	j
 ef fddZ  ZS )r=   generator_lm_head.weightc                    s>   t  | t|| _t|| _t|j|j	| _
|   d S r   )rd   re   r  r   r   generator_predictionsr   r   rh   rg   generator_lm_headr	  rw   ry   r%   r(   re     s
   

zElectraForMaskedLM.__init__c                 C      | j S r   r5  r  r%   r%   r(   get_output_embeddings     z(ElectraForMaskedLM.get_output_embeddingsc                 C   
   || _ d S r   r7  )rx   ri   r%   r%   r(   set_output_embeddings     
z(ElectraForMaskedLM.set_output_embeddingsNr{   r   ra   r\   r   r|   r)  r   r   r   r~   c                 C   s   |
dur|
n| j j}
| j||||||||	|
d	}|d }| |}| |}d}|dur>t }||d| j j|d}|
sT|f|dd  }|durR|f| S |S t	|||j
|jdS )a  
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
            config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
            loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
        Nr*  r   r]   r   r+  )rO   r  r   r4  r5  r   r
   r   rg   r   r   r   )rx   r{   r   ra   r\   r   r|   r)  r   r   r   r   Zgenerator_sequence_outputprediction_scoresr  r-  r   r%   r%   r(   r     s8   

zElectraForMaskedLM.forwardr.  )r   r   r   _tied_weights_keysre   r8  r;  r   r   rK   r   r   r   r   r   r   r   r%   r%   ry   r(   r=   u  sN    	
	
r=   z
    Electra model with a token classification head on top.

    Both the discriminator and generator may be loaded into this model.
    c                       r'  )ElectraForTokenClassificationc                    s^   t  | |j| _t|| _|jd ur|jn|j}t|| _	t
|j|j| _|   d S r   )rd   re   r  r  r   r  rp   r   ro   rq   r   r   r.   r	  r  ry   r%   r(   re     s   
z&ElectraForTokenClassification.__init__Nr{   r   ra   r\   r   r|   r)  r   r   r   r~   c                 C   s   |
dur|
n| j j}
| j||||||||	|
d	}|d }| |}| |}d}|dur<t }||d| j|d}|
sR|f|dd  }|durP|f| S |S t|||j	|j
dS )z
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
        Nr*  r   r]   r   r+  )rO   r  r   rq   r.   r
   r   r  r   r   r   )rx   r{   r   ra   r\   r   r|   r)  r   r   r   r   r2  r   r  r-  r   r%   r%   r(   r     s8   

z%ElectraForTokenClassification.forwardr.  )r   r   r   re   r   r   rK   r   r   r   r   r   r   r   r%   r%   ry   r(   r?    sH    	
r?  c                       s   e Zd ZeZdZ fddZe											ddee	j
 dee	j
 dee	j
 dee	j
 d	ee	j
 d
ee	j
 dee	j
 dee	j
 dee dee dee deee	j
 ef fddZ  ZS )ElectraForQuestionAnsweringr   c                    s<   t  | |j| _t|| _t|j|j| _| 	  d S r   )
rd   re   r  r  r   r   r   r   
qa_outputsr	  rw   ry   r%   r(   re     s
   
z$ElectraForQuestionAnswering.__init__Nr{   r   ra   r\   r   r|   start_positionsend_positionsr   r   r   r~   c              
   C   sF  |d ur|n| j j}| j|||||||	|
d}|d }| |}|jddd\}}|d }|d }d }|d ur|d urt| dkrN|d}t| dkr[|d}|d}|	d|}|	d|}t
|d}|||}|||}|| d }|s||f|dd   }|d ur|f| S |S t||||j|jdS )	N)r   ra   r\   r   r|   r   r   r   r   r]   r   )Zignore_indexr/   )r  start_logits
end_logitsr   r   )rO   r  r   rA  r?   r   r   rC   ru   clampr
   r   r   r   )rx   r{   r   ra   r\   r   r|   rB  rC  r   r   r   r   r,  r   rD  rE  Z
total_lossZignored_indexr-  Z
start_lossZend_lossr   r%   r%   r(   r   !  sV   







z#ElectraForQuestionAnswering.forward)NNNNNNNNNNN)r   r   r   r    r  r  re   r   r   rK   r   r   r   r   r   r   r   r%   r%   ry   r(   r@    sR    
	
r@  c                       r'  )ElectraForMultipleChoicec                    s<   t  | t|| _t|| _t|jd| _	| 
  d S r   )rd   re   r  r   r  sequence_summaryr   r   r   r.   r	  rw   ry   r%   r(   re   g  s
   

z!ElectraForMultipleChoice.__init__Nr{   r   ra   r\   r   r|   r)  r   r   r   r~   c                 C   sn  |
dur|
n| j j}
|dur|jd n|jd }|dur%|d|dnd}|dur4|d|dnd}|durC|d|dnd}|durR|d|dnd}|dure|d|d|dnd}| j||||||||	|
d	}|d }| |}| |}|d|}d}|durt }|||}|
s|f|dd  }|dur|f| S |S t	|||j
|jdS )a[  
        input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
            Indices of input sequence tokens in the vocabulary.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
            Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
            1]`:

            - 0 corresponds to a *sentence A* token,
            - 1 corresponds to a *sentence B* token.

            [What are token type IDs?](../glossary#token-type-ids)
        position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
            config.max_position_embeddings - 1]`.

            [What are position IDs?](../glossary#position-ids)
        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
            model's internal embedding lookup matrix.
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
            num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
            `input_ids` above)
        Nr   r]   r   r*  r   r+  )rO   r  rG   r   ru   r   rH  r.   r
   r   r   r   )rx   r{   r   ra   r\   r   r|   r)  r   r   r   Znum_choicesr   r,  Zpooled_outputr   Zreshaped_logitsr  r-  r   r%   r%   r(   r   q  sL   ,


z ElectraForMultipleChoice.forwardr.  )r   r   r   re   r   r   rK   r   r   r   r   r   r   r   r%   r%   ry   r(   rG  e  sH    
	
rG  zS
    ELECTRA Model with a `language modeling` head on top for CLM fine-tuning.
    c                "       s   e Zd ZdgZ fddZdd Zdd Ze														dd	ee	j
 d
ee	j
 dee	j
 dee	j
 dee	j
 dee	j
 dee	j
 dee	j
 dee	j
 deee	j
  dee dee dee dee deee	j
 ef fddZdd Z  ZS )ElectraForCausalLMr3  c                    sN   t  | |jstd t|| _t|| _t	
|j|j| _|   d S )NzOIf you want to use `ElectraForCausalLM` as a standalone, add `is_decoder=True.`)rd   re   r   r3   warningr  r   r   r4  r   r   rh   rg   r5  Zinit_weightsrw   ry   r%   r(   re     s   


zElectraForCausalLM.__init__c                 C   r6  r   r7  r  r%   r%   r(   r8    r9  z(ElectraForCausalLM.get_output_embeddingsc                 C   r:  r   r7  )rx   Znew_embeddingsr%   r%   r(   r;    r<  z(ElectraForCausalLM.set_output_embeddingsNr{   r   ra   r\   r   r|   r   r   r)  r   r   r   r   r   r~   c                 K   s   |dur|n| j j}|	durd}| j|||||||||
||||d}|d }| | |}d}|	durB| j||	fd| j ji|}|sX|f|dd  }|durV|f| S |S t|||j|j	|j
|jdS )a3  
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
            `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
            ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`

        Example:

        ```python
        >>> from transformers import AutoTokenizer, ElectraForCausalLM, ElectraConfig
        >>> import torch

        >>> tokenizer = AutoTokenizer.from_pretrained("google/electra-base-generator")
        >>> config = ElectraConfig.from_pretrained("google/electra-base-generator")
        >>> config.is_decoder = True
        >>> model = ElectraForCausalLM.from_pretrained("google/electra-base-generator", config=config)

        >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
        >>> outputs = model(**inputs)

        >>> prediction_logits = outputs.logits
        ```NF)r   ra   r\   r   r|   r   r   r   r   r   r   r   r   rg   r   )r  r   r   r   r   r   )rO   r  r   r5  r4  Zloss_functionrg   r   r   r   r   r   )rx   r{   r   ra   r\   r   r|   r   r   r)  r   r   r   r   r   r  r   r,  r=  Zlm_lossr   r%   r%   r(   r     sR   )zElectraForCausalLM.forwardc                    s.   d}|D ]}|t  fdd|D f7 }q|S )Nr%   c                 3   s$    | ]}| d  |jV  qdS )r   N)Zindex_selectr   r   )r&   Z
past_statebeam_idxr%   r(   r)   B  s   " z4ElectraForCausalLM._reorder_cache.<locals>.<genexpr>)r   )rx   r   rL  Zreordered_pastZ
layer_pastr%   rK  r(   _reorder_cache>  s   z!ElectraForCausalLM._reorder_cache)NNNNNNNNNNNNNN)r   r   r   r>  re   r8  r;  r   r   rK   r   r   r   r   r   r   r   rM  r   r%   r%   ry   r(   rI    sh    	
VrI  )
rI  r=   rG  r/  r@  r(  r?  r  r   rX   )r!   )Ir   r   r5   dataclassesr   typingr   r   r   r   r   rK   Ztorch.utils.checkpointr   Ztorch.nnr	   r
   r   Zactivationsr   r   Z
generationr   Zmodeling_outputsr   r   r   r   r   r   r   r   Zmodeling_utilsr   Zpytorch_utilsr   r   r   utilsr   r   r   Zconfiguration_electrar    Z
get_loggerr   r3   rX   ModulerY   r   r   r   r   r   r   r   r   r   r   r   r  r  r  r  r(  r/  r=   r?  r@  rG  rI  __all__r%   r%   r%   r(   <module>   s   (


RC 4W]vcS^KCRgu