o
    Zh                     @   s  d Z ddlZddlmZmZmZ ddlZddlZddlmZ ddl	m
Z
mZmZmZ ddl	mZ ddlmZ dd	lmZ dd
lmZmZmZmZmZ ddlmZ ddlmZmZ ddlm Z  e!e"Z#d(ddZ$G dd dej%Z&G dd dej%Z'G dd dej%Z(eG dd deZ)eG dd de)Z*eddG dd de)eZ+ed dG d!d" d"e)Z,eG d#d$ d$e)Z-eG d%d& d&e)Z.g d'Z/dS ))zPyTorch MPT model.    N)OptionalTupleUnion)nn)BCEWithLogitsLossCrossEntropyLoss	LayerNormMSELoss)
functional   )GenerationMixin)!_prepare_4d_causal_attention_mask))BaseModelOutputWithPastAndCrossAttentions!CausalLMOutputWithCrossAttentionsQuestionAnsweringModelOutput SequenceClassifierOutputWithPastTokenClassifierOutput)PreTrainedModel)auto_docstringlogging   )	MptConfig   c                 C   s   t jd| dt j|dddd|}dtt|  }t jd|d t j|d }|||  }dt 	d| }|d|dd}|| krjt j
|ddddddf |ddddddf gddddd| df }|| }|dS )	a  
    Link to paper: https://arxiv.org/abs/2108.12409 - Alibi tensor is not causal as the original paper mentions, it
    relies on a translation invariance of softmax for quick implementation. This implementation has been copied from
    the alibi implementation of MPT source code that led to slightly different results than the Bloom alibi:
    https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L292
    r   )dtypedevice         ?N.dimr   )torcharangeint32viewmathceillog2Zint64floatpowconcatsqueeze)	num_headssequence_lengthalibi_bias_maxr   alibiZnum_heads_power_of_2baseZslopes r/   S/var/www/auris/lib/python3.10/site-packages/transformers/models/mpt/modeling_mpt.pybuild_mpt_alibi_tensor+   s   $L
r1   c                
       sZ   e Zd ZdZdef fddZ		ddejdejdee	ej  d	eej fd
dZ
  ZS )MptAttentionzzMulti-head self attention.
    Using torch or triton attention implementation enables user to also use additive bias.
    configc                    s   t    |j| _|j| _|j| _| j| j | _|jj| _| jd u r.dt	
| j| j  | _|jj| _|jj| _tj| jd| j dd| _tj| j| jdd| _d S )Nr   r   Fbias)super__init__hidden_sizen_headsmax_seq_lenZmax_seq_lengthhead_dimattn_configsoftmax_scaler#   sqrt
attn_pdropattn_dropout_pclip_qkvr   LinearWqkvout_projselfr3   	__class__r/   r0   r7   G   s   




zMptAttention.__init__Nhidden_statesposition_biaspast_key_valueattention_maskc                 C   s4  |j d d \}}| |}| jr|j| j | jd}|jddd\}}	}
|||| j| jdd}|	||| j| jdd}	|
||| j| jdd}
|d urtt	|dkrot
j|d |	gdd}	t
j|d |
gdd}
|	|
f}n|	|
f}t
||	dd| j }|d u r|n||d j d  }|d urt	|j dkrtd	t	|j  |	j d }td|d| }td|d| }|d d |d |d f }|| }|d ur||t
|jj}tjj| dd|
j}tjj|| j| jd
}t
||
}|dddd ||d}|  |}|||fS )Nr   )minmaxr   r   r   r   z6Expecting position_bias shape to be 3 dimensions, got ptraining)!shaperC   rA   clampchunkreshaper9   r;   Z	transposelenr   catmatmulr=   
ValueErrorrN   sizeZmasked_fillZfinfor   rM   r   r
   Zsoftmaxr&   todropoutr@   rS   Zpermute
contiguousr"   rD   )rF   rI   rJ   rK   rL   
batch_size
seq_lengthZ	mixed_qkvZquery_statesZ
key_statesZvalue_statesZattention_scoresZquery_lengthZ
key_lengthZposition_bias_query_indexZposition_bias_key_indexattn_weightsZcontext_statesZattn_outputr/   r/   r0   forwardV   s@   




