o
    Zhs                     @   s  d dl Z d dlmZ d dlZd dlmZ ddlmZ e r/d dlm	Z	m
Z
mZ d dlmZmZ G dd	 d	ejjZG d
d dejjZG dd dejjZ				dde
deej deej dee dee f
ddZG dd dejjZG dd dejjZG dd dejjZdS )    N)Optional)GenerationConfig   )is_torch_available)HybridCachePreTrainedModelStaticCache)is_torch_greater_or_equal"is_torch_greater_or_equal_than_2_3c                       s   e Zd ZdZ		d"dededef fddZd	ejd
ejdejfddZ					d#d	e
ej d
e
ej de
e de
e dejjf
ddZe						d$dejjdedededededededefd d!Z  ZS )%%TorchExportableModuleForDecoderOnlyLMa  
    A recipe module designed to make a `PreTrainedModel` exportable with `torch.export`,
    specifically for decoder-only LM with cache. This module ensures that the
    exported model is compatible with further lowering and execution in `ExecuTorch`.
          modelmax_batch_sizemax_cache_lenc                    s   t    t|jdr|jjdu rtdt|jds'td t|| _	d	S |jj
dkr6t|||| _	d	S td|jj
 d)
a  
        Initializes the exportable module with `HybridCache`.

        Args:
            model (`PreTrainedModel`): The pretrained model to wrap.
            max_batch_size (int): Maximum batch size for the cache.
            max_cache_len (int): Maximum sequence length for the cache.

        Raises:
            ValueError: If the model is configured with a unsupported cache implementation.
        	use_cacheFz5The model must have caching enabled to be performant.cache_implementationzXUsing `StaticCache` for export as `cache_implementation` is not specified in the config.hybridz"Unsupported cache implementation: z". Please use `hybrid` or `static`.N)super__init__hasattrconfigr   
ValueErrorlogginginfo$TorchExportableModuleWithStaticCacher   r   $TorchExportableModuleWithHybridCache)selfr   r   r   	__class__ S/var/www/auris/lib/python3.10/site-packages/transformers/integrations/executorch.pyr   #   s   

z.TorchExportableModuleForDecoderOnlyLM.__init__	input_idscache_positionreturnc                 C   s   | j ||S )  
        Forward pass of the module, which is compatible with the ExecuTorch llm runner.

        Args:
            input_ids (`torch.Tensor`): Tensor representing current input token id to the module.
            cache_position (`torch.Tensor`): Tensor representing current input position in the cache.

        Returns:
            torch.Tensor: Logits output from the model.
        )r   forward)r   r"   r#   r    r    r!   r&   G   s   z-TorchExportableModuleForDecoderOnlyLM.forwardNdynamic_shapesstrictc                 C   sj   |dur|n	t jdggt jd}|dur|nt jdgt jd}t jj| j||fi ||dur1|dS ddS )a  
        Export the wrapped module using `torch.export`.

        Args:
            input_ids (`Optional[torch.Tensor]`):
                Tensor representing current input token id to the module. If not provided, a default tensor will be used.
            cache_position (`Optional[torch.Tensor]`):
                Tensor representing current input position in the cache. If not provided, a default tensor will be used.
            dynamic_shapes (`Optional[dict]`):
                Dynamic shapes to use for export if specified.
            strict(`Optional[bool]`):
                Flag to instruct `torch.export` to use `torchdynamo`.
        Nr   dtyper   Targskwargsr'   r(   )torchtensorlongexportr   )r   r"   r#   r'   r(   example_input_idsexample_cache_positionr    r    r!   r1   X   s    
z,TorchExportableModuleForDecoderOnlyLM.export   F      ?2   cpuexported_programpromptmax_new_tokens	do_sampletemperaturetop_ktop_pdevicec	                 C   s  |   }	||ddj|}
|
 }d}t|
