o
    Zh`                     @   s   d Z ddlZddlmZmZmZ ddlZddlmZ ddlm	Z	m
Z
mZ ddlmZ ddlmZmZmZmZmZmZmZ dd	lmZ dd
lmZmZ ddlmZ eeZG dd dej Z!G dd dej Z"G dd dej#Z$G dd dej Z%G dd dej Z&G dd dej Z'G dd dej Z(G dd dej Z)G dd dej Z*G dd  d ej Z+G d!d" d"ej Z,G d#d$ d$ej Z-eG d%d& d&eZ.eG d'd( d(e.Z/eG d)d* d*e.Z0ed+d,G d-d. d.e.Z1eG d/d0 d0e.Z2eG d1d2 d2e.Z3eG d3d4 d4e.Z4g d5Z5dS )6zPyTorch SqueezeBert model.    N)OptionalTupleUnion)nn)BCEWithLogitsLossCrossEntropyLossMSELoss   )ACT2FN)BaseModelOutputBaseModelOutputWithPoolingMaskedLMOutputMultipleChoiceModelOutputQuestionAnsweringModelOutputSequenceClassifierOutputTokenClassifierOutput)PreTrainedModel)auto_docstringlogging   )SqueezeBertConfigc                       s*   e Zd ZdZ fddZdddZ  ZS )SqueezeBertEmbeddingszGConstruct the embeddings from word, position and token_type embeddings.c                    s   t    tj|j|j|jd| _t|j|j| _	t|j
|j| _tj|j|jd| _t|j| _| jdt|jddd d S )N)padding_idxepsposition_ids)r   F)
persistent)super__init__r   	Embedding
vocab_sizeembedding_sizeZpad_token_idword_embeddingsZmax_position_embeddingsposition_embeddingsZtype_vocab_sizetoken_type_embeddings	LayerNormhidden_sizelayer_norm_epsDropouthidden_dropout_probdropoutZregister_buffertorchZarangeexpandselfconfig	__class__ c/var/www/auris/lib/python3.10/site-packages/transformers/models/squeezebert/modeling_squeezebert.pyr   0   s   

zSqueezeBertEmbeddings.__init__Nc           
      C   s   |d ur	|  }n|  d d }|d }|d u r$| jd d d |f }|d u r3tj|tj| jjd}|d u r<| |}| |}| |}|| | }	| 	|	}	| 
|	}	|	S )Nr   r   dtypedevice)sizer   r,   zeroslongr7   r#   r$   r%   r&   r+   )
r/   	input_idstoken_type_idsr   inputs_embedsinput_shapeZ
seq_lengthr$   r%   
embeddingsr3   r3   r4   forward@   s    





zSqueezeBertEmbeddings.forward)NNNN__name__
__module____qualname____doc__r   r@   __classcell__r3   r3   r1   r4   r   -   s    r   c                       (   e Zd ZdZ fddZdd Z  ZS )MatMulWrapperz
    Wrapper for torch.matmul(). This makes flop-counting easier to implement. Note that if you directly call
    torch.matmul() in your code, the flop counter will typically ignore the flops of the matmul.
    c                    s   t    d S N)r   r   r/   r1   r3   r4   r   _      zMatMulWrapper.__init__c                 C   s   t ||S )a0  

        :param inputs: two torch tensors :return: matmul of these tensors

        Here are the typical dimensions found in BERT (the B is optional) mat1.shape: [B, <optional extra dims>, M, K]
        mat2.shape: [B, <optional extra dims>, K, N] output shape: [B, <optional extra dims>, M, N]
        )r,   matmul)r/   Zmat1Zmat2r3   r3   r4   r@   b   s   zMatMulWrapper.forwardrA   r3   r3   r1   r4   rH   Y   s    rH   c                   @   s"   e Zd ZdZdddZdd ZdS )	SqueezeBertLayerNormz
    This is a nn.LayerNorm subclass that accepts NCW data layout and performs normalization in the C dimension.

    N = batch C = channels W = sequence length
    -q=c                 C   s   t jj| ||d d S )N)Znormalized_shaper   )r   r&   r   )r/   r'   r   r3   r3   r4   r   t   s   zSqueezeBertLayerNorm.__init__c                 C   s*   | ddd}tj| |}| dddS )Nr      r   )permuter   r&   r@   )r/   xr3   r3   r4   r@   w   s   zSqueezeBertLayerNorm.forwardN)rN   )rB   rC   rD   rE   r   r@   r3   r3   r3   r4   rM   m   s    
rM   c                       rG   )ConvDropoutLayerNormz8
    ConvDropoutLayerNorm: Conv, Dropout, LayerNorm
    c                    s8   t    tj||d|d| _t|| _t|| _d S Nr   Zin_channelsZout_channelsZkernel_sizegroups)	r   r   r   Conv1dconv1drM   	layernormr)   r+   )r/   cincoutrU   dropout_probr1   r3   r4   r      s   

zConvDropoutLayerNorm.__init__c                 C   s*   |  |}| |}|| }| |}|S rI   )rW   r+   rX   )r/   hidden_statesZinput_tensorrQ   r3   r3   r4   r@      s
   


zConvDropoutLayerNorm.forwardrA   r3   r3   r1   r4   rR   }   s    rR   c                       rG   )ConvActivationz*
    ConvActivation: Conv, Activation
    c                    s,   t    tj||d|d| _t| | _d S rS   )r   r   r   rV   rW   r
   act)r/   rY   rZ   rU   r^   r1   r3   r4   r      s   
zConvActivation.__init__c                 C   s   |  |}| |S rI   )rW   r^   )r/   rQ   outputr3   r3   r4   r@      s   

zConvActivation.forwardrA   r3   r3   r1   r4   r]      s    r]   c                       s>   e Zd Zd fdd	Zdd Zdd Zdd	 Zd
d Z  ZS )SqueezeBertSelfAttentionr   c                    s   t    ||j dkrtd| d|j d|j| _t||j | _| j| j | _tj||d|d| _	tj||d|d| _
tj||d|d| _t|j| _tjdd| _t | _t | _d	S )
z
        config = used for some things; ignored for others (work in progress...) cin = input channels = output channels
        groups = number of groups to use in conv1d layers
        r   zcin (z6) is not a multiple of the number of attention heads ()r   rT   r   dimN)r   r   num_attention_heads
ValueErrorintattention_head_sizeall_head_sizer   rV   querykeyvaluer)   Zattention_probs_dropout_probr+   ZSoftmaxsoftmaxrH   	matmul_qk
matmul_qkv)r/   r0   rY   q_groupsk_groupsv_groupsr1   r3   r4   r      s   
z!SqueezeBertSelfAttention.__init__c                 C   s:   |  d | j| j|  d f}|j| }|ddddS )z
        - input: [N, C, W]
        - output: [N, C1, W, C2] where C1 is the head index, and C2 is one head's contents
        r   r   r   r	   rO   )r8   rd   rg   viewrP   r/   rQ   Znew_x_shaper3   r3   r4   transpose_for_scores   s    
z-SqueezeBertSelfAttention.transpose_for_scoresc                 C   s.   |  d | j| j|  d f}|j| }|S )z
        - input: [N, C, W]
        - output: [N, C1, C2, W] where C1 is the head index, and C2 is one head's contents
        r   r   )r8   rd   rg   rr   rs   r3   r3   r4   transpose_key_for_scores   s    
z1SqueezeBertSelfAttention.transpose_key_for_scoresc                 C   s>   | dddd }| d | j| d f}|j| }|S )zE
        - input: [N, C1, W, C2]
        - output: [N, C, W]
        r   r   r	   rO   )rP   
contiguousr8   rh   rr   rs   r3   r3   r4   transpose_output   s   
z)SqueezeBertSelfAttention.transpose_outputc                 C   s   |  |}| |}| |}| |}| |}| |}	| ||}
|
t| j }
|
| }
| 	|
}| 
|}| ||	}| |}d|i}|rO|
|d< |S )z
        expects hidden_states in [N, C, W] data layout.

        The attention_mask data layout is [N, W], and it does not need to be transposed.
        context_layerattention_score)ri   rj   rk   rt   ru   rm   mathsqrtrg   rl   r+   rn   rw   )r/   r\   attention_maskoutput_attentionsZmixed_query_layerZmixed_key_layerZmixed_value_layerZquery_layerZ	key_layerZvalue_layerry   Zattention_probsrx   resultr3   r3   r4   r@      s"   








z SqueezeBertSelfAttention.forward)r   r   r   )	rB   rC   rD   r   rt   ru   rw   r@   rF   r3   r3   r1   r4   r`      s    	

