o
    Zh)                     @   s  d Z ddlZddlZddlZddlmZ ddlmZmZm	Z	m
Z
mZmZ ddlZddlmZ ddlmZmZmZ ddlmZmZmZ dd	lmZ dd
lmZmZmZ ddlmZ ddlm Z m!Z!m"Z" ddl#m$Z$m%Z%m&Z& ddl'm(Z( e&)e*Z+dd Z,e- eeedZ.G dd dej/Z0G dd dej/Z1G dd dej/Z2G dd dej/Z3e%G dd deZ4eG dd de$Z5e%G dd  d e4Z6e%d!d"G d#d$ d$e4eZ7e%d%d"G d&d' d'e4Z8e%d(d"G d)d* d*e4Z9g d+Z:dS ),zPyTorch OpenAI GPT model.    N)	dataclass)AnyCallableDictOptionalTupleUnion)nn)BCEWithLogitsLossCrossEntropyLossMSELoss   )gelu_newget_activationsilu)GenerationMixin)BaseModelOutputCausalLMOutputSequenceClassifierOutput)PreTrainedModel)Conv1D find_pruneable_heads_and_indicesprune_conv1d_layer)ModelOutputauto_docstringlogging   )OpenAIGPTConfigc                    s  ddl }ddl dv rtjtd  td ddd}t	|}W d   n1 s3w   Y  td	 ddd}t	|}W d   n1 sQw   Y   
 fd
d|D } fddtdD }	  |	d|dd }	dd t|	|D }	dd |	D }	| jjj|	d jkrtd| jjj d|	d j | jjj|	d jkrtd| jjj d|	d j t|	d | jj_t|	d | jj_|d |	d |	d t||	D ]\}
}|
dd }
|
dd dkrtd|
 d|
dd }
|
d}
| }|
D ]S}|d|r#|d|}n|g}|d dkr3t|d}n!|d d kr@t|d!}n|d d"krMt|d}nt||d }t|d#kret|d }|| }q|j|jkr{td$|j d%|j d&td'|
  t||_q| S )(zGLoad tf pre-trained weights in a pytorch model (from NumPy arrays here)r   Nz.ckptzLoading weights from z/parameters_names.jsonrzutf-8)encodingz/params_shapes.jsonc                    s   g | ]}  |qS  )prod).0shape)npr    Y/var/www/auris/lib/python3.10/site-packages/transformers/models/openai/modeling_openai.py
<listcomp>;   s    z1load_tf_weights_in_openai_gpt.<locals>.<listcomp>c                    s"   g | ]}  d | d qS )z/params_z.npy)load)r"   nr$   openai_checkpoint_folder_pathr    r%   r&   <   s   " 
   c                 S   s   g | ]	\}}| |qS r    )Zreshape)r"   paramr#   r    r    r%   r&   >   s    c                 S   s   g | ]}|  qS r    )squeeze)r"   Zarrr    r    r%   r&   C   s    r   ztokens_embed.weight.shape: z% does not match init_param[1].shape: zpositions_embed.weight.shape: z% does not match init_param[0].shape:    z:0zLayer z does not end with :0/z[A-Za-z]+\d+z(\d+)gweightbbiasw   zPointer shape z and array shape z mismatchedzInitialize PyTorch weight )renumpyospathdirnameloggerinfoopenjsonr'   ZcumsumrangesplitZconcatenateziptokens_embedr3   r#   
ValueErrorpositions_embedtorchZ
from_numpydatapop	fullmatchgetattrlenint)modelconfigr*   r8   Znames_handlenamesZshapes_handleZshapesoffsetsZinit_paramsnamearrayZpointerZm_nameZscope_namesnumr    r)   r%   load_tf_weights_in_openai_gpt,   sx   



rU   )Zrelur   ZgeluZswishc                       sL   e Zd Zd fdd	Zdd ZdddZd	d
 ZdddZdddZ  Z	S )	AttentionFc                    s   t    |}||j dkrtd| d|j | jdtt||dd||dd |j| _|| _	|| _