zMptAttention.forward)NN)__name__
__module____qualname____doc__r   r7   r   Tensorr   r   rc   __classcell__r/   r/   rG   r0   r2   B   s    r2   c                       s>   e Zd Zdef fddZdejdejdejfddZ  ZS )	MptMLPr3   c                    sX   t    |j}tj|d| dd| _tjdd| _tjd| |dd| _|j	j
| _d S )N   Fr4   none)Zapproximate)r6   r7   r8   r   rB   up_projZGELUact	down_projr<   r?   hidden_dropoutrF   r3   r8   rG   r/   r0   r7      s   
zMptMLP.__init__rI   residualreturnc                 C   s:   |  | |}| |}tj|| j| jd}|| }|S )NrQ   )rn   rm   ro   Fr^   rp   rS   )rF   rI   rr   Zintermediate_outputoutputr/   r/   r0   rc      s
   
zMptMLP.forward)	rd   re   rf   r   r7   r   rh   rc   ri   r/   r/   rG   r0   rj      s    $	rj   c                       sb   e Zd Zdef fddZ			ddejdejdejd	eeejejf  d
e	de	fddZ
  ZS )MptBlockr3   c                    sx   t    |j}t||jd| _d | j_|j| _t	|| _
t||jd| _d | j_t|| _|jj| _t| j| _d S )NZeps)r6   r7   r8   r   layer_norm_epsilonnorm_1r5   r9   r*   r2   attnnorm_2rj   ffnr<   r?   Zdropout_rater   Dropoutresid_attn_dropoutrq   rG   r/   r0   r7      s   



zMptBlock.__init__NFrI   rJ   rL   
layer_past	use_cacheoutput_attentionsc                 C   st   |  |}|}| j||||d\}	}
}| |	| }| |}|}| ||}|f}|r1||f7 }|r8||
f7 }|S )N)rJ   rL   rK   )ry   rz   r~   r{   r|   )rF   rI   rJ   rL   r   r   r   Zlayernorm_outputrr   Zattn_outputsrb   rK   ru   outputsr/   r/   r0   rc      s$   



zMptBlock.forward)NFF)rd   re   rf   r   r7   r   rh   r   r   boolrc   ri   r/   r/   rG   r0   rv      s$    rv   c                       sz   e Zd ZeZdZdZdgZdgZ fddZ	de
jfdd	Zed
eeejejf  deeejejf  fddZ  ZS )MptPreTrainedModeltransformerTrv   z
lm_head.*.c                    s   t  j|i | d S N)r6   r7   )rF   ZinputskwargsrG   r/   r0   r7      s   zMptPreTrainedModel.__init__modulec                 C   s   t |tjr |jjjd| jjd |jdur|jj	  dS dS t |tj
rC|jjjd| jjd |jdurA|jj|j 	  dS dS t |tr\|jdurS|jj	  |jjd dS dS )zInitialize the weights.g        )meanZstdNr   )
isinstancer   rB   weightdataZnormal_r3   Zinitializer_ranger5   Zzero_	EmbeddingZpadding_idxr   Zfill_)rF   r   r/   r/   r0   _init_weights   s   



z MptPreTrainedModel._init_weightsrK   rs   c                    s8   | d d j \}}||  t fdd| D S )zw
        Converts the cache to the format expected by Mpt, i.e. to tuple(tuple([batch_size * num_heads, ...]))
        r   c                 3   s4    | ]}|d    |d   fV  qdS r   r   N)rW   .0r   Zbatch_size_times_num_headsr;   ra   r/   r0   	<genexpr>
  s    
