o
    Zh                     @   sB  d Z ddlZddlmZmZmZmZ ddlZddlZddlm	Z	 ddl
mZ ddlmZ ddlmZmZ dd	lmZmZ dd
lmZ ddlmZmZ ddlmZ eeZG dd de	jZG dd de	j Z!G dd de	j Z"G dd de	j Z#eG dd deZ$eG dd de$Z%eddG dd de$eZ&g dZ'dS )zPyTorch XGLM model.    N)ListOptionalTupleUnion)nn   )ACT2FN)GenerationMixin)_prepare_4d_attention_mask!_prepare_4d_causal_attention_mask))BaseModelOutputWithPastAndCrossAttentions!CausalLMOutputWithCrossAttentions)PreTrainedModel)auto_docstringlogging   )
XGLMConfigc                
       sL   e Zd ZdZddedededee f fddZd	ej	f fd
dZ
  ZS )XGLMScaledWordEmbeddingz\
    This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
          ?num_embeddingsembedding_dimpadding_idxembed_scalec                    s   t  ||| || _d S N)super__init__r   )selfr   r   r   r   	__class__ U/var/www/auris/lib/python3.10/site-packages/transformers/models/xglm/modeling_xglm.pyr   *   s   
z XGLMScaledWordEmbedding.__init__	input_idsc                    s   t  || j S r   )r   forwardr   )r   r!   r   r   r    r"   .   s   zXGLMScaledWordEmbedding.forward)r   )__name__
__module____qualname____doc__intr   floatr   torchTensorr"   __classcell__r   r   r   r    r   %   s    $r   c                	       s   e Zd ZdZddededee f fddZddededee fd	d
Zeddededee fddZ	e
 ddee
j defddZ  ZS )!XGLMSinusoidalPositionalEmbeddingzDThis module produces sinusoidal positional embeddings of any length.Nnum_positionsr   r   c                    s4   t    d| _|| _|| _| || j || d S )N   )r   r   offsetr   r   make_weights)r   r-   r   r   r   r   r    r   5   s
   
z*XGLMSinusoidalPositionalEmbedding.__init__r   c                 C   sB   |  |||}t| dr|j| jj| jjd}| jd|dd d S )NweightsdtypedeviceF)
persistent)get_embeddinghasattrtor1   r3   r4   Zregister_buffer)r   r   r   r   Zemb_weightsr   r   r    r0   <   s   
z.XGLMSinusoidalPositionalEmbedding.make_weightsc                 C   s   |d }t d|d  }ttj|tjd |  }tj| tjd d|d }tjt	|t
|gdd| d}|d dkrUtj|t| dgdd}|durad||ddf< |t S )	z
        Build sinusoidal embeddings.

        This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of
        "Attention Is All You Need".
        r.   i'  r   )r3   r   dimN)mathlogr)   exparangeZint64r(   	unsqueezecatsincosviewZzerosr8   Zget_default_dtype)r   r   r   Zhalf_dimZembr   r   r    r6   D   s    $&z/XGLMSinusoidalPositionalEmbedding.get_embeddingr   position_idspast_key_values_lengthc                 C   sn   |  \}}|| j7 }d| | }|| j dkr"| || j| j | jd|d||| jjd 	 S )Nr.   r   r;   )
sizer/   r1   r0   r   r   index_selectrD   shapedetach)r   rE   rF   bszseq_lenZmax_posr   r   r    r"   Y   s   
*z)XGLMSinusoidalPositionalEmbedding.forwardr   )Nr   )r#   r$   r%   r&   r'   r   r   r0   staticmethodr6   r)   Zno_gradr*   r"   r+   r   r   r   r    r,   2   s     $r,   c                       s   e Zd ZdZ			ddedededed	ef
 fd
dZdej	dedefddZ
					ddej	deej	 deeej	  deej	 deej	 dedeej	eej	 eeej	  f fddZ  ZS )XGLMAttentionz=Multi-headed attention from 'Attention Is All You Need' paper        FT	embed_dim	num_headsdropout
is_decoderbiasc                    s   t    || _|| _|| _|| | _| j| | jkr'td| j d| d| jd | _|| _t	j
|||d| _t	j
|||d| _t	j
|||d| _t	j
|||d| _d S )Nz;embed_dim must be divisible by num_heads (got `embed_dim`: z and `num_heads`: z).g      ࿩rT   )r   r   rP   rQ   rR   head_dim
ValueErrorscalingrS   r   Lineark_projv_projq_projout_proj)r   rP   rQ   rR   rS   rT   r   r   r    r   i   s"   