t|d || _t||| _t|j| _t|j| _t | _d S )	Nr   zAttention n_state shape: z$ must be divisible by config.n_head r5   r   F
persistentr   )super__init__n_headrE   register_bufferrG   ZtrilZonesview
split_sizescaler   c_attnc_projr	   DropoutZ
attn_pdropattn_dropoutresid_pdropresid_dropoutsetpruned_heads)selfnxn_positionsrO   r_   n_state	__class__r    r%   rZ   ~   s"   
zAttention.__init__c                 C   s   t |dkrd S t|| j| j| j | j\}}t||| j |d| j  g}t| j|dd| _t| j	|dd| _	| j| j | jt |  | _| jt | | _| j
|| _d S )Nr   r7   r   dim)rL   r   r[   r^   rg   rG   catr   r`   ra   union)rh   headsindexZ
index_attnr    r    r%   prune_heads   s    zAttention.prune_headsNc           
      C   s   t ||}| jr|t|d }| jd d d d d |dd |df }|| dd|   }|d ur<|| }tjj	|dd}| 
|}|d urQ|| }t ||g}	|r_|	| |	S )Nr,   r0   g     r   rn   )rG   matmulr_   mathsqrtsizer5   r	   Z
functionalZsoftmaxrc   append)
rh   qkvattention_mask	head_maskoutput_attentionsr6   r4   outputsr    r    r%   _attn   s   .

zAttention._attnc                 C   sD   | dddd }| d d |d|d f }|j| S )Nr   r7   r   r   r0   r,   )permute
contiguousrx   r]   )rh   xnew_x_shaper    r    r%   merge_heads   s   &
zAttention.merge_headsc                 C   sT   |  d d | j| d| j f }|j| }|r"|ddddS |ddddS )Nr,   r   r7   r   r   )rx   r[   r]   r   )rh   r   r{   r   r    r    r%   split_heads   s
   &
zAttention.split_headsc                 C   s   |  |}|j| jdd\}}}| |}| j|dd}| |}| ||||||}|d }	| |	}	| |	}	| |	}	|	g|dd   }
|
S )Nr7   rn   T)r{   r   r   )r`   rB   r^   r   r   r   ra   re   )rh   r   r}   r~   r   querykeyvalueattn_outputsar   r    r    r%   forward   s   





zAttention.forwardFNNF)
__name__
__module____qualname__rZ   rt   r   r   r   r   __classcell__r    r    rl   r%   rV   }   s    

rV   c                       s$   e Zd Z fddZdd Z  ZS )MLPc                    sF   t    |j}t||| _t||| _t|j | _t	
|j| _d S N)rY   rZ   n_embdr   c_fcra   ACT_FNSZafnactr	   rb   rd   dropout)rh   rk   rO   ri   rl   r    r%   rZ      s   
zMLP.__init__c                 C   s$   |  | |}| |}| |S r   )r   r   ra   r   )rh   r   hZh2r    r    r%   r      s   

zMLP.forwardr   r   r   rZ   r   r   r    r    rl   r%   r      s    r   c                       s(   e Zd Zd fdd	ZdddZ  ZS )	BlockFc                    sX   t    |j}t||||| _tj||jd| _t	d| || _
tj||jd| _d S )N)Zeps   )rY   rZ   r   rV   attnr	   	LayerNormZlayer_norm_epsilonln_1r   mlpln_2)rh   rj   rO   r_   ri   rl   r    r%   rZ      s   
zBlock.__init__Nc                 C   sV   | j ||||d}|d }| || }| |}| || }	|	g|dd   }
|
S )N)r}   r~   r   r   r   )r   r   r   r   )rh   r   r}   r~   r   r   r   r(   mr   r   r    r    r%   r      s   
zBlock.forwardr   r   r   r    r    rl   r%   r      s    r   c                       sJ   e Zd ZdZdef fddZ	ddejdeej	 dejfd	d
Z
  ZS )OpenAIGPTSequenceSummarya  
    Compute a single vector summary of a sequence hidden states.

    Args:
        config ([`OpenAIGPTConfig`]):
            The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
            config class of your model for the default values it uses):

            - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are:

                - `"last"` -- Take the last token hidden state (like XLNet)
                - `"first"` -- Take the first token hidden state (like Bert)
                - `"mean"` -- Take the mean of all tokens hidden states
                - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
                - `"attn"` -- Not implemented now, use multi-head attention

            - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction.
            - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes
              (otherwise to `config.hidden_size`).
            - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output,
              another string or `None` will add no activation.
            - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation.
            - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation.
    rO   c                    s   t    t|dd| _| jdkrtt | _t|dr<|j	r<t|dr1|j