jd D ]!}|
dd||d f }tj|gtj|d}|	||}|d7 }qt|D ]}|ddddf }tj|gtj|d}|	||}|r|dkrg|| }n|}|dkr|t	||d d k }t
d	||< |d
k rtj|dd\}}tjtj|dddd}||k}|dddf  |dddf< d|d< |d||}t
d	||< tj|dd}tj|dd}n|jddd}| dkr|d}tj||gdd}|d7 }| |jkr nqA|j|d ddS )a   
        Generate a sequence of tokens using an exported program.

        Args:
            exported_program (`torch.export.ExportedProgram`): The exported model being used for generate.
            tokenizer: The tokenizer to use.
            prompt (str): The input prompt.
            max_new_tokens (int): Maximum number of new tokens to generate.
            do_sample (bool): Whether to use sampling or greedy decoding.
            temperature (float): The temperature for sampling.
            top_k (int): The number of highest probability tokens to keep for top-k sampling.
            top_p (float): The cumulative probability for nucleus sampling.
            device (str): The device to use.

        Returns:
            str: The generated text.
        pt)Zreturn_tensorsr   r   Nr*   r?   ).rB   Nz-infr5   T)Z
descendingdim.).r   )Znum_samples)rD   Zkeepdimr   )Zskip_special_tokens)moduler"   toclonerangeshaper.   r/   r0   ZtopkfloatsortZcumsumZsoftmaxZscatterZmultinomialargmaxrD   Zsqueezecatitemeos_token_iddecode)r8   Z	tokenizerr9   r:   r;   r<   r=   r>   r?   Zexported_moduler"   generated_idsZcurr_positioniZcurr_input_idsZcurr_cache_position_outputslogitsZindices_to_removeZsorted_logitsZsorted_indicesZcumulative_probsZsorted_indices_to_removeZprobsZnext_token_idr    r    r!   generatew   sN   



 
z.TorchExportableModuleForDecoderOnlyLM.generater   r   NNNN)r4   Fr5   r6   r5   r7   )__name__
__module____qualname____doc__r   intr   r.   Tensorr&   r   dictboolr1   ExportedProgramstaticmethodstrrJ   rV   __classcell__r    r    r   r!   r      sv    	$

	
r   c                	       sb   e Zd ZdZdef fddZdejdejfddZe	d	ej
jd
ejdedejfddZ  ZS )r   a  
    A recipe module designed to make a `PreTrainedModel` exportable with `torch.export`,
    specifically for decoder-only LM to `StaticCache`. This module ensures that the
    exported model is compatible with further lowering and execution in `ExecuTorch`.

    Note:
        This class is specifically designed to support export process using `torch.export`
        in a way that ensures the model can be further lowered and run efficiently in `ExecuTorch`.
    r   c                    s$  t    |jdu rtd|jjstd|jjdkr td|| _t| jj| jjj	j
| jjj	j| jjj	j| jjd| _tt| jjD ] }| jd| | jj| dd	 | jd
| | jj| dd	 qEtdd | jjjD | _| jrttj| jj| jjtjd}| jd|dd	 dS dS )a  
        Initializes the wrapper module with the pretrained model.

        Args:
            model (`PreTrainedModel`): The pretrained model to wrap. The model must have caching
            enabled and use a 'static' caching implementation.

        Raises:
            AssertionError: If the pretrained model does not have caching enabled or if it does
            not use a 'static' caching implementation in `model.generation_config`.
        NzkThe model must have a generation config to be exported with static caching. Please set `generation_config`.zvThe model must have caching enabled to be exported with static caching. Please set `generation_config.use_cache=True`.staticzThe model must use a 'static' caching implementation to be exported with static caching. Please set `generation_config.cache_implementation='static'`.r   r   r   r?   r*   