zXGLMAttention.__init__tensorrL   rK   c                 C   s    | ||| j| jdd S )Nr   r.   )rD   rQ   rV   	transpose
contiguous)r   r^   rL   rK   r   r   r    _shape   s    zXGLMAttention._shapeNhidden_stateskey_value_statespast_key_valueattention_masklayer_head_maskoutput_attentionsreturnc                 C   s  |du}|  \}}	}
| || j }|r"|dur"|d }|d }nZ|r9| | |d|}| | |d|}nC|durh| | |d|}| | |d|}tj|d |gdd}tj|d |gdd}n| | |d|}| | |d|}| jr||f}|| j	 d| j
f}| ||	|j| }|j| }|j| }| d}t||dd}|  || j	 |	|fkrtd|| j	 |	|f d|   |dur|  |d|	|fkrtd	|d|	|f d|   ||| j	|	|| }t|tjt|jj|jd
}||| j	 |	|}|jtjkr(tjj|dtjdtj}ntjj|dd}|durg|  | j	fkrLtd| j	f d|   |dddd||| j	|	| }||| j	 |	|}|r~||| j	|	|}||| j	 |	|}nd}tjj|| j| jd}t||}|  || j	 |	| j
fkrtd|| j	|	| j
f d|   ||| j	|	| j
}|dd}|||	| j}| |}|||fS )z#Input shape: Batch x Time x ChannelNr   r   r;   r.   r9   z$Attention weights should be of size z	, but is z!Attention mask should be of size )r4   )r:   r3   z/Head mask for a single layer should be of size ptrainingz `attn_output` should be of size ) rG   r\   rX   ra   rZ   r[   r)   rA   rS   rQ   rV   rD   Zbmmr_   rW   maxr^   Zfinfor3   minr4   Zfloat16r   
functionalZsoftmaxZfloat32r8   rR   rk   ZreshaperP   r]   )r   rb   rc   rd   re   rf   rg   Zis_cross_attentionrK   tgt_len_Zquery_statesZ
key_statesZvalue_statesZ
proj_shapeZsrc_lenZattn_weightsZattn_weights_reshapedZ
attn_probsZattn_outputr   r   r    r"      s   





"

zXGLMAttention.forward)rO   FT)NNNNF)r#   r$   r%   r&   r'   r(   boolr   r)   r*   ra   r   r   r"   r+   r   r   r   r    rN   f   sJ    rN   c                       s   e Zd Zdef fddZ								ddejdeej d	eej d
eej deej deej deeej  dee	 dee	 dejfddZ
  ZS )XGLMDecoderLayerconfigc                    s   t    |j| _t| j|j|jdd| _|j| _t	|j
 | _|j| _|jr9t| j|j|jdd| _t| j| _t| j| _t| j|j| _t|j| j| _t| j| _d S )NT)rP   rQ   rR   rS   )r   r   d_modelrP   rN   Zattention_headsZattention_dropout	self_attnrR   r   Zactivation_functionactivation_fnactivation_dropoutZadd_cross_attentionencoder_attnr   	LayerNormencoder_attn_layer_normself_attn_layer_normrY   Zffn_dimfc1fc2final_layer_normr   rs   r   r   r    r     s.   
zXGLMDecoderLayer.__init__NFTrb   re   encoder_hidden_statesencoder_attention_maskrf   cross_attn_layer_head_maskrd   rg   	use_cacherh   c
                 C   sZ  |}