r1|jdkr1|j}n|j}t|j|| _t|dd }|rHt|nt | _t | _t|drc|jdkrct|j| _t | _t|d	r{|jdkr}t|j| _d S d S d S )
Nsummary_typelastr   summary_use_projsummary_proj_to_labelsr   Zsummary_activationsummary_first_dropoutsummary_last_dropout)rY   rZ   rK   r   NotImplementedErrorr	   ZIdentitysummaryhasattrr   r   
num_labelsZhidden_sizeLinearr   
activationfirst_dropoutr   rb   last_dropoutr   )rh   rO   Znum_classesZactivation_stringrl   r    r%   rZ     s&   




z!OpenAIGPTSequenceSummary.__init__Nhidden_states	cls_indexreturnc                 C   s  | j dkr|dddf }ne| j dkr|dddf }nW| j dkr(|jdd}nK| j d	krl|du rItj|d
ddddf |jd d tjd}n|dd}|d| d  |	df }|
d|d}n| j dkrst| |}| |}| |}| |}|S )ak  
        Compute a single vector summary of a sequence hidden states.

        Args:
            hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`):
                The hidden states of the last layer.
            cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*):
                Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token.

        Returns:
            `torch.FloatTensor`: The summary of the sequence hidden states.
        r   Nr,   firstr   meanr   rn   r   .r0   dtype)r,   r   )r   r   rG   Z	full_liker#   long	unsqueezeexpandro   rx   gatherr.   r   r   r   r   r   )rh   r   r   outputr    r    r%   r   ;  s.   



"




z OpenAIGPTSequenceSummary.forwardr   )r   r   r   __doc__r   rZ   rG   FloatTensorr   
LongTensorr   r   r    r    rl   r%   r     s    r   c                   @   s    e Zd ZeZeZdZdd ZdS )OpenAIGPTPreTrainedModeltransformerc                 C   s   t |tjtfr"|jjjd| jjd |j	dur |j	j
  dS dS t |tjrE|jjjd| jjd |jdurC|jj|j 
  dS dS t |tjrZ|j	j
  |jjd dS dS )zInitialize the weights.g        )r   ZstdN      ?)
isinstancer	   r   r   r3   rH   Znormal_rO   Zinitializer_ranger5   Zzero_	EmbeddingZpadding_idxr   Zfill_)rh   moduler    r    r%   _init_weightsm  s   

z&OpenAIGPTPreTrainedModel._init_weightsN)	r   r   r   r   Zconfig_classrU   Zload_tf_weightsZbase_model_prefixr   r    r    r    r%   r   g  s
    r   c                   @   s   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eej ed< dZeej ed< dZeeej  ed< dZeeej  ed< dS )	OpenAIGPTDoubleHeadsModelOutputa  
    Base class for outputs of models predicting if two sentences are consecutive or not.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Language modeling loss.
        mc_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mc_labels` is provided):
            Multiple choice classification loss.
        logits (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, config.vocab_size)`):
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        mc_logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`):
            Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    Nlossmc_losslogits	mc_logitsr   
attentions)r   r   r   r   r   r   rG   r   __annotations__r   r   r   r   r   r   r    r    r    r%   r   ~  s   
 r   c                       s   e Zd Z fddZdd Zdd Zdd Ze																		dd