key_cache_F
persistentvalue_cache_c                 s   s    | ]}d |v V  qdS )ZCausalLMNr    ).0archr    r    r!   	<genexpr>  s    z@TorchExportableModuleWithStaticCache.__init__.<locals>.<genexpr>r)   mask)r   r   generation_configAssertionErrorr   r   r   r   r   cache_config
batch_sizer   r?   r*   static_cacherH   len	key_cacheregister_buffervalue_cacheanyZarchitectures	is_causalr.   Ztrilonesr`   )r   r   rR   Zcausal_maskr   r    r!   r      sF   




 z-TorchExportableModuleWithStaticCache.__init__r"   r#   c           	      C   sR   |j \}}| jr| j|d|f nd}|d}| j}| j|||||dd}|jS )a  
        Forward pass of the module, which is compatible with the ExecuTorch runtime.

        Args:
            input_ids (`torch.Tensor`): Tensor representing current input token id to the module.
            cache_position (`torch.Tensor`): Tensor representing current input position in the cache.

        Returns:
            torch.Tensor: Logits output from the model.

        This forward adapter serves two primary purposes:

        1. **Making the Model `torch.export`-Compatible**:
            The adapter hides unsupported objects, such as the `Cache`, from the graph inputs and outputs,
            enabling the model to be exportable using `torch.export` without encountering issues.

        2. **Ensuring Compatibility with `ExecuTorch` runtime**:
            The adapter matches the model's forward signature with that in `executorch/extension/llm/runner`,
            ensuring that the exported model can be executed in `ExecuTorch` out-of-the-box.
        Nr   T)r"   attention_maskposition_idsr#   past_key_valuesr   )rI   ry   rn   	unsqueezers   r   rU   )	r   r"   r#   rS   ZseqlenZ	attn_maskr|   r}   Zoutsr    r    r!   r&   '  s   

z,TorchExportableModuleWithStaticCache.forwardr8   prompt_token_idsr:   r$   c                 C   sb  |j d }|| }|  D ]\}}|dr"|j d }t||} nqg }tt||D ]'}	|  j|dd|	|	d f tj|	gtj	dd}
|
|d |	   q,tj|
dddddf dd	 }|
| t||k r|  jtj|ggtj	dtjt|gtj	dd}
tj|
dddddf dd	 }|
| t||k sqtj|gtj	dS )
a  
        Generate a sequence of tokens using an exported program.

        This util function is designed to test exported models by simulating the generation process.
        It processes the input prompt tokens sequentially (no parallel prefill).
        This generate function is not intended to replace the original `generate` method, and the support
        for leveraging the original `generate` is potentially planned!

        Args:
            exported_program (`torch.export.ExportedProgram`): The exported program generated via `torch.export`.
            prompt_token_ids (`torch.Tensor`): Tensor representing the input prompt token IDs.
            max_new_tokens (`int`): Maximum number of new tokens to generate. Note that the total generation
                length is limited by both `max_new_tokens` and the model's cache size.

        Returns:
            torch.Tensor: A tensor containing the generated sequence of token IDs, including the original prompt tokens.
        rB   ru   r   Nr   r)   )r"   r#   r   rC   )rI   Znamed_buffers
startswithminrH   rE   r&   r.   r/   r0   appendrN   rL   rt   )r8   r   r:   Zprompt_token_lenZmax_generation_lengthZbuffer_namebufferr   Zresponse_tokensZ	input_posresultcurrent_tokenr    r    r!   rV   K  s4   



$
$
z-TorchExportableModuleWithStaticCache.generate)rY   rZ   r[   r\   r   r   r.   r^   r&   rb   r1   ra   r]   rV   rd   r    r    r   r!   r      s    
8$r   c                       sP   e Zd ZdZ		ddededef fddZd	ejd
ejdejfddZ	  Z
S )r   a  
    A recipe module designed to make a `PreTrainedModel` exportable with `torch.export`,
    specifically for decoder-only LM to `HybridCache`. This module ensures that the
    exported model is compatible with further lowering and execution in `ExecuTorch`.
    r   r   r   r   r   c                    s   t    || _| jjjstdt| jjdr| jjjdkr#tdt| jj||| jj	| jj
d| _tt| jjD ] }| jd| | jj| dd | jd	| | jj| dd q;d
S )a  
        Initializes the exportable module with `HybridCache`.

        Args:
            model (`PreTrainedModel`): The pretrained model to wrap.
            max_batch_size (int): Maximum batch size for the cache.
            max_cache_len (int): Maximum sequence length for the cache.

        Raises:
            AssertionError: If the model doesn't have the expected configuration for HybridCache.
        zModel must have caching enabledr   r   z,Model must use 'hybrid' cache implementationrf   rg   Frh   rj   N)r   r   r   r   r   rp   r   r   r   r?   r*   cacherH   rt   ru   rv   rw   )r   r   r   r   rR   r   r    r!   r     s&   

	 z-TorchExportableModuleWithHybridCache.__init__r"   r#   r$   c                 C   sR   |j \}}|d|d}tj||ftj|jd}| j|||| jd|d}|j	S )r%   r   rB   rA   T)r"   r{   r|   r}   r   r#   )
rI   r~   expandr.   rz   r0   r?   r   r   rU   )r   r"   r#   rr   Zseq_lenr|   r{   rT   r    r    r!   r&     s   

z,TorchExportableModuleWithHybridCache.forwardrW   )rY   rZ   r[   r\   r   r]   r   r.   r^   r&   rd   r    r    r   r!   r   ~  s$    	,r   r   r2   r3   r'   r(   c                 C   s   t stdddl}| f |dur|n	|jdgg|jd}|dur%|n|jdg|jd}tdrH|jjt| ||fi ||durC|ndd}n#|durQt	
d	 |durZt	
d
 |jjjt| |fd|iddd}|W  d   S 1 sww   Y  dS )a  
    Convert a `PreTrainedModel` into an exportable module and export it using `torch.export`,
    ensuring the exported model is compatible with `ExecuTorch`.

    Args:
        model (`PreTrainedModel`): The pretrained model to be exported.
        example_input_ids (`Optional[torch.Tensor]`): Example input token id used by `torch.export`.
        example_cache_position (`Optional[torch.Tensor]`): Example current cache position used by `torch.export`.
        dynamic_shapes(`Optional[dict]`): Dynamic shapes used by `torch.export`.
        strict(`Optional[bool]`): Flag to instruct `torch.export` to use `torchdynamo`.

    Returns:
        Exported program (`torch.export.ExportedProgram`): The exported program generated via `torch.export`.
    ztorch >= 2.3 is required.r   Nr   r)   z2.6.0Tr+   zWDynamic shapes spec will be ignored by convert_and_export_with_cache for torch < 2.6.0.zSThe strict flag will be ignored by convert_and_export_with_cache for torch < 2.6.0.r#   F)r,   r-   Zpre_dispatchr(   )r
   ImportErrorZtorch.export._traceno_gradr/   r0   r	   r1   r   r   warning_traceZ_export)r   r2   r3   r'   r(   r.   r8   r    r    r!   convert_and_export_with_cache  s>   

$r   c                       (   e Zd ZdZ fddZdd Z  ZS ) Seq2SeqLMEncoderExportableModulez
    A wrapper module designed to make a Seq2Seq LM encoder exportable with `torch.export`.
    This module ensures that the exported encoder model is compatible with ExecuTorch.
    c                    s   t    || _d S )N)r   r   encoder)r   Zencoder_modelr   r    r!   r     s   

z)Seq2SeqLMEncoderExportableModule.__init__c                 C   s   | j |djS )N)r"   )r   Zlast_hidden_state)r   r"   r    r    r!   r&      s   z(Seq2SeqLMEncoderExportableModule.forwardrY   rZ   r[   r\   r   r&   rd   r    r    r   r!   r     s    r   c                       r   )/Seq2SeqLMDecoderExportableModuleWithStaticCachez
    A wrapper module designed to make a Seq2Seq LM decoder exportable with `torch.export`,
    specifically for use with static caching. This module ensures the exported decoder
    is compatible with ExecuTorch.
    c                    s   t    | | _|j| _|j| _t| j||dtjd| _	t
t| j	jD ] }| jd| | j	j| dd | jd| | j	j| dd q&d S )Nr7   rf   rg   Frh   rj   )r   r   Zget_decoderdecoderlm_headr   r   r.   float32rs   rH   rt   ru   rv   rw   )r   r   max_static_cache_lengthrr   rR   r   r    r!   r   +  s   

	 z8Seq2SeqLMDecoderExportableModuleWithStaticCache.__init__c                 C   s(   | j ||| jd|d}| |d }|S )NT)r"   encoder_hidden_statesr}   r   r#   r   )r   rs   r   )r   decoder_input_idsr   r#   rT   Z	lm_logitsr    r    r!   r&   A  s   	z7Seq2SeqLMDecoderExportableModuleWithStaticCache.forwardr   r    r    r   r!   r   $  s    r   c                       sB   e Zd Z	d fdd	Zdd Zd	d
 ZdddZdd Z  ZS )Seq2SeqLMExportableModuler   r   re      c                    sP   t    || _| | _|j| _|| _td||||dd| _d | _	d | _
d S )NT)rr   r   )r   
max_lengthr   rq   )r   r   
full_modelZget_encoderr   r   max_hidden_seq_lengthr   ro   exported_encoderexported_decoder)r   r   rr   r   r   Zmax_cache_lengthr   r    r!   r   R  s   

	
z"Seq2SeqLMExportableModule.__init__c                 C   sr   t | jd }tjjd| jd}t  tjj||fdd|iidd}W d    |S 1 s2w   Y  |S )Nr7   Zencoder_seq_lengthmaxr"   r   Tr'   r(   )	r   r   rF   evalr.   r1   Dimr   r   )r   encoder_input_idsZwrapped_encoderZseq_len_dimr   r    r    r!   _export_encoderg  s   

z)Seq2SeqLMExportableModule._export_encoderc                 C   s   t | j| jjj| jjjdd }tj	j
d| jd}t  tj	j	||||fd d|id ddd}W d    |S 1 s?w   Y  |S )	N)r   r   rr   r7   Zencoder_hidden_seq_lengthr   r   )r   r   r#   Tr   )r   r   ro   rq   r   rr   rF   r   r.   r1   r   r   r   )r   r   r   r#   Zwrapped_decoderZencoder_seq_len_dimr   r    r    r!   _export_decoderu  s.   

z)Seq2SeqLMExportableModule._export_decoderNc           	      C   s   |d ur|nt jdt jd}|d ur|n	t jdggt jd}|d ur$|nt jdgt jd}|d ur3|nt j| jjjd| jj	ft j
d}| || _| |||| _| S )N)r   
   r)   r   r   )r.   rz   r0   r/   Zzerosro   rq   rr   r   Zd_modelr   r   r   r   r   )	r   r   r   r   r#   Zexample_encoder_input_idsZexample_decoder_input_idsr3   Zexample_encoder_hidden_statesr    r    r!   r1     s    z Seq2SeqLMExportableModule.exportc           	   	   C   s   t  ` | j |}t jdggt jd}dg}t|d D ];}| j ||t j|gt jd}t j|d d dd d f dd	 }|
| t j|ggt jd}|| jjkrZ nq|W  d    S 1 sgw   Y  d S )Nr   r)   r   rB   rC   )r.   r   r   rE   r/   r0   rH   r   rL   rN   r   r   rO   )	r   r   r:   Zencoder_outputr   rQ   rR   rU   Z
next_tokenr    r    r!   rV     s    
$
$z"Seq2SeqLMExportableModule.generate)r   r   re   r   rX   )	rY   rZ   r[   r   r   r   r1   rV   rd   r    r    r   r!   r   Q  s    
r   rX   )r   typingr   r.   Z+transformers.generation.configuration_utilsr   Zutils.import_utilsr   Ztransformersr   r   r   Ztransformers.pytorch_utilsr	   r
   nnModuler   r   r   r^   r_   r`   r   r   r   r   r    r    r    r!   <module>   s>    I Z
@-