o
    Zh                     @   s  d Z ddlZddlmZ ddlmZ ddlmZmZm	Z	 ddl
Z
ddlZ
ddl
mZ ddlmZmZmZ dd	lmZmZ e rEdd
lmZ ddlmZ ddlmZmZmZmZmZmZmZm Z m!Z! ddl"m#Z# ddl$m%Z% ddlm&Z& ddl'm(Z( e&)e*Z+dd Z,dd Z-dd Z.G dd dej/Z0G dd dej/Z1G dd dej/Z2G dd dej/Z3G d d! d!ej/Z4G d"d# d#ej/Z5G d$d% d%ej/Z6G d&d' d'ej/Z7G d(d) d)ej/Z8G d*d+ d+ej/Z9G d,d- d-ej/Z:G d.d/ d/ej/Z;G d0d1 d1ej/Z<G d2d3 d3ej/Z=eG d4d5 d5e#Z>eG d6d7 d7eZ?eG d8d9 d9e>Z@ed:d;G d<d= d=e>ZAeG d>d? d?e>ZBed@d;G dAdB dBe>ZCedCd;G dDdE dEe>ZDeG dFdG dGe>ZEeG dHdI dIe>ZFeG dJdK dKe>ZGg dLZHdS )MzPyTorch FNet model.    N)	dataclass)partial)OptionalTupleUnion)nn)BCEWithLogitsLossCrossEntropyLossMSELoss   )auto_docstringis_scipy_available)linalg)ACT2FN)	BaseModelOutputBaseModelOutputWithPoolingMaskedLMOutputModelOutputMultipleChoiceModelOutputNextSentencePredictorOutputQuestionAnsweringModelOutputSequenceClassifierOutputTokenClassifierOutput)PreTrainedModel)apply_chunking_to_forward)logging   )
FNetConfigc                 C   s:   | j d }|d|d|f }| tj} td| ||S )z4Applies 2D matrix multiplication to 3D input arrays.r   Nzbij,jk,ni->bnk)shapetypetorch	complex64Zeinsum)xmatrix_dim_onematrix_dim_two
seq_length r&   U/var/www/auris/lib/python3.10/site-packages/transformers/models/fnet/modeling_fnet.py_two_dim_matmul7   s   
r(   c                 C   s   t | ||S N)r(   )r"   r#   r$   r&   r&   r'   two_dim_matmul@      r*   c                 C   s4   | }t t| jdd D ]
}tjj||d}q|S )z
    Applies n-dimensional Fast Fourier Transform (FFT) to input array.

    Args:
        x: Input n-dimensional array.

    Returns:
        n-dimensional Fourier transform of input n-dimensional array.
    r   N)axis)reversedrangendimr    fft)r"   outr,   r&   r&   r'   fftnE   s   
r2   c                       s*   e Zd ZdZ fddZdddZ  ZS )FNetEmbeddingszGConstruct the embeddings from word, position and token_type embeddings.c                    s   t    tj|j|j|jd| _t|j|j| _	t|j
|j| _tj|j|jd| _t|j|j| _t|j| _| jdt|jddd | jdtj| j tjddd d S )	N)padding_idxZepsposition_ids)r   F)
persistenttoken_type_idsdtype)super__init__r   	Embedding
vocab_sizehidden_sizeZpad_token_idword_embeddingsmax_position_embeddingsposition_embeddingsZtype_vocab_sizetoken_type_embeddings	LayerNormlayer_norm_epsLinear
projectionDropouthidden_dropout_probdropoutregister_bufferr    Zarangeexpandzerosr6   sizelongselfconfig	__class__r&   r'   r=   X   s   

zFNetEmbeddings.__init__Nc                 C   s   |d ur	|  }n|  d d }|d }|d u r$| jd d d |f }|d u rNt| drC| jd d d |f }||d |}|}ntj|tj| jjd}|d u rW| 	|}| 
|}	||	 }
| |}|
|7 }
| |
}
| |
}
| |
}
|
S )Nr7   r   r9   r   r;   device)rO   r6   hasattrr9   rM   r    rN   rP   rW   rA   rD   rC   rE   rH   rK   )rR   	input_idsr9   r6   inputs_embedsinput_shaper%   buffered_token_type_ids buffered_token_type_ids_expandedrD   
embeddingsrC   r&   r&   r'   forwardn   s,   