ee	j
 dee	j dee	j
 dee	j
 dee	j dee	j dee dee dee deee	j ef fddZ  ZS )OpenAIGPTModelc                    s   t    t j j| _t j j| _t	 j
| _t fddt jD | _| jdt jdd |   d S )Nc                    s   g | ]
}t  j d dqS )T)r_   )r   rj   )r"   _rO   r    r%   r&     s    z+OpenAIGPTModel.__init__.<locals>.<listcomp>position_idsFrW   )rY   rZ   r	   r   
vocab_sizer   rD   rj   rF   rb   Z
embd_pdropdropZ
ModuleListrA   n_layerr   r\   rG   arange	post_initrh   rO   rl   r   r%   rZ     s    zOpenAIGPTModel.__init__c                 C      | j S r   rD   rh   r    r    r%   get_input_embeddings     z#OpenAIGPTModel.get_input_embeddingsc                 C   
   || _ d S r   r   rh   Znew_embeddingsr    r    r%   set_input_embeddings     
z#OpenAIGPTModel.set_input_embeddingsc                 C   s(   |  D ]\}}| j| j| qdS )zv
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
        N)itemsr   r   rt   )rh   Zheads_to_prunelayerrr   r    r    r%   _prune_heads  s   zOpenAIGPTModel._prune_headsN	input_idsr}   token_type_idsr   r~   inputs_embedsr   output_hidden_statesreturn_dictr   c
                 C   s0  |d ur|n| j j}|d ur|n| j j}|	d ur|	n| j j}	|d ur*|d ur*td|d urA| || | }
|d|
d }n|d urN| d d }
ntd|d u ra| jd d |
d f }|d ur|	d	d}|j
t|  jd}d| t| jj }| || j j}|d u r| |}| |}|d ur|d|d}| |}nd}|| | }| |}|
|df }|rd	nd }|rd	nd }t| jD ]"\}}|r||f }||||| |d
}|d }|r||d f }q|j| }|r||f }|	stdd |||fD S t|||dS )NzDYou cannot specify both input_ids and inputs_embeds at the same timer,   z5You have to specify either input_ids or inputs_embedsr   r7   r   r   r   r    )r   c                 s   s    | ]	}|d ur|V  qd S r   r    )r"   r|   r    r    r%   	<genexpr>  s    z)OpenAIGPTModel.forward.<locals>.<genexpr>)Zlast_hidden_stater   r   )rO   r   r   use_return_dictrE   Z%warn_if_padding_and_no_attention_maskrx   r]   r   r   tonext
parametersr   rG   ZfinfominZget_head_maskr   rD   rF   r   	enumerater   tupler   )rh   r   r}   r   r   r~   r   r   r   r   Zinput_shapeZposition_embedsZtoken_type_embedsr   Zoutput_shapeZall_attentionsZall_hidden_statesiblockr   r    r    r%   r     sd   





zOpenAIGPTModel.forward)	NNNNNNNNN)r   r   r   rZ   r   r   r   r   r   rG   r   r   boolr   r   Tensorr   r   r   r    r    rl   r%   r     sH    	
r   z
    OpenAI GPT Model transformer with a language modeling head on top (linear layer with weights tied to the input
    embeddings).
    )Zcustom_introc                       s   e Zd ZdgZ fddZdd Zdd Ze										dd	ee	j
 d
ee	j dee	j
 dee	j
 dee	j dee	j dee	j
 dee dee dee deee	j ef fddZd	e	j