z;MptPreTrainedModel._convert_to_mpt_cache.<locals>.<genexpr>)rT   tuple)rK   r`   r*   r/   r   r0   _convert_to_mpt_cache   s
   z(MptPreTrainedModel._convert_to_mpt_cache)rd   re   rf   r   Zconfig_classZbase_model_prefixZsupports_gradient_checkpointingZ_no_split_modulesZ_keys_to_ignore_on_load_missingr7   r   Moduler   staticmethodr   r   rh   r   ri   r/   r/   rG   r0   r      s    r   c                       s   e Zd Zdef fddZdd Zddd	Zd
ejfddZ	e
								ddeej deeeejejf df  deej deej dee dee dee dee deeejdf ef fddZ  ZS )MptModelr3   c                    sz   t     j| _ j| _t j| j| _t	 fddt
 jD | _t| j jd| _d | j_d| _|   d S )Nc                    s   g | ]}t  qS r/   )rv   )r   _r3   r/   r0   
<listcomp>  s    z%MptModel.__init__.<locals>.<listcomp>rw   F)r6   r7   r8   r9   r*   r   r   
vocab_sizewteZ
ModuleListrangeZn_layersblocksr   rx   norm_fr5   gradient_checkpointing	post_initrE   rG   r   r0   r7     s    zMptModel.__init__c                 C      | j S r   r   rF   r/   r/   r0   get_input_embeddings+     zMptModel.get_input_embeddingsr   Nc                 C   s   t ||||S r   )r1   )rF   r*   r+   r,   r   r/   r/   r0   r1   .  s   zMptModel.build_mpt_alibi_tensornew_embeddingsc                 C   
   || _ d S r   r   rF   r   r/   r/   r0   set_input_embeddings1     
zMptModel.set_input_embeddings	input_idspast_key_values.rL   inputs_embedsr   r   output_hidden_statesreturn_dictrs   c	              
   K   s~  |dur|n| j j}|dur|n| j j}|dur|n| j j}|dur$|n| j j}|dur4|dur4td|dur>|j\}
}n|durI|j\}
}}ntd|du r[tdgt| j	 }|du rd| 
|}|}|rjdnd}|rpdnd}|rvdnd}| jr| jr|rtd d}|}d}|d dur|d d jd }|| }|du rtj|
|f|jd	}n||j}| j| j| j j|jd	}t||
|f||}| }t| j	|D ]G\}}|r||f }| jr| jr| |j||||||}n
|||||||d
}|d }|du r
||d f }|r|||rdnd f }q| |}|r'||f }|s7tdd ||||fD S t||||dS )  
        input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
            `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0][0].shape[2]`
            (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.

            If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
            `input_ids`.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        NzDYou cannot specify both input_ids and inputs_embeds at the same timez5You have to specify either input_ids or inputs_embedsr/   zZ`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...Fr   r   r   )r   rL   r   r   rJ   Tr   c                 s   s    | ]	}|d ur|V  qd S r   r/   )r   vr/   r/   r0   r     s    z#MptModel.forward.<locals>.<genexpr>)Zlast_hidden_stater   rI   
attentions)r3   r   r   r   use_return_dictr[   rT   r   rX   r   r   r   rS   loggerwarning_oncer   Zonesr   r]   r1   r*   r:   r   r   zipZ_gradient_checkpointing_func__call__r   r   )rF   r   r   rL   r   r   r   r   r   r   r`   ra   r   rI   ZpresentsZall_self_attentionsZall_hidden_statesZseq_length_with_pastZpast_key_values_lengthr-   Zcausal_maskblockr   r   r/   r/   r0   rc   4  s   


	


zMptModel.forwardr   NNNNNNNNN)rd   re   rf   r   r7   r   r1   r   rh   r   r   r   
LongTensorr   r   r   r   rc   ri   r/   r/   rG   r0   r     sB    
	r   z
    The MPT Model transformer with a language modeling head on top (linear layer with weights tied to the input
    embeddings).
    )Zcustom_introc                       s  e Zd ZdgZdef fddZdd Zdejfdd	Z	e
	
	
	
	
	
	
	
	
	
ddeej deeeejejf df  deej deej deej dee dee dee dee deeej ef fddZdeeejejf df dejdeeejejf df fddZ  ZS )MptForCausalLMzlm_head.weightr3   c                    s8   t  | t|| _tj|j|jdd| _| 	  d S NFr4   )
r6   r7   r   r   r   rB   r8   r   lm_headr   rE   rG   r/   r0   r7     s   
zMptForCausalLM.__init__c                 C   r   r   r   r   r/   r/   r0   get_output_embeddings  r   z$MptForCausalLM.get_output_embeddingsr   c                 C   r   r   r   r   r/   r/   r0   set_output_embeddings  r   z$MptForCausalLM.set_output_embeddingsNr   r   .rL   r   labelsr   r   r   r   rs   c
              
   K   s   |	dur|	n| j j}	| j||||||||	d}|d }| |}d}|dur:||j}| j||fd| j ji|
}|	sP|f|dd  }|durN|f| S |S t|||j	|j
|jdS )aZ  
        input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
            `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0][0].shape[2]`
            (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.

            If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
            `input_ids`.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
            `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
            are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
        Nr   rL   r   r   r   r   r   r   r   r   losslogitsr   rI   r   )r3   r   r   r   r]   r   Zloss_functionr   r   r   rI   r   )rF   r   r   rL   r   r   r   r   r   r   r   transformer_outputsrI   Z	lm_logitsr   ru   r/   r/   r0   rc     sD   