zFNetEmbeddings.forward)NNNN)__name__
__module____qualname____doc__r=   r_   __classcell__r&   r&   rT   r'   r3   U   s    r3   c                       ,   e Zd Z fddZdd Zdd Z  ZS )FNetBasicFourierTransformc                    s   t    | | d S r)   )r<   r=   _init_fourier_transformrQ   rT   r&   r'   r=         
z"FNetBasicFourierTransform.__init__c                 C   s   |j sttjjdd| _d S |jdkrLt rB| dtj	t
|jtjd | dtj	t
|jtjd tt| j| jd| _d S td t| _d S t| _d S )	N)r      dim   dft_mat_hiddenr:   dft_mat_seq)r#   r$   zpSciPy is needed for DFT matrix calculation and is not found. Using TPU optimized fast fourier transform instead.)use_tpu_fourier_optimizationsr   r    r0   r2   fourier_transformrB   r   rL   Ztensorr   Zdftr@   r!   tpu_short_seq_lengthr*   rn   rm   r   warningrQ   r&   r&   r'   rg      s$   



z1FNetBasicFourierTransform._init_fourier_transformc                 C   s   |  |j}|fS r)   )rp   real)rR   hidden_statesoutputsr&   r&   r'   r_      s   z!FNetBasicFourierTransform.forward)r`   ra   rb   r=   rg   r_   rd   r&   r&   rT   r'   rf      s    rf   c                       $   e Zd Z fddZdd Z  ZS )FNetBasicOutputc                    s"   t    tj|j|jd| _d S Nr5   )r<   r=   r   rE   r@   rF   rQ   rT   r&   r'   r=      s   
zFNetBasicOutput.__init__c                 C   s   |  || }|S r)   )rE   rR   rt   input_tensorr&   r&   r'   r_      s   zFNetBasicOutput.forwardr`   ra   rb   r=   r_   rd   r&   r&   rT   r'   rw          rw   c                       rv   )FNetFourierTransformc                    s"   t    t|| _t|| _d S r)   )r<   r=   rf   rR   rw   outputrQ   rT   r&   r'   r=      s   

zFNetFourierTransform.__init__c                 C   s$   |  |}| |d |}|f}|S Nr   )rR   r~   )rR   rt   Zself_outputsfourier_outputru   r&   r&   r'   r_      s   
zFNetFourierTransform.forwardr{   r&   r&   rT   r'   r}          r}   c                       2   e Zd Z fddZdejdejfddZ  ZS )FNetIntermediatec                    sD   t    t|j|j| _t|jt	rt
|j | _d S |j| _d S r)   )r<   r=   r   rG   r@   intermediate_sizedense
isinstance
hidden_actstrr   intermediate_act_fnrQ   rT   r&   r'   r=      s
   
zFNetIntermediate.__init__rt   returnc                 C      |  |}| |}|S r)   )r   r   rR   rt   r&   r&   r'   r_         

zFNetIntermediate.forwardr`   ra   rb   r=   r    Tensorr_   rd   r&   r&   rT   r'   r      s    r   c                       s8   e Zd Z fddZdejdejdejfddZ  ZS )
FNetOutputc                    sB   t    t|j|j| _tj|j|jd| _t	|j
| _d S rx   )r<   r=   r   rG   r   r@   r   rE   rF   rI   rJ   rK   rQ   rT   r&   r'   r=      s   
zFNetOutput.__init__rt   rz   r   c                 C   s&   |  |}| |}| || }|S r)   )r   rK   rE   ry   r&   r&   r'   r_      s   

zFNetOutput.forwardr   r&   r&   rT   r'   r      s    $r   c                       re   )	FNetLayerc                    s:   t    |j| _d| _t|| _t|| _t|| _	d S Nr   )
r<   r=   chunk_size_feed_forwardseq_len_dimr}   fourierr   intermediater   r~   rQ   rT   r&   r'   r=      s   


zFNetLayer.__init__c                 C   s0   |  |}|d }t| j| j| j|}|f}|S r   )r   r   feed_forward_chunkr   r   )rR   rt   Zself_fourier_outputsr   layer_outputru   r&   r&   r'   r_      s   
zFNetLayer.forwardc                 C   s   |  |}| ||}|S r)   )r   r~   )rR   r   Zintermediate_outputr   r&   r&   r'   r     s   
zFNetLayer.feed_forward_chunk)r`   ra   rb   r=   r_   r   rd   r&   r&   rT   r'   r      s    r   c                       s&   e Zd Z fddZdddZ  ZS )FNetEncoderc                    s:   t     | _t fddt jD | _d| _d S )Nc                    s   g | ]}t  qS r&   )r   ).0_rS   r&   r'   
<listcomp>  s    z(FNetEncoder.__init__.<locals>.<listcomp>F)	r<   r=   rS   r   Z
ModuleListr.   Znum_hidden_layerslayergradient_checkpointingrQ   rT   r   r'   r=   
  s   
 