deeef fddZ  ZS )OpenAIGPTLMHeadModellm_head.weightc                    s8   t  | t|| _tj|j|jdd| _| 	  d S NFr5   )
rY   rZ   r   r   r	   r   r   r   lm_headr   r   rl   r    r%   rZ   "  s   
zOpenAIGPTLMHeadModel.__init__c                 C   r   r   r   r   r    r    r%   get_output_embeddings*  r   z*OpenAIGPTLMHeadModel.get_output_embeddingsc                 C   r   r   r   r   r    r    r%   set_output_embeddings-  r   z*OpenAIGPTLMHeadModel.set_output_embeddingsNr   r}   r   r   r~   r   labelsr   r   r   r   c                 K   s   |
dur|
n| j j}
| j||||||||	|
d	}|d }| |}d}|dur5| j||fd| j ji|}|
sK|f|dd  }|durI|f| S |S t|||j|jdS )a  
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
            `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
            are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
        Nr}   r   r   r~   r   r   r   r   r   r   r   r   r   r   r   )	rO   r   r   r   Zloss_functionr   r   r   r   )rh   r   r}   r   r   r~   r   r   r   r   r   kwargstransformer_outputsr   	lm_logitsr   r   r    r    r%   r   0  sB   
zOpenAIGPTLMHeadModel.forwardc                 K   s   d|iS )Nr   r    )rh   r   r  r    r    r%   prepare_inputs_for_generationj  s   z2OpenAIGPTLMHeadModel.prepare_inputs_for_generation
NNNNNNNNNN)r   r   r   _tied_weights_keysrZ   r   r   r   r   rG   r   r   r   r   r   r   r   r   r   strr   r  r   r    r    rl   r%   r     sP    	
$9r   a  
        OpenAI GPT Model transformer with a language modeling and a multiple-choice classification head on top e.g. for
    RocStories/SWAG tasks. The two heads are two linear layers. The language modeling head has its weights tied to the
    input embeddings, the classification head takes as input the input of a specified classification token index in the
    input sequence).
    c                       s   e Zd ZdgZ fddZdd Zdd Ze												dd	ee	j
 d
ee	j dee	j
 dee	j
 dee	j dee	j dee	j
 dee	j
 dee	j
 dee dee dee deee	j ef fddZ  ZS )OpenAIGPTDoubleHeadsModelr   c                    sH   t  | d|_t|| _tj|j|jdd| _	t
|| _|   d S )Nr   Fr   )rY   rZ   r   r   r   r	   r   r   r   r   r   multiple_choice_headr   r   rl   r    r%   rZ   z  s   

z"OpenAIGPTDoubleHeadsModel.__init__c                 C   r   r   r   r   r    r    r%   r     r   z/OpenAIGPTDoubleHeadsModel.get_output_embeddingsc                 C   r   r   r   r   r    r    r%   r     r   z/OpenAIGPTDoubleHeadsModel.set_output_embeddingsNr   r}   r   r   r~   r   mc_token_idsr   	mc_labelsr   r   r   r   c                 C   s:  |dur|n| j j}| j|||||||
||d	}|d }| |}| ||d}d\}}|	durDt }||d|d|	d}|durq|dddddf 	 }|dddf 	 }t }||d|d|d}|s||f|dd  }|dur|f| }|dur|f| S |S t
|||||j|jdS )	a  
        mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input):
            Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) -
            1]`.
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
            `labels = input_ids` Indices are selected in `[-1, 0, ..., config.vocab_size]` All labels set to `-100` are
            ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
        mc_labels (`torch.LongTensor` of shape `(batch_size)`, *optional*):
            Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]`
            where *num_choices* is the size of the second dimension of the input tensors. (see *input_ids* above)

        Examples:

        ```python
        >>> from transformers import AutoTokenizer, OpenAIGPTDoubleHeadsModel
        >>> import torch

        >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/openai-gpt")
        >>> model = OpenAIGPTDoubleHeadsModel.from_pretrained("openai-community/openai-gpt")
        >>> tokenizer.add_special_tokens(
        ...     {"cls_token": "[CLS]"}
        ... )  # Add a [CLS] to the vocabulary (we should train it also!)
        >>> model.resize_token_embeddings(len(tokenizer))

        >>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"]
        >>> input_ids = torch.tensor([tokenizer.encode(s) for s in choices]).unsqueeze(0)  # Batch size 1, 2 choices
        >>> mc_token_ids = torch.tensor([input_ids.size(-1) - 1, input_ids.size(-1) - 1]).unsqueeze(0)  # Batch size 1

        >>> outputs = model(input_ids, mc_token_ids=mc_token_ids)
        >>> lm_logits = outputs.logits
        >>> mc_logits = outputs.mc_logits
        ```Nr  r   r,   )NN.r   )r   r   r   r   r   r   )rO   r   r   r   r  r.   r   r]   rx   r   r   r   r   )rh   r   r}   r   r   r~   r   r  r   r  r   r   r   r  r   r  r   Zlm_lossr   loss_fctZshift_logitsZshift_labelsr   r    r    r%   r     sJ   1

z!OpenAIGPTDoubleHeadsModel.forward)NNNNNNNNNNNN)r   r   r   r  rZ   r   r   r   r   rG   r   r   r   r   r   r   r   r   r   r    r    rl   r%   r
  o  sZ    		
r
  a  
    The Original OpenAI GPT Model transformer with a sequence classification head on top (linear layer).
    [`OpenAIGPTForSequenceClassification`] uses the last token in order to do the classification, as other causal
    models (e.g. GPT-2) do. Since it does classification on the last token, it requires to know the position of the
    last token. If a `pad_token_id` is defined in the configuration, it finds the last token that is not a padding
    token in each row. If no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since
    it cannot guess the padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take
    the last value in each row of the batch).
    c                       s   e Zd Z fddZe										ddeej deej deej deej deej d	eej d
eej dee	 dee	 dee	 de
eej ef fddZ  ZS )"OpenAIGPTForSequenceClassificationc                    s@   t  | |j| _t|| _tj|j| jdd| _| 	  d S r   )
rY   rZ   r   r   r   r	   r   r   scorer   r   rl   r    r%   rZ     s
   
z+OpenAIGPTForSequenceClassification.__init__Nr   r}   r   r   r~   r   r   r   r   r   r   c                 C   sB  |
dur|
n| j j}
| j||||||||	|
d	}|d }| |}|dur/|jdd \}}n	|jdd \}}| j jdu rF|dkrFtd| j jdu rOd}n1|durt|| j jk|jt	j
}t	j|jd |jt	j
d}|| d}nd}t| jj d	 |t	j||jd
|f }d}|dur| j jdu r| jdkrd| j _n| jdkr|jt	jks|jt	jkrd| j _nd| j _| j jdkrt }| jdkr|| | }n+|||}n%| j jdkrt }||d| j|d}n| j jdkrt }|||}|
s|f|dd  }|dur|f| S |S t|||j|jdS )a  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        Nr  r   r7   r   z=Cannot handle batch sizes > 1 if no padding token is defined.r,   )devicer   z will not detect padding tokens in `inputs_embeds`. Results may be unexpected if using padding tokens in conjunction with `inputs_embeds.`)r  Z
regressionZsingle_label_classificationZmulti_label_classificationr  )rO   r   r   r  r#   Zpad_token_idrE   r   r  rG   Zint32r   Zargmaxr=   Zwarning_oncerm   r   Zproblem_typer   r   r   rM   r   r.   r   r]   r
   r   r   r   )rh   r   r}   r   r   r~   r   r   r   r   r   r  r   r   Z
batch_sizeZsequence_lengthZlast_non_pad_tokenZnon_pad_maskZtoken_indicesZpooled_logitsr   r  r   r    r    r%   r     st   


"


z*OpenAIGPTForSequenceClassification.forwardr  )r   r   r   rZ   r   r   rG   r   r   r   r   r   r   r   r   r   r    r    rl   r%   r    sH    		
r  )r
  r  r   r   r   rU   );r   r@   rv   r:   dataclassesr   typingr   r   r   r   r   r   rG   r	   Ztorch.nnr
   r   r   Zactivationsr   r   r   Z
generationr   Zmodeling_outputsr   r   r   Zmodeling_utilsr   Zpytorch_utilsr   r   r   utilsr   r   r   Zconfiguration_openair   Z
get_loggerr   r=   rU   ZReLUr   ModulerV   r   r   r   r   r   r   r   r
  r  __all__r    r    r    r%   <module>   sT    
N]c"wPqi