r`   c                       $   e Zd Z fddZdd Z  ZS )SqueezeBertModulec                    s   t    |j}|j}|j}|j}t|||j|j|jd| _t	|||j
|jd| _t|||j|jd| _t	|||j|jd| _dS )a  
        - hidden_size = input chans = output chans for Q, K, V (they are all the same ... for now) = output chans for
          the module
        - intermediate_size = output chans for intermediate layer
        - groups = number of groups for all layers in the BertModule. (eventually we could change the interface to
          allow different groups for different layers)
        )r0   rY   ro   rp   rq   )rY   rZ   rU   r[   )rY   rZ   rU   r^   N)r   r   r'   Zintermediate_sizer`   ro   rp   rq   	attentionrR   Zpost_attention_groupsr*   post_attentionr]   Zintermediate_groups
hidden_actintermediateZoutput_groupsr_   )r/   r0   Zc0c1c2c3r1   r3   r4   r      s   
zSqueezeBertModule.__init__c           
      C   sT   |  |||}|d }| ||}| |}| ||}d|i}	|r(|d |	d< |	S )Nrx   feature_mapry   )r   r   r   r_   )
r/   r\   r|   r}   ZattZattention_outputZpost_attention_outputZintermediate_outputlayer_outputZoutput_dictr3   r3   r4   r@     s   
zSqueezeBertModule.forwardrB   rC   rD   r   r@   rF   r3   r3   r1   r4   r      s    r   c                       s0   e Zd Z fddZ					dddZ  ZS )	SqueezeBertEncoderc                    sB   t     j jksJ dt fddt jD | _d S )NzIf you want embedding_size != intermediate hidden_size, please insert a Conv1d layer to adjust the number of channels before the first SqueezeBertModule.c                 3   s    | ]}t  V  qd S rI   )r   ).0_r0   r3   r4   	<genexpr>.  s    z.SqueezeBertEncoder.__init__.<locals>.<genexpr>)	r   r   r"   r'   r   Z
ModuleListrangenum_hidden_layerslayersr.   r1   r   r4   r   %  s
   