zMptForCausalLM.forwardpastbeam_idxc                    s,    fdd|D t fdd|D }|S )aL  
        This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
        [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
        beam_idx at every generation step.

        Output shares the same memory storage as `past`.
        c                    s&   i | ]}|D ]
}|j  |j qqS r/   )r   r]   )r   r   Z
past_state)r   r/   r0   
<dictcomp>  s
    z1MptForCausalLM._reorder_cache.<locals>.<dictcomp>c                 3   sD    | ]}|d   d  |d  j |d  d  |d  j fV  qdS r   )Zindex_selectr   r   )device_to_beam_idxr/   r0   r     s    
z0MptForCausalLM._reorder_cache.<locals>.<genexpr>)r   )rF   r   r   Zreordered_pastr/   )r   r   r0   _reorder_cache  s   
zMptForCausalLM._reorder_cache	NNNNNNNNN)rd   re   rf   Z_tied_weights_keysr   r7   r   r   rh   r   r   r   r   r   r   r   r   rc   r   ri   r/   r/   rG   r0   r     sV    	
Fr   a  
    The MPT Model transformer with a sequence classification head on top (linear layer).

    [`MptForSequenceClassification`] uses the last token in order to do the classification, as other causal models
    (e.g. GPT-1) do.

    Since it does classification on the last token, it requires to know the position of the last token. If a
    `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
    no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
    padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
    each row of the batch).
    c                          e Zd Zdef fddZe									ddeej dee	e	ej
ej
f df  deej
 d	eej
 d
eej
 dee dee dee dee dee	ej
 ef fddZ  ZS )MptForSequenceClassificationr3   c                    s@   t  | |j| _t|| _tj|j|jdd| _| 	  d S r   )
r6   r7   
num_labelsr   r   r   rB   r8   scorer   rE   rG   r/   r0   r7   8  s
   
z%MptForSequenceClassification.__init__Nr   r   .rL   r   r   r   r   r   r   rs   c
              
   C   s$  |	dur|	n| j j}	| j||||||||	d}
|
d }| |}|dur*|jd }n|jd }| j jdu r=|dkr=td| j jdu rFd}n1|durk|| j jk|jt	j
}t	j|jd |jt	j
d}|| d}nd}t| jj d |t	j||jd	|f }d}|dur| j jdu r| jdkrd
| j _n| jdkr|jt	jks|jt	jkrd| j _nd| j _| j jd
krt }| jdkr|| | }n#|||}n| j jdkrt }|||}n| j jdkrt }|||}|	s|f|
dd  }|dur|f| S |S t|||
j|
j|
jdS )4  
        input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
            `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0][0].shape[2]`
            (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.

            If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
            `input_ids`.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        Nr   r   r   z=Cannot handle batch sizes > 1 if no padding token is defined.rO   )r   r   z will not detect padding tokens in `inputs_embeds`. Results may be unexpected if using padding tokens in conjunction with `inputs_embeds.`r   Z
regressionZsingle_label_classificationZmulti_label_classificationr   )r3   r   r   r   rT   Zpad_token_idr[   r]   r   r   r!   r    Zargmaxr   r   rH   rd   Zproblem_typer   r   longintr	   r)   r   r   r   r   rI   r   )rF   r   r   rL   r   r   r   r   r   r   r   rI   r   r`   Zlast_non_pad_tokenZnon_pad_maskZtoken_indicesZpooled_logitsr   loss_fctru   r/   r/   r0   rc   A  st   