zFNetEncoder.__init__FTc                 C   s   |rdnd }t | jD ]!\}}|r||f }| jr$| jr$| |j|}n||}|d }q|r4||f }|sAtdd ||fD S t||dS )Nr&   r   c                 s   s    | ]	}|d ur|V  qd S r)   r&   )r   vr&   r&   r'   	<genexpr>"  s    z&FNetEncoder.forward.<locals>.<genexpr>)last_hidden_statert   )	enumerater   r   ZtrainingZ_gradient_checkpointing_func__call__tupler   )rR   rt   output_hidden_statesreturn_dictZall_hidden_statesiZlayer_moduleZlayer_outputsr&   r&   r'   r_     s   


zFNetEncoder.forward)FTr{   r&   r&   rT   r'   r   	  s    r   c                       r   )
FNetPoolerc                    s*   t    t|j|j| _t | _d S r)   )r<   r=   r   rG   r@   r   ZTanh
activationrQ   rT   r&   r'   r=   )  s   
zFNetPooler.__init__rt   r   c                 C   s(   |d d df }|  |}| |}|S r   )r   r   )rR   rt   Zfirst_token_tensorpooled_outputr&   r&   r'   r_   .  s   

zFNetPooler.forwardr   r&   r&   rT   r'   r   (  s    r   c                       r   )FNetPredictionHeadTransformc                    sV   t    t|j|j| _t|jtrt	|j | _
n|j| _
tj|j|jd| _d S rx   )r<   r=   r   rG   r@   r   r   r   r   r   transform_act_fnrE   rF   rQ   rT   r&   r'   r=   9  s   
z$FNetPredictionHeadTransform.__init__rt   r   c                 C   s"   |  |}| |}| |}|S r)   )r   r   rE   r   r&   r&   r'   r_   B  s   


z#FNetPredictionHeadTransform.forwardr   r&   r&   rT   r'   r   8  s    	r   c                       s.   e Zd Z fddZdd Zd	ddZ  ZS )
FNetLMPredictionHeadc                    sH   t    t|| _t|j|j| _t	t
|j| _| j| j_d S r)   )r<   r=   r   	transformr   rG   r@   r?   decoder	Parameterr    rN   biasrQ   rT   r&   r'   r=   J  s
   