$zSqueezeBertEncoder.__init__NFTc                 C   s  |d u rd}n| d t|krd}nd}|du sJ d|ddd}|r(dnd }|r.dnd }	| jD ]+}
|rJ|ddd}||f7 }|ddd}|
|||}|d }|r^|	|d	 f7 }	q3|ddd}|rm||f7 }|s{td
d |||	fD S t|||	dS )NTFzAhead_mask is not yet supported in the SqueezeBert implementation.r   rO   r   r3   r   ry   c                 s   s    | ]	}|d ur|V  qd S rI   r3   )r   vr3   r3   r4   r   [  s    z-SqueezeBertEncoder.forward.<locals>.<genexpr>)last_hidden_stater\   
attentions)countlenrP   r   r@   tupler   )r/   r\   r|   	head_maskr}   output_hidden_statesreturn_dictZhead_mask_is_all_noneZall_hidden_statesZall_attentionslayerr   r3   r3   r4   r@   0  s6   	


zSqueezeBertEncoder.forward)NNFFTr   r3   r3   r1   r4   r   $  s    r   c                       r   )SqueezeBertPoolerc                    s*   t    t|j|j| _t | _d S rI   )r   r   r   Linearr'   denseZTanh
activationr.   r1   r3   r4   r   b  s   
zSqueezeBertPooler.__init__c                 C   s(   |d d df }|  |}| |}|S )Nr   )r   r   )r/   r\   Zfirst_token_tensorpooled_outputr3   r3   r4   r@   g  s   

zSqueezeBertPooler.forwardr   r3   r3   r1   r4   r   a  s    r   c                       r   )"SqueezeBertPredictionHeadTransformc                    sV   t    t|j|j| _t|jtrt	|j | _
n|j| _
tj|j|jd| _d S )Nr   )r   r   r   r   r'   r   
isinstancer   strr
   transform_act_fnr&   r(   r.   r1   r3   r4   r   q  s   
z+SqueezeBertPredictionHeadTransform.__init__c                 C   s"   |  |}| |}| |}|S rI   )r   r   r&   r/   r\   r3   r3   r4   r@   z  s   


z*SqueezeBertPredictionHeadTransform.forwardr   r3   r3   r1   r4   r   p  s    	r   c                       s.   e Zd Z fddZd	ddZdd Z  ZS )
SqueezeBertLMPredictionHeadc                    sL   t    t|| _tj|j|jdd| _t	t
|j| _| j| j_d S )NF)bias)r   r   r   	transformr   r   r'   r!   decoder	Parameterr,   r9   r   r.   r1   r3   r4   r     s
   