|  |}|dur|dd nd}| j|||||d\}}}tjj|| j| jd}|
| }d}d}|durk|}
| |}|durH|dd nd}| j||||||d\}}}tjj|| j| jd}|
| }|| }|}
| |}| 	| 
|}tjj|| j| jd}| |}tjj|| j| jd}|
| }|f}|r|||f7 }|	r||f7 }|S )a  
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`): attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
            encoder_hidden_states (`torch.FloatTensor`):
                cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
            encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
            layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
                `(encoder_attention_heads,)`.
            cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
                size `(decoder_attention_heads,)`.
            past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
        Nr.   )rb   rd   re   rf   rg   ri   )rb   rc   re   rf   rd   rg   )r{   ru   r   rn   rR   rk   rz   rx   r~   rv   r|   rw   r}   )r   rb   re   r   r   rf   r   rd   rg   r   ZresidualZself_attn_past_key_valueZself_attn_weightsZpresent_key_valueZcross_attn_present_key_valueZcross_attn_weightsZcross_attn_past_key_valueoutputsr   r   r    r"     sT   




zXGLMDecoderLayer.forward)NNNNNNFT)r#   r$   r%   r   r   r)   r*   r   r   rq   r"   r+   r   r   r   r    rr      s>     	
rr   c                   @   s&   e Zd ZeZdZdZdgZdd ZdS )XGLMPreTrainedModelmodelTrr   c                 C   s   | j j}t|tjr"|jjjd|d |jd ur |jj	  d S d S t|tj
rA|jjjd|d |jd urC|jj|j 	  d S d S d S )NrO   )meanstd)rs   Zinit_std
isinstancer   rY   weightdataZnormal_rT   Zzero_	Embeddingr   )r   moduler   r   r   r    _init_weights  s   

z!XGLMPreTrainedModel._init_weightsN)	r#   r$   r%   r   Zconfig_classbase_model_prefixZsupports_gradient_checkpointingZ_no_split_modulesr   r   r   r   r    r   x  s    r   c                        s   e Zd Zddedeej f fddZdd Zdd	 Z	e
													dd
eej deej deej deej deej deej deej deeej  deej dee dee dee dee deeej ef fddZ  ZS )	XGLMModelNrs   embed_tokensc                    s   t     j| _ j| _ j| _ j| _ jrt	
 jnd}|dur)|| _nt j j| j|d| _t j j j| _t fddt jD | _t j| _d| _|   dS )zZ
        embed_tokens (`nn.Embedding`, *optional*):
            output embeddings
        r   N)r   c                    s   g | ]}t  qS r   )rr   ).0rp   rs   r   r    
<listcomp>  s    z&XGLMModel.__init__.<locals>.<listcomp>F)r   r   rR   	layerdroppad_token_idr   Zmax_position_embeddingsZmax_target_positionsZscale_embeddingr<   sqrtrt   r   r   
vocab_sizer,   embed_positionsr   Z
ModuleListrangeZ
num_layerslayersry   
layer_normgradient_checkpointing	post_init)r   rs   r   r   r   r   r    r     s(    zXGLMModel.__init__c                 C      | j S r   r   r   r   r   r    get_input_embeddings     zXGLMModel.get_input_embeddingsc                 C   
   || _ d S r   r   r   valuer   r   r    set_input_embeddings     
zXGLMModel.set_input_embeddingsr!   re   rE   r   r   	head_maskcross_attn_head_maskpast_key_valuesinputs_embedsr   rg   output_hidden_statesreturn_dictrh   c                 C   s  |dur|n| j j}|dur|n| j j}|
dur|
n| j j}
|dur$|n| j j}|dur4|	dur4td|durK| || | }|d|d }n|	durX|	 dd }ntd|duri|d d j	d nd}|du rt
j||d | t
j|dur|jn|	jd}|d}|	du r| |}	t|||	|}|dur|durt||	j|d d}|	| |||	j }tjj|t| j| jd	}| jr| jr|
rtd
 d}
|rdnd}|rdnd}|r|durdnd}|
rdnd}t||gddgD ]*\}}|dur#| d t| jkr#td| dt| j d| d  dqt | jD ]\}}|r5||f7 }| jrFt
!g }|| j"k rFq)|durO|| nd}| jr{| jr{| #|j$|||||durj|| nd|durt|| ndd||

}n||||||dur|| nd|dur|| nd|||
d	}|d }|
r|||rdnd f7 }|r||d f7 }|dur||d f7 }q)| %|}|r||f7 }|
r|nd}|st&dd |||||fD S t'|||||dS )a  
        encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
            Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of
            the decoder.
        encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
            Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
            selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)
        cross_attn_head_mask (`torch.Tensor` of shape `(num_layers, attention_heads)`, *optional*):
            Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.
        NzDYou cannot specify both input_ids and inputs_embeds at the same timer;   z5You have to specify either input_ids or inputs_embedsr   r.   r2   )ro   ri   z_`use_cache = True` is incompatible with gradient checkpointing`. Setting `use_cache = False`...Fr   r   r   zThe `z` should be specified for z layers, but it is for .)re   r   r   rf   r   rd   rg   r   r   r   c                 s   s    | ]	}|d ur|V  qd S r   r   )r   vr   r   r    	<genexpr>T  s    z$XGLMModel.forward.<locals>.<genexpr>)Zlast_hidden_stater   rb   