zFNetLMPredictionHead.__init__c                 C   r   r)   )r   r   r   r&   r&   r'   r_   U  r   zFNetLMPredictionHead.forwardr   Nc                 C   s,   | j jjjdkr| j| j _d S | j j| _d S )Nmeta)r   r   rW   r   rR   r&   r&   r'   _tie_weightsZ  s   z!FNetLMPredictionHead._tie_weights)r   N)r`   ra   rb   r=   r_   r   rd   r&   r&   rT   r'   r   I  s    r   c                       rv   )FNetOnlyMLMHeadc                    s   t    t|| _d S r)   )r<   r=   r   predictionsrQ   rT   r&   r'   r=   d  rh   zFNetOnlyMLMHead.__init__c                 C      |  |}|S r)   )r   )rR   sequence_outputprediction_scoresr&   r&   r'   r_   h     
zFNetOnlyMLMHead.forwardr{   r&   r&   rT   r'   r   c  r|   r   c                       rv   )FNetOnlyNSPHeadc                    s   t    t|jd| _d S Nri   )r<   r=   r   rG   r@   seq_relationshiprQ   rT   r&   r'   r=   o  s   
zFNetOnlyNSPHead.__init__c                 C   r   r)   )r   )rR   r   seq_relationship_scorer&   r&   r'   r_   s  r   zFNetOnlyNSPHead.forwardr{   r&   r&   rT   r'   r   n  r|   r   c                       rv   )FNetPreTrainingHeadsc                    s(   t    t|| _t|jd| _d S r   )r<   r=   r   r   r   rG   r@   r   rQ   rT   r&   r'   r=   z  s   

zFNetPreTrainingHeads.__init__c                 C   s   |  |}| |}||fS r)   )r   r   )rR   r   r   r   r   r&   r&   r'   r_     s   

zFNetPreTrainingHeads.forwardr{   r&   r&   rT   r'   r   y  r   r   c                   @   s    e Zd ZeZdZdZdd ZdS )FNetPreTrainedModelfnetTc                 C   s   t |tjr |jjjd| jjd |jdur|jj	  dS dS t |tj
rC|jjjd| jjd |jdurA|jj|j 	  dS dS t |tjrX|jj	  |jjd dS dS )zInitialize the weightsg        )meanZstdNg      ?)r   r   rG   weightdataZnormal_rS   Zinitializer_ranger   Zzero_r>   r4   rE   Zfill_)rR   moduler&   r&   r'   _init_weights  s   

z!FNetPreTrainedModel._init_weightsN)r`   ra   rb   r   Zconfig_classZbase_model_prefixZsupports_gradient_checkpointingr   r&   r&   r&   r'   r     s
    r   c                   @   s^   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eej ed< dZeeej  ed< dS )FNetForPreTrainingOutputa  
    Output type of [`FNetForPreTraining`].

    Args:
        loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
            Total loss as the sum of the masked language modeling loss and the next sequence prediction
            (classification) loss.
        prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        seq_relationship_logits (`torch.FloatTensor` of shape `(batch_size, 2)`):
            Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
            before SoftMax).
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
            plus the initial embedding outputs.
    Nlossprediction_logitsseq_relationship_logitsrt   )r`   ra   rb   rc   r   r   r    FloatTensor__annotations__r   r   rt   r   r&   r&   r&   r'   r     s   
 r   c                       s   e Zd ZdZd fdd	Zdd Zdd Ze												dd
ee	j
 dee	j
 dee	j
 dee	j dee dee deeef fddZ  ZS )	FNetModelz

    The model can behave as an encoder, following the architecture described in [FNet: Mixing Tokens with Fourier
    Transforms](https://arxiv.org/abs/2105.03824) by James Lee-Thorp, Joshua Ainslie, Ilya Eckstein, Santiago Ontanon.

    Tc                    sD   t  | || _t|| _t|| _|rt|nd| _| 	  dS )zv
        add_pooling_layer (bool, *optional*, defaults to `True`):
            Whether to add a pooling layer
        N)
r<   r=   rS   r3   r^   r   encoderr   pooler	post_init)rR   rS   Zadd_pooling_layerrT   r&   r'   r=     s   

zFNetModel.__init__c                 C   s   | j jS r)   r^   rA   r   r&   r&   r'   get_input_embeddings  s   zFNetModel.get_input_embeddingsc                 C   s   || j _d S r)   r   )rR   valuer&   r&   r'   set_input_embeddings  r+   zFNetModel.set_input_embeddingsNrY   r9   r6   rZ   r   r   r   c                 C   sv  |d ur|n| j j}|d ur|n| j j}|d ur |d ur td|d ur-| }|\}}	n|d ur>| d d }|\}}	ntd| j jrT|	dkrT| j j|	krTtd|d ur[|jn|j}
|d u rt| j	dr}| j	j
d d d |	f }|||	}|}n	tj|tj|
d}| j	||||d}| j|||d	}|d
 }| jd ur| |nd }|s||f|dd   S t|||jdS )NzDYou cannot specify both input_ids and inputs_embeds at the same timer7   z5You have to specify either input_ids or inputs_embedsrl   zThe `tpu_short_seq_length` in FNetConfig should be set equal to the sequence length being passed to the model when using TPU optimizations.r9   rV   )rY   r6   r9   rZ   )r   r   r   r   )r   pooler_outputrt   )rS   r   use_return_dict
ValueErrorrO   ro   rq   rW   rX   r^   r9   rM   r    rN   rP   r   r   r   rt   )rR   rY   r9   r6   rZ   r   r   r[   Z
batch_sizer%   rW   r\   r]   Zembedding_outputZencoder_outputsr   r   r&   r&   r'   r_     s\   

zFNetModel.forward)T)NNNNNN)r`   ra   rb   rc   r=   r   r   r   r   r    Z
LongTensorr   boolr   r   r   r_   rd   r&   r&   rT   r'   r     s6    
r   z
    FNet Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next
    sentence prediction (classification)` head.
    )Zcustom_introc                       s   e Zd ZddgZ fddZdd Zdd Ze																dd
ee	j
 dee	j
 dee	j
 dee	j
 dee	j
 dee	j
 dee dee deeef fddZ  ZS )FNetForPreTrainingcls.predictions.decoder.biascls.predictions.decoder.weightc                    ,   t  | t|| _t|| _|   d S r)   )r<   r=   r   r   r   clsr   rQ   rT   r&   r'   r=   &     

zFNetForPreTraining.__init__c                 C   
   | j jjS r)   r   r   r   r   r&   r&   r'   get_output_embeddings/     
z(FNetForPreTraining.get_output_embeddingsc                 C      || j j_|j| j j_d S r)   r   r   r   r   rR   Znew_embeddingsr&   r&   r'   set_output_embeddings2     
z(FNetForPreTraining.set_output_embeddingsNrY   r9   r6   rZ   labelsnext_sentence_labelr   r   r   c	                 C   s   |dur|n| j j}| j||||||d}	|	dd \}
}| |
|\}}d}|durP|durPt }||d| j j|d}||dd|d}|| }|sg||f|	dd  }|dure|f| S |S t||||	jdS )aH  
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
            config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
            loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
        next_sentence_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
            (see `input_ids` docstring) Indices should be in `[0, 1]`:

            - 0 indicates sequence B is a continuation of sequence A,
            - 1 indicates sequence B is a random sequence.

        Example:

        ```python
        >>> from transformers import AutoTokenizer, FNetForPreTraining
        >>> import torch

        >>> tokenizer = AutoTokenizer.from_pretrained("google/fnet-base")
        >>> model = FNetForPreTraining.from_pretrained("google/fnet-base")
        >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
        >>> outputs = model(**inputs)
        >>> prediction_logits = outputs.prediction_logits
        >>> seq_relationship_logits = outputs.seq_relationship_logits
        ```Nr9   r6   rZ   r   r   ri   r7   )r   r   r   rt   )	rS   r   r   r   r	   viewr?   r   rt   )rR   rY   r9   r6   rZ   r   r   r   r   ru   r   r   r   r   
total_lossloss_fctmasked_lm_lossnext_sentence_lossr~   r&   r&   r'   r_   6  s4   %	zFNetForPreTraining.forwardNNNNNNNN)r`   ra   rb   _tied_weights_keysr=   r   r   r   r   r    r   r   r   r   r   r_   rd   r&   r&   rT   r'   r     sB    		