z$SqueezeBertLMPredictionHead.__init__returnNc                 C   s   | j | j_ d S rI   )r   r   rJ   r3   r3   r4   _tie_weights  rK   z(SqueezeBertLMPredictionHead._tie_weightsc                 C   s   |  |}| |}|S rI   )r   r   r   r3   r3   r4   r@     s   

z#SqueezeBertLMPredictionHead.forward)r   N)rB   rC   rD   r   r   r@   rF   r3   r3   r1   r4   r     s    
r   c                       r   )SqueezeBertOnlyMLMHeadc                    s   t    t|| _d S rI   )r   r   r   predictionsr.   r1   r3   r4   r     s   
zSqueezeBertOnlyMLMHead.__init__c                 C   s   |  |}|S rI   )r   )r/   sequence_outputprediction_scoresr3   r3   r4   r@     s   
zSqueezeBertOnlyMLMHead.forwardr   r3   r3   r1   r4   r     s    r   c                   @   s   e Zd ZeZdZdd ZdS )SqueezeBertPreTrainedModeltransformerc                 C   s   t |tjtjfr#|jjjd| jjd |j	dur!|j	j
  dS dS t |tjrF|jjjd| jjd |jdurD|jj|j 
  dS dS t |tjr[|j	j
  |jjd dS t |trh|j	j
  dS dS )zInitialize the weightsg        )meanZstdNg      ?)r   r   r   rV   weightdataZnormal_r0   Zinitializer_ranger   Zzero_r    r   r&   Zfill_r   )r/   moduler3   r3   r4   _init_weights  s    


z(SqueezeBertPreTrainedModel._init_weightsN)rB   rC   rD   r   Zconfig_classZbase_model_prefixr   r3   r3   r3   r4   r     s    r   c                       s   e Zd Z fddZdd Zdd Zdd Ze																		dd
ee	j
 dee	j
 dee	j
 dee	j
 dee	j
 dee	j dee dee dee deeef fddZ  ZS )SqueezeBertModelc                    s6   t  | t|| _t|| _t|| _|   d S rI   )	r   r   r   r?   r   encoderr   pooler	post_initr.   r1   r3   r4   r     s
   


zSqueezeBertModel.__init__c                 C   s   | j jS rI   r?   r#   rJ   r3   r3   r4   get_input_embeddings  s   z%SqueezeBertModel.get_input_embeddingsc                 C   s   || j _d S rI   r   r/   Znew_embeddingsr3   r3   r4   set_input_embeddings  s   z%SqueezeBertModel.set_input_embeddingsc                 C   s*   |  D ]\}}| jj| j| qdS )z
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        N)itemsr   r   r   Zprune_heads)r/   Zheads_to_pruner   Zheadsr3   r3   r4   _prune_heads  s   zSqueezeBertModel._prune_headsNr;   r|   r<   r   r   r=   r}   r   r   r   c
                 C   sZ  |d ur|n| j j}|d ur|n| j j}|	d ur|	n| j j}	|d ur*|d ur*td|d ur9| || | }
n|d urF| d d }
ntd|d urQ|jn|j}|d u r_tj	|
|d}|d u rltj
|
tj|d}| ||
}| || j j}| j||||d}| j||||||	d}|d }| |}|	s||f|d	d   S t|||j|jd
S )NzDYou cannot specify both input_ids and inputs_embeds at the same timer   z5You have to specify either input_ids or inputs_embeds)r7   r5   )r;   r   r<   r=   )r\   r|   r   r}   r   r   r   r   )r   Zpooler_outputr\   r   )r0   r}   r   use_return_dictre   Z%warn_if_padding_and_no_attention_maskr8   r7   r,   Zonesr9   r:   Zget_extended_attention_maskZget_head_maskr   r?   r   r   r   r\   r   )r/   r;   r|   r<   r   r   r=   r}   r   r   r>   r7   Zextended_attention_maskZembedding_outputZencoder_outputsr   r   r3   r3   r4   r@     sP   