"


z$MptForSequenceClassification.forwardr   )rd   re   rf   r   r7   r   r   r   r   r   rh   r   r   r   rc   ri   r/   r/   rG   r0   r   )  sB    		
r   c                       r   )MptForTokenClassificationr3   c                    s   t  | |j| _t|| _t|dr|jd ur|j}nt|dr+|jd ur+|j}nd}t	|| _
t|j|j| _|   d S )Nclassifier_dropoutrp   g?)r6   r7   r   r   r   hasattrr   rp   r   r}   r^   rB   r8   
classifierr   )rF   r3   r   rG   r/   r0   r7     s   
z"MptForTokenClassification.__init__Nr   r   .rL   r   r   r   r   r   r   rs   c
              
   K   s   |	dur|	n| j j}	| j||||||||	d}|d }| |}| |}d}|durJ||j}|j\}}t }||	|| | j
|	|| }|	s`|f|dd  }|dur^|f| S |S t|||j|jdS )r   Nr   r   r   )r   r   rI   r   )r3   r   r   r^   r   r]   r   rT   r   r"   r   r   rI   r   )rF   r   r   rL   r   r   r   r   r   r   Zdeprecated_argumentsr   rI   r   r   r`   ra   r   ru   r/   r/   r0   rc     s>   


z!MptForTokenClassification.forwardr   )rd   re   rf   r   r7   r   r   r   r   r   rh   r   r   r   rc   ri   r/   r/   rG   r0   r     sB    	
r   c                       s   e Zd Z fddZe								ddeej deej deej deej deej d	ee	 d
ee	 dee	 de
eef fddZ  ZS )MptForQuestionAnsweringc                    s2   t  | t|| _t|jd| _|   d S )Nr   )	r6   r7   r   r   r   rB   r8   
qa_outputsr   rE   rG   r/   r0   r7     s   
z MptForQuestionAnswering.__init__Nr   rL   r   start_positionsend_positionsr   r   r   rs   c	                 C   sB  |dur|n| j j}| j||||||d}	|	d }
| |
}|jddd\}}|d }|d }d}|dur|durt| dkrL|d}t| dkrY|d}|d}|	d|}|	d|}t
|d}|||}|||}|| d }|s||f|	dd  }|dur|f| S |S t||||	j|	jd	S )
r   N)rL   r   r   r   r   r   r   rO   r   )Zignore_indexr   )r   start_logits
end_logitsrI   r   )r3   r   r   r   splitr)   r_   rX   r\   rU   r   r   rI   r   )rF   r   rL   r   r   r   r   r   r   r   Zsequence_outputr   r   r   Z
total_lossZignored_indexr   Z
start_lossZend_lossru   r/   r/   r0   rc     sJ   	






zMptForQuestionAnswering.forwardr   )rd   re   rf   r7   r   r   r   r   ZFloatTensorr   r   r   r   rc   ri   r/   r/   rG   r0   r     s<    	

r   )r   r   r   r   r   r   r   )0rg   r#   typingr   r   r   r   Ztorch.utils.checkpointr   Ztorch.nnr   r   r   r	   r
   rt   Z
generationr   Zmodeling_attn_mask_utilsr   Zmodeling_outputsr   r   r   r   r   Zmodeling_utilsr   utilsr   r   Zconfiguration_mptr   Z
get_loggerrd   r   r1   r   r2   rj   rv   r   r   r   r   r   r   __all__r/   r/   r/   r0   <module>   sJ   

L@/  prXR