r   c                       s   e Zd ZddgZ fddZdd Zdd Ze														dd
ee	j
 dee	j
 dee	j
 dee	j
 dee	j
 dee dee deeef fddZ  ZS )FNetForMaskedLMr   r   c                    r   r)   )r<   r=   r   r   r   r   r   rQ   rT   r&   r'   r=     r   zFNetForMaskedLM.__init__c                 C   r   r)   r   r   r&   r&   r'   r     r   z%FNetForMaskedLM.get_output_embeddingsc                 C   r   r)   r   r   r&   r&   r'   r     r   z%FNetForMaskedLM.set_output_embeddingsNrY   r9   r6   rZ   r   r   r   r   c                 C   s   |dur|n| j j}| j||||||d}|d }	| |	}
d}|dur5t }||
d| j j|d}|sK|
f|dd  }|durI|f| S |S t||
|jdS )a  
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
            config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
            loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
        Nr   r   r7   ri   r   logitsrt   )	rS   r   r   r   r	   r   r?   r   rt   )rR   rY   r9   r6   rZ   r   r   r   ru   r   r   r   r   r~   r&   r&   r'   r_     s&   	
zFNetForMaskedLM.forwardNNNNNNN)r`   ra   rb   r   r=   r   r   r   r   r    r   r   r   r   r   r_   rd   r&   r&   rT   r'   r   |  s<    	
	r   zT
    FNet Model with a `next sentence prediction (classification)` head on top.
    c                          e Zd Z fddZe							ddeej deej deej deej deej d	ee d
ee de	e
ef fddZ  ZS )FNetForNextSentencePredictionc                    r   r)   )r<   r=   r   r   r   r   r   rQ   rT   r&   r'   r=     r   z&FNetForNextSentencePrediction.__init__NrY   r9   r6   rZ   r   r   r   r   c                 K   s   d|v rt dt |d}|dur|n| jj}| j||||||d}	|	d }
| |
}d}|durBt }||	dd|	d}|sX|f|	dd  }|durV|f| S |S t
|||	jdS )	a  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
            (see `input_ids` docstring). Indices should be in `[0, 1]`:

            - 0 indicates sequence B is a continuation of sequence A,
            - 1 indicates sequence B is a random sequence.

        Example:

        ```python
        >>> from transformers import AutoTokenizer, FNetForNextSentencePrediction
        >>> import torch

        >>> tokenizer = AutoTokenizer.from_pretrained("google/fnet-base")
        >>> model = FNetForNextSentencePrediction.from_pretrained("google/fnet-base")
        >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
        >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
        >>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt")
        >>> outputs = model(**encoding, labels=torch.LongTensor([1]))
        >>> logits = outputs.logits
        >>> assert logits[0, 0] < logits[0, 1]  # next sentence was random
        ```r   zoThe `next_sentence_label` argument is deprecated and will be removed in a future version, use `labels` instead.Nr   r   r7   ri   r   )warningswarnFutureWarningpoprS   r   r   r   r	   r   r   rt   )rR   rY   r9   r6   rZ   r   r   r   kwargsru   r   Zseq_relationship_scoresr   r   r~   r&   r&   r'   r_     s:   $
	
z%FNetForNextSentencePrediction.forwardr   )r`   ra   rb   r=   r   r   r    r   r   r   r   r   r_   rd   r&   r&   rT   r'   r     s6    	