zSqueezeBertModel.forward)	NNNNNNNNN)rB   rC   rD   r   r   r   r   r   r   r,   TensorZFloatTensorboolr   r   r   r@   rF   r3   r3   r1   r4   r     sH    
	

r   c                       s   e Zd ZddgZ fddZdd Zdd Ze																				dd
ee	j
 dee	j
 dee	j
 dee	j
 dee	j
 dee	j
 dee	j
 dee dee dee deeef fddZ  ZS )SqueezeBertForMaskedLMzcls.predictions.decoder.weightzcls.predictions.decoder.biasc                    s,   t  | t|| _t|| _|   d S rI   )r   r   r   r   r   clsr   r.   r1   r3   r4   r     s   

zSqueezeBertForMaskedLM.__init__c                 C   s
   | j jjS rI   )r   r   r   rJ   r3   r3   r4   get_output_embeddings&  s   
z,SqueezeBertForMaskedLM.get_output_embeddingsc                 C   s   || j j_|j| j j_d S rI   )r   r   r   r   r   r3   r3   r4   set_output_embeddings)  s   
z,SqueezeBertForMaskedLM.set_output_embeddingsNr;   r|   r<   r   r   r=   labelsr}   r   r   r   c                 C   s   |
dur|
n| j j}
| j||||||||	|
d	}|d }| |}d}|dur8t }||d| j j|d}|
sN|f|dd  }|durL|f| S |S t|||j|j	dS )a  
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
            config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
            loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
        Nr|   r<   r   r   r=   r}   r   r   r   r   rO   losslogitsr\   r   )
r0   r   r   r   r   rr   r!   r   r\   r   )r/   r;   r|   r<   r   r   r=   r   r}   r   r   outputsr   r   Zmasked_lm_lossloss_fctr_   r3   r3   r4   r@   -  s6   
zSqueezeBertForMaskedLM.forward
NNNNNNNNNN)rB   rC   rD   Z_tied_weights_keysr   r   r   r   r   r,   r   r   r   r   r   r@   rF   r3   r3   r1   r4   r     sN    		

r   z
    SqueezeBERT Model transformer with a sequence classification/regression head on top (a linear layer on top of the
    pooled output) e.g. for GLUE tasks.
    )Zcustom_introc                          e Zd Z fddZe										ddeej deej deej deej deej d	eej d
eej dee dee dee de	e
ef fddZ  ZS )$SqueezeBertForSequenceClassificationc                    sR   t  | |j| _|| _t|| _t|j| _	t
|j| jj| _|   d S rI   )r   r   
num_labelsr0   r   r   r   r)   r*   r+   r   r'   
classifierr   r.   r1   r3   r4   r   j  s   
z-SqueezeBertForSequenceClassification.__init__Nr;   r|   r<   r   r   r=   r   r}   r   r   r   c                 C   sr  |
dur|
n| j j}
| j||||||||	|
d	}|d }| |}| |}d}|dur| j jdu rV| jdkr<d| j _n| jdkrR|jtj	ksM|jtj
krRd| j _nd| j _| j jdkrtt }| jdkrn|| | }n+|||}n%| j jdkrt }||d| j|d}n| j jdkrt }|||}|
s|f|dd  }|dur|f| S |S t|||j|jd	S )
a  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        Nr   r   Z
regressionZsingle_label_classificationZmulti_label_classificationr   rO   r   )r0   r   r   r+   r   Zproblem_typer   r6   r,   r:   rf   r   squeezer   rr   r   r   r\   r   )r/   r;   r|   r<   r   r   r=   r   r}   r   r   r   r   r   r   r   r_   r3   r3   r4   r@   v  sV   



"


z,SqueezeBertForSequenceClassification.forwardr   )rB   rC   rD   r   r   r   r,   r   r   r   r   r   r@   rF   r3   r3   r1   r4   r   c  sH    	

r   c                       r   )SqueezeBertForMultipleChoicec                    s@   t  | t|| _t|j| _t|j	d| _
|   d S )Nr   )r   r   r   r   r   r)   r*   r+   r   r'   r   r   r.   r1   r3   r4   r     s
   