attentionscross_attentions)(rs   rg   r   r   use_return_dictrW   Z%warn_if_padding_and_no_attention_maskrG   rD   rI   r)   r?   longr4   r@   r   r   r
   r3   r   r8   r   rn   rR   r(   rk   r   loggerZwarning_onceziplenr   	enumerateZrandr   Z_gradient_checkpointing_func__call__r   tupler   )r   r!   re   rE   r   r   r   r   r   r   r   rg   r   r   Zinput_shaperF   rb   Zall_hidden_statesZall_self_attnsZall_cross_attentionsZnext_decoder_cacheZ	attn_maskZ	mask_nameidxZdecoder_layerZdropout_probabilityrd   Zlayer_outputsZ
next_cacher   r   r    r"     s   #









zXGLMModel.forwardr   )NNNNNNNNNNNNN)r#   r$   r%   r   r   r   r   r   r   r   r   r)   r*   r   FloatTensorrq   r   r   r   r"   r+   r   r   r   r    r     s^    	
r   z
    The XGLM Model transformer with a language modeling head on top (linear layer with weights tied to the input
    embeddings).
    )Zcustom_introc                "       s  e Zd ZdZdgZ fddZdd Zdd Zd	d
 Zdd Z	e
														d!deej deej deej deej deej deej deej deeej  deej deej dee dee dee dee deeej ef fddZedd  Z  ZS )"XGLMForCausalLMr   zlm_head.weightc                    s8   t  | t|| _tj|j|jdd| _| 	  d S )NFrU   )
r   r   r   r   r   rY   Zhidden_sizer   lm_headr   r   r   r   r    r   l  s   
zXGLMForCausalLM.__init__c                 C   s   | j jS r   r   r   r   r   r   r    r   t  s   z$XGLMForCausalLM.get_input_embeddingsc                 C   s   || j _d S r   r   r   r   r   r    r   w  s   z$XGLMForCausalLM.set_input_embeddingsc                 C   r   r   r   r   r   r   r    get_output_embeddingsz  r   z%XGLMForCausalLM.get_output_embeddingsc                 C   r   r   r   )r   Znew_embeddingsr   r   r    set_output_embeddings}  r   z%XGLMForCausalLM.set_output_embeddingsNr!   re   rE   r   r   r   r   r   r   labelsr   rg   r   r   rh   c                 K   s   |dur|n| j j}|dur|n| j j}|dur|n| j j}| j|||||||||	||||d}| |d }d}|
durN| j||
f| j j| j jd|}|sd|f|dd  }|durb|f| S |S t	|||j
|j|j|jdS )a  
        encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
            Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of
            the decoder.
        encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
            Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
            selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)
        cross_attn_head_mask (`torch.Tensor` of shape `(num_layers, attention_heads)`, *optional*):
            Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
        N)r!   re   rE   r   r   r   r   r   r   r   rg   r   r   r   )r   r   r   )losslogitsr   rb   r   r   )rs   rg   r   r   r   r   Zloss_functionr   r   r   r   rb   r   r   )r   r!   re   rE   r   r   r   r   r   r   r   r   rg   r   r   kwargsr   r   r   outputr   r   r    r"     sT   *zXGLMForCausalLM.forwardc                    s.   d}| D ]}|t  fdd|D f7 }q|S )Nr   c                 3   s$    | ]}| d  |jV  qdS )r   N)rH   r8   r4   )r   Z
past_statebeam_idxr   r    r     s   " z1XGLMForCausalLM._reorder_cache.<locals>.<genexpr>)r   )r   r   Zreordered_pastZ
layer_pastr   r   r    _reorder_cache  s   zXGLMForCausalLM._reorder_cache)NNNNNNNNNNNNNN)r#   r$   r%   r   Z_tied_weights_keysr   r   r   r   r   r   r   r)   r*   r   r   rq   r   r   r   r"   rM   r   r+   r   r   r   r    r   b  sp    	
Yr   )r   r   r   )(r&   r<   typingr   r   r   r   r)   Ztorch.utils.checkpointr   Zactivationsr   Z
generationr	   Zmodeling_attn_mask_utilsr
   r   Zmodeling_outputsr   r   Zmodeling_utilsr   utilsr   r   Zconfiguration_xglmr   Z
get_loggerr#   r   r   r   Moduler,   rN   rr   r   r   r   __all__r   r   r   r    <module>   s:   
4 x W|