r   z
    FNet Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
    output) e.g. for GLUE tasks.
    c                       r   )FNetForSequenceClassificationc                    J   t  | |j| _t|| _t|j| _t	|j
|j| _|   d S r)   r<   r=   
num_labelsr   r   r   rI   rJ   rK   rG   r@   
classifierr   rQ   rT   r&   r'   r=     s   
z&FNetForSequenceClassification.__init__NrY   r9   r6   rZ   r   r   r   r   c                 C   sh  |dur|n| j j}| j||||||d}|d }	| |	}	| |	}
d}|dur| j jdu rS| jdkr9d| j _n| jdkrO|jtj	ksJ|jtj
krOd| j _nd| j _| j jdkrqt }| jdkrk||
 | }n+||
|}n%| j jdkrt }||
d| j|d}n| j jdkrt }||
|}|s|
f|dd  }|dur|f| S |S t||
|jd	S )
a  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        Nr   r   Z
regressionZsingle_label_classificationZmulti_label_classificationr7   ri   r   )rS   r   r   rK   r  Zproblem_typer  r;   r    rP   intr
   squeezer	   r   r   r   rt   )rR   rY   r9   r6   rZ   r   r   r   ru   r   r   r   r   r~   r&   r&   r'   r_   '  sF   	



"


z%FNetForSequenceClassification.forwardr   )r`   ra   rb   r=   r   r   r    r   r   r   r   r   r_   rd   r&   r&   rT   r'   r     s6    
	r   c                       r   )FNetForMultipleChoicec                    s@   t  | t|| _t|j| _t|j	d| _
|   d S r   )r<   r=   r   r   r   rI   rJ   rK   rG   r@   r  r   rQ   rT   r&   r'   r=   f  s
   
zFNetForMultipleChoice.__init__NrY   r9   r6   rZ   r   r   r   r   c                 C   sF  |dur|n| j j}|dur|jd n|jd }|dur%|d|dnd}|dur4|d|dnd}|durC|d|dnd}|durV|d|d|dnd}| j||||||d}	|	d }
| |
}
| |
}|d|}d}|durt }|||}|s|f|	dd  }|dur|f| S |S t	|||	j
dS )a[  
        input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
            Indices of input sequence tokens in the vocabulary.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
            Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
            1]`:

            - 0 corresponds to a *sentence A* token,
            - 1 corresponds to a *sentence B* token.

            [What are token type IDs?](../glossary#token-type-ids)
        position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
            config.max_position_embeddings - 1]`.

            [What are position IDs?](../glossary#position-ids)
        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
            is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
            model's internal embedding lookup matrix.
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
            num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
            `input_ids` above)
        Nr   r7   r   ri   r   )rS   r   r   r   rO   r   rK   r  r	   r   rt   )rR   rY   r9   r6   rZ   r   r   r   Znum_choicesru   r   r   Zreshaped_logitsr   r   r~   r&   r&   r'   r_   p  s:   )	


zFNetForMultipleChoice.forwardr   )r`   ra   rb   r=   r   r   r    r   r   r   r   r   r_   rd   r&   r&   rT   r'   r  d  s6    

	r  c                       r   )FNetForTokenClassificationc                    r   r)   r   rQ   rT   r&   r'   r=     s   