z%SqueezeBertForMultipleChoice.__init__Nr;   r|   r<   r   r   r=   r   r}   r   r   r   c                 C   sn  |
dur|
n| j j}
|dur|jd n|jd }|dur%|d|dnd}|dur4|d|dnd}|durC|d|dnd}|durR|d|dnd}|dure|d|d|dnd}| j||||||||	|
d	}|d }| |}| |}|d|}d}|durt }|||}|
s|f|dd  }|dur|f| S |S t	|||j
|jdS )a[  
        input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
            Indices of input sequence tokens in the vocabulary.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
            Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
            1]`:

            - 0 corresponds to a *sentence A* token,
            - 1 corresponds to a *sentence B* token.

            [What are token type IDs?](../glossary#token-type-ids)
        position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
            config.max_position_embeddings - 1]`.

            [What are position IDs?](../glossary#position-ids)
        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
            model's internal embedding lookup matrix.
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
            num_choices-1]` where *num_choices* is the size of the second dimension of the input tensors. (see
            *input_ids* above)
        Nr   r   r   rO   r   )r0   r   shaperr   r8   r   r+   r   r   r   r\   r   )r/   r;   r|   r<   r   r   r=   r   r}   r   r   Znum_choicesr   r   r   Zreshaped_logitsr   r   r_   r3   r3   r4   r@     sL   ,


z$SqueezeBertForMultipleChoice.forwardr   )rB   rC   rD   r   r   r   r,   r   r   r   r   r   r@   rF   r3   r3   r1   r4   r     sH    
	

r   c                       r   )!SqueezeBertForTokenClassificationc                    sJ   t  | |j| _t|| _t|j| _t	|j
|j| _|   d S rI   )r   r   r   r   r   r   r)   r*   r+   r   r'   r   r   r.   r1   r3   r4   r   *  s   
z*SqueezeBertForTokenClassification.__init__Nr;   r|   r<   r   r   r=   r   r}   r   r   r   c                 C   s   |
dur|
n| j j}
| j||||||||	|
d	}|d }| |}| |}d}|dur<t }||d| j|d}|
sR|f|dd  }|durP|f| S |S t|||j	|j
dS )z
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
        Nr   r   r   rO   r   )r0   r   r   r+   r   r   rr   r   r   r\   r   )r/   r;   r|   r<   r   r   r=   r   r}   r   r   r   r   r   r   r   r_   r3   r3   r4   r@   5  s8   

z)SqueezeBertForTokenClassification.forwardr   )rB   rC   rD   r   r   r   r,   r   r   r   r   r   r@   rF   r3   r3   r1   r4   r   (  sH    	

r   c                       s   e Zd Z fddZe											ddeej deej deej deej deej d	eej d
eej deej dee dee dee de	e
ef fddZ  ZS )SqueezeBertForQuestionAnsweringc                    s<   t  | |j| _t|| _t|j|j| _| 	  d S rI   )
r   r   r   r   r   r   r   r'   
qa_outputsr   r.   r1   r3   r4   r   m  s
   
z(SqueezeBertForQuestionAnswering.__init__Nr;   r|   r<   r   r   r=   start_positionsend_positionsr}   r   r   r   c                 C   sH  |d ur|n| j j}| j|||||||	|
|d	}|d }| |}|jddd\}}|d }|d }d }|d ur|d urt| dkrO|d}t| dkr\|d}|d}|	d|}|	d|}t
|d}|||}|||}|| d }|s||f|dd   }|d ur|f| S |S t||||j|jdS )	Nr   r   r   r   rb   )Zignore_indexrO   )r   start_logits
end_logitsr\   r   )r0   r   r   r   splitr   rv   r   r8   clampr   r   r\   r   )r/   r;   r|   r<   r   r   r=   r   r   r}   r   r   r   r   r   r   r   Z
total_lossZignored_indexr   Z
start_lossZend_lossr_   r3   r3   r4   r@   w  sP   






z'SqueezeBertForQuestionAnswering.forward)NNNNNNNNNNN)rB   rC   rD   r   r   r   r,   r   r   r   r   r   r@   rF   r3   r3   r1   r4   r   k  sN    
	

r   )r   r   r   r   r   r   r   r   )6rE   rz   typingr   r   r   r,   r   Ztorch.nnr   r   r   Zactivationsr
   Zmodeling_outputsr   r   r   r   r   r   r   Zmodeling_utilsr   utilsr   r   Zconfiguration_squeezebertr   Z
get_loggerrB   loggerModuler   rH   r&   rM   rR   r]   r`   r   r   r   r   r   r   r   r   r   r   r   r   r   __all__r3   r3   r3   r4   <module>   sR   $	
,Z*=
^IWgBM