z#FNetForTokenClassification.__init__NrY   r9   r6   rZ   r   r   r   r   c                 C   s   |dur|n| j j}| j||||||d}|d }	| |	}	| |	}
d}|dur9t }||
d| j|d}|sO|
f|dd  }|durM|f| S |S t||
|j	dS )z
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
        Nr   r   r7   ri   r   )
rS   r   r   rK   r  r	   r   r  r   rt   )rR   rY   r9   r6   rZ   r   r   r   ru   r   r   r   r   r~   r&   r&   r'   r_     s(   	

z"FNetForTokenClassification.forwardr   )r`   ra   rb   r=   r   r   r    r   r   r   r   r   r_   rd   r&   r&   rT   r'   r    s6    
	r  c                       s   e Zd Z fddZe								ddeej deej deej deej deej d	eej d
ee dee de	e
ef fddZ  ZS )FNetForQuestionAnsweringc                    s<   t  | |j| _t|| _t|j|j| _| 	  d S r)   )
r<   r=   r  r   r   r   rG   r@   
qa_outputsr   rQ   rT   r&   r'   r=     s
   
z!FNetForQuestionAnswering.__init__NrY   r9   r6   rZ   start_positionsend_positionsr   r   r   c	                 C   s>  |d ur|n| j j}| j||||||d}	|	d }
| |
}|jddd\}}|d }|d }d }|d ur|d urt| dkrL|d}t| dkrY|d}|d}|	d|}|	d|}t
|d}|||}|||}|| d }|s||f|	dd   }|d ur|f| S |S t||||	jdS )	Nr   r   r   r7   rj   )Zignore_indexri   )r   start_logits
end_logitsrt   )rS   r   r   r	  splitr  
contiguouslenrO   clampr	   r   rt   )rR   rY   r9   r6   rZ   r
  r  r   r   ru   r   r   r  r  r   Zignored_indexr   Z
start_lossZend_lossr~   r&   r&   r'   r_     sB   	







z FNetForQuestionAnswering.forwardr   )r`   ra   rb   r=   r   r   r    r   r   r   r   r   r_   rd   r&   r&   rT   r'   r    s<    	

r  )
r   r  r   r   r  r   r  r   r   r   )Irc   r   dataclassesr   	functoolsr   typingr   r   r   r    Ztorch.utils.checkpointr   Ztorch.nnr   r	   r
   utilsr   r   Zscipyr   Zactivationsr   Zmodeling_outputsr   r   r   r   r   r   r   r   r   Zmodeling_utilsr   Zpytorch_utilsr   r   Zconfiguration_fnetr   Z
get_loggerr`   loggerr(   r*   r2   Moduler3   rf   rw   r}   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r  r  r  __all__r&   r&   r&   r'   <module>   s~   ,
	=&
eY>UI[9D