o
    Zhw|                     @   s  d dl Z d dlZd dlZd dlZd dlZd dlZd dlmZmZ d dl	m
Z
mZ d dlmZ d dlmZ d dlZd dlmZmZ ddlmZ e d	krPd dlZe red d
lmZ d dlmZ d dlmZ e rzd dlZd dlm Z m!Z!m"Z"m#Z#m$Z$ e%ej&ej' Z(e%ej&ej) ej' d Z*ddiddiddiddiddiddidZ+dZ,dd-e+.  dZ/g dZ0G dd dZ1e
G dd dZ2d efd!d"Z3G d#d$ d$eZ4dS )%    N)ArgumentParser	Namespace)	dataclassfield)Thread)Optional)is_rich_availableis_torch_available   )BaseTransformersCLICommandWindows)Console)Live)Markdown)AutoModelForCausalLMAutoTokenizerBitsAndBytesConfigGenerationConfigTextIteratorStreamerz .!\"#$%&'()*+,\-/:<=>?@[]^_`{|}~textz5There is a Llama in my lawn, how can I get rid of it?zyWrite a Python function that integrates any Python function f(x) numerically over an arbitrary interval [x_start, x_end].z4How many helicopters can a human eat in one sitting?z4Count to 10 but skip every number ending with an 'e'zWhy aren't birds real?z2Why is it important to eat socks after meditating?)llamacode
helicopternumbersZbirdssocksaa  

**TRANSFORMERS CHAT INTERFACE**

Chat interface to try out a model. Besides chatting with the model, here are some basic commands:
- **!help**: shows all available commands
- **!status**: shows the current status of the model and generation settings
- **!clear**: clears the current conversation and starts a new one
- **!exit**: closes the interface
am  

**TRANSFORMERS CHAT INTERFACE HELP**

Full command list:
- **!help**: shows this help message
- **!clear**: clears the current conversation and starts a new one
- **!status**: shows the current status of the model and generation settings
- **!example {NAME}**: loads example named `{NAME}` from the config and uses it as the user input.
Available example names: `z`, `a%  `
- **!set {ARG_1}={VALUE_1} {ARG_2}={VALUE_2}** ...: changes the system prompt or generation settings (multiple
settings are separated by a space). Accepts the same flags and format as the `generate_flags` CLI argument.
If you're a new user, check this basic flag guide: https://huggingface.co/docs/transformers/llm_tutorial#common-options
- **!save {SAVE_NAME} (optional)**: saves the current chat and settings to file by default to
`./chat_history/{MODEL_NAME}/chat_{DATETIME}.yaml` or `{SAVE_NAME}` if provided
- **!exit**: closes the interface
)	)max_new_tokens   r   )	do_sampleTr   )	num_beamsr
   r   )temperature      ?r   )top_k2   r!   )top_pr    r#   )repetition_penaltyr    r$   )
eos_tokensNeos_token_id)eos_token_idsNr&   c                   @   s   e Zd Zddee dee fddZdedefdd	Zdefd
dZdd Z	defddZ
dedefddZddefddZdededefddZdS )RichInterfaceN
model_name	user_namec                 C   s:   t  | _|d u rd| _n|| _|d u rd| _d S || _d S )N	assistantuser)r   _consoler)   r*   )selfr)   r*    r/   I/var/www/auris/lib/python3.10/site-packages/transformers/commands/chat.py__init__x   s   

zRichInterface.__init__output_streamreturnc           	      C   s   d}| j d| j d t| j ddJ}t|D ]=\}}|r#|dkr$q||7 }g }| D ]}|| |dr@|d q.|d	 q.td	|
 d
d}|| qW d   n1 saw   Y  | j   |S )zRStream output from a role, and return the generated text after it's done steaming. z[bold blue]<z>:   )consolerefresh_per_secondr   z```
z  
zgithub-dark)Z
code_themeN)r-   printr)   r   	enumerate
splitlinesappend
startswithr   joinstripupdate)	r.   r2   r   liveiZoutputslineslinemarkdownr/   r/   r0   stream_output   s&   


zRichInterface.stream_outputc                 C   s$   | j d| j d}| j   |S )z!Gets user input from the console.[bold red]<z>:
)r-   inputr*   r9   )r.   rH   r/   r/   r0   rH      s   
zRichInterface.inputc                 C   s   | j   dS )zClears the console.N)r-   clear)r.   r/   r/   r0   rI      s   zRichInterface.clearr   c                 C   s(   | j d| j d|  | j   dS )z%Prints a user message to the console.rG   z>:[/ bold red]
N)r-   r9   r*   )r.   r   r/   r/   r0   print_user_message   s   z RichInterface.print_user_messagecolorc                 C   s&   | j d| d|  | j   dS )z,Prints text in a given color to the console.z[bold ]Nr-   r9   )r.   r   rK   r/   r/   r0   print_color      zRichInterface.print_colorFminimalc                 C   s&   | j t|rtnt | j   dS )z'Prints the help message to the console.N)r-   r9   r   HELP_STRING_MINIMALHELP_STRING)r.   rP   r/   r/   r0   
print_help   rO   zRichInterface.print_helpgeneration_configmodel_kwargsc                 C   sJ   | j d| d |r| j d|  | j d|  | j   dS )zFPrints the status of the model and generation settings to the console.z[bold blue]Model: r8   z[bold blue]Model kwargs: z[bold blue]NrM   )r.   r)   rT   rU   r/   r/   r0   print_status   s
   zRichInterface.print_status)NN)F)__name__
__module____qualname__r   strr1   r   rF   rH   rI   rJ   rN   boolrS   r   dictrV   r/   r/   r/   r0   r(   w   s    'r(   c                   @   s  e Zd ZU dZedddidZee ed< edddidZ	ee ed< eddd	idZ
ee ed
< edddidZeed< edddidZee ed< edddidZee ed< edddidZeed< edddidZeed< edddidZeed< edddidZeed< edddidZeed < eddd!idZeed"< eddd#idZeed$< eddd%idZee ed&< eddd'idZee ed(< ed)dd*idZeed+< ed,dd-idZeed.< ed/d0g d1d2dZee ed3< ed4dd5idZeed6< eddd7idZee ed8< ed4dd9idZeed:< ed4dd;idZ eed<< ed=d>d?d=gd2dZ!eed@< ed4ddAidZ"eedB< dS )CChatArgumentsz
    Arguments for the chat CLI.

    See the metadata arg for each argument's description -- the medatata will be printed with
    `transformers chat --help`
    Nhelpz_Name of the pre-trained model. The positional argument will take precedence if both are passed.)defaultmetadatamodel_name_or_pathzKUsername to display in chat interface. Defaults to the current user's name.r,   zSystem prompt.system_promptz./chat_history/zFolder to save chat history.save_folderz"Path to a yaml file with examples.examples_pathzPath to a local generation config file or to a HuggingFace repo containing a `generation_config.json` file. Other generation settings passed as CLI arguments will be applied on top of this generation config.rT   r   z%Maximum number of tokens to generate.r   Tz,Whether to sample outputs during generation.r   r
   z Number of beams for beam search.r   r    z%Temperature parameter for generation.r   r"   zValue of k for top-k sampling.r!   z Value of p for nucleus sampling.r#   zRepetition penalty.r$   zNEOS tokens to stop the generation. If multiple they should be comma separated.r%   zQEOS token IDs to stop the generation. If multiple they should be comma separated.r'   mainzLSpecific model version to use (can be a branch name, tag name or commit id).model_revisioncpuzDevice to use for inference.deviceautozOverride the default `torch.dtype` and load the model under this dtype. If `'auto'` is passed, the dtype will be automatically derived from the model's weights.)ri   Zbfloat16Zfloat16Zfloat32)r^   choicestorch_dtypeFz2Whether to trust remote code when loading a model.trust_remote_codezWhich attention implementation to use; you can run --attn_implementation=flash_attention_2, in which case you must install this manually by running `pip install flash-attn --no-build-isolation`.attn_implementationzIWhether to use 8 bit precision for the base model - works only with LoRA.load_in_8bitzIWhether to use 4 bit precision for the base model - works only with LoRA.load_in_4bitZnf4zQuantization type.Zfp4bnb_4bit_quant_typez#Whether to use nested quantization.use_bnb_nested_quant)#rW   rX   rY   __doc__r   ra   r   rZ   __annotations__r,   rb   rc   rd   rT   r   intr   r[   r   r   floatr!   r#   r$   r%   r'   rf   rh   rk   rl   rm   rn   ro   rp   rq   r/   r/   r/   r0   r]      s   
 r]   argsc                 C   s   t | S )z;
    Factory function used to chat with a local model.
    )ChatCommand)rv   r/   r/   r0   chat_command_factory%  s   rx   c                   @   sx  e Zd ZedefddZdd Zdedefdd	Zede	fd
dZ
ed-dedee	 de	fddZed-dee	 dee fddZdee	 defddZdededeeef fddZededee	 dee	 deeee f fddZededed fdd Zdedeeef fd!d"Zd#e	ded$ed%ee	ee	e	f f d&ed'ed(ee deee eef fd)d*Zd+d, ZdS ).rw   parserc                 C   sT   t f}| jd|d}|d}|jdtddd |jdtdd	d
d |jtd dS )z
        Register this command to argparse so it's available for the transformer-cli

        Args:
            parser: Root parser to register command-specific arguments
        chat)dataclass_typeszPositional argumentsmodel_name_or_path_positionalNzName of the pre-trained model.)typer_   r^   generate_flagsa  Flags to pass to `generate`, using a space as a separator between flags. Accepts booleans, numbers, and lists of integers, more advanced parameterization should be set through --generation-config. Example: `transformers chat <model_repo> max_new_tokens=100 do_sample=False eos_token_id=[1,2]`. If you're a new user, check this basic flag guide: https://huggingface.co/docs/transformers/llm_tutorial#common-options*)r}   r_   r^   nargs)func)r]   
add_parseradd_argument_groupadd_argumentrZ   set_defaultsrx   )ry   r{   Zchat_parsergroupr/   r/   r0   register_subcommand-  s   
zChatCommand.register_subcommandc                 C   s   |  |}|| _d S N)_handle_deprecated_argsrv   )r.   rv   r/   r/   r0   r1   J  s   

zChatCommand.__init__rv   r3   c              
   C   s   d}|j p|j|_ |j du rtd|jdurd}tdt tD ] \}}}t||}||krAd}td| d| d| d	t q!|rKtd
 t	  |S )z
        Handles deprecated arguments and their deprecation cycle. To be removed after we fully migrated to the new
        args.
        FNzOne of the following must be provided:
- The positional argument containing the model repo, e.g. `transformers chat <model_repo>`
- the optional --model_name_or_path argument, containing the model repo (deprecated)TzThe --model_name_or_path argument is deprecated will be removed in v4.54.0. Use the positional argument instead, e.g. `transformers chat <model_repo>`.zThe --a.   argument is deprecated will be removed in v4.54.0. There are two alternative solutions to specify this generation option: 
1. Pass `--generation-config <path_to_file/Hub repo>` to specify a generation config.
2. Pass `generate` flags through positional arguments, e.g. `transformers chat <model_repo> =`z
(Press enter to continue))
r|   ra   
ValueErrorwarningswarnFutureWarning_DEPRECATION_MAPgetattrr9   rH   )r.   rv   Zhas_warningsZdeprecated_argdefault_valueZnew_argvaluer/   r/   r0   r   N  s:   


	z#ChatCommand._handle_deprecated_argsc                   C   s$   t  dkr
t S tt jS )z)Returns the username of the current user.r   )platformsystemosgetloginpwdgetpwuidgetuidpw_namer/   r/   r/   r0   get_usernamey  s   zChatCommand.get_usernameNfilenamec                 C   s   i }t ||d< | |d< |j}|du r(td}|j d| d}tj||}tjtj	|dd t
|d	}tj||d
d W d   n1 sKw   Y  tj|S )z!Saves the chat history to a file.settingsZchat_historyNz%Y-%m-%d_%H-%M-%Sz/chat_.jsonT)exist_okwr5   )indent)varsrc   timestrftimer|   r   pathr>   makedirsdirnameopenjsondumpabspath)rz   rv   r   Zoutput_dictfolderZtime_strfr/   r/   r0   	save_chat  s   
zChatCommand.save_chatrb   c                 C   s    | du rg }|S d| dg}|S )zClears the chat history.Nr   Zrolecontentr/   )rb   rz   r/   r/   r0   clear_chat_history  s
   zChatCommand.clear_chat_historyr~   c                    s   t |dkri S dd |D }dd | D }dd | D }dtdtfdd	  fd
d| D }ddd | D }d| d }|dd}|dd}|dd}|dd}|dd}|dd}zt|}W |S  tjy   t	dw )zUParses the generate flags from the user input into a dictionary of `generate` kwargs.r   c                 S   s.   i | ]}d | dd  d  | dd qS )"r   r   r
   )split).0flagr/   r/   r0   
<dictcomp>  s   . z4ChatCommand.parse_generate_flags.<locals>.<dictcomp>c                 S   s*   i | ]\}}||  d v r|  n|qS ))truefalse)lowerr   kvr/   r/   r0   r     s    c                 S   s"   i | ]\}}||d krdn|qS )Nonenullr/   r   r/   r/   r0   r     s   " sr3   c                 S   s   |  ddd S )N.r4   r
   )replaceisdigit)r   r/   r/   r0   	is_number  s   z3ChatCommand.parse_generate_flags.<locals>.is_numberc                    s*   i | ]\}}| |sd | d n|qS )r   r/   r   r   r/   r0   r     s   * z, c                 S   s   g | ]\}}| d | qS )z: r/   r   r/   r/   r0   
<listcomp>  s    z4ChatCommand.parse_generate_flags.<locals>.<listcomp>{}z"null"r   z"true"r   z"false"r   z"[[z]"rL   r   :zFailed to convert `generate_flags` into a valid JSON object.
`generate_flags` = {generate_flags}
Converted JSON string = {generate_flags_string})
lenitemsrZ   r[   r>   r   r   loadsJSONDecodeErrorr   )r.   r~   Zgenerate_flags_as_dictZgenerate_flags_stringZprocessed_generate_flagsr/   r   r0   parse_generate_flags  s2   z ChatCommand.parse_generate_flags	tokenizerc              
   C   s   |j du r/t }| ||j|j\}}|j|j|j|j|j	|j
|j||d	}|jdi | n d|j v rItj|j }tj|j }t||}nt|j }| |j}	|jdi |	}
||
fS )zj
        Returns a GenerationConfig object holding the generation parameters for the CLI command.
        N)	r   r   r   r   r!   r#   r$   pad_token_idr&   r   r/   )rT   r   parse_eos_tokensr%   r'   r   r   r   r   r!   r#   r$   r@   r   r   r   basenamefrom_pretrainedr   r~   )r.   rv   r   rT   r   r'   Zdeprecated_kwargsr   r   Zparsed_generate_flagsrU   r/   r/   r0   get_generation_parameterization  s,   

z+ChatCommand.get_generation_parameterizationr%   r'   c                 C   s|   | j du r	| j}n| j }g }|dur|| |d |dur.|dd |dD  t|dkr:|| j ||fS )z:Retrieves the pad token ID and all possible EOS token IDs.N,c                 S   s   g | ]}t |qS r/   )rt   )r   Ztoken_idr/   r/   r0   r   	  s    z0ChatCommand.parse_eos_tokens.<locals>.<listcomp>r   )r   r&   extendZconvert_tokens_to_idsr   r   r<   )r   r%   r'   r   Zall_eos_token_idsr/   r/   r0   r     s   
zChatCommand.parse_eos_tokens
model_argsr   c                 C   s@   | j rtd| j| j| j| jd}|S | jrtdd}|S d }|S )NT)ro   Zbnb_4bit_compute_dtyperp   Zbnb_4bit_use_double_quantZbnb_4bit_quant_storage)rn   )ro   r   rk   rp   rq   rn   )r   quantization_configr/   r/   r0   get_quantization_config  s    z#ChatCommand.get_quantization_configc                 C   s   t j|j|j|jd}|jdv r|jntt|j}| |}|j|j	|d|d}t
j|jfd|ji|}t|dd d u rC||j}||fS )N)revisionrl   )ri   Nri   )r   rm   rk   Z
device_mapr   rl   Zhf_device_map)r   r   r|   rf   rl   rk   r   torchr   rm   r   torh   )r.   rv   r   rk   r   rU   modelr/   r/   r0   load_model_and_tokenizer&  s.   
z$ChatCommand.load_model_and_tokenizer
user_input	interfaceexamplesrT   rU   rz   c                 C   s  |dkr|  |j}|  n|dkr|  n|drIt| dk rI| }t|dkr4|d }	nd}	| |||	}	|jd|	 dd	d
 n|dr|dd 	 }
|
 }
|
D ]}d|vro|jd| ddd
  nq\| 
|
}|jdi |}|jdi | ne|drt| dkr| d }||v r|  g }||| d  |d|| d d n2d| dt|  d}|j|dd
 n|dkr|j|j||d n|jd| ddd
 |  |||fS )z
        Handles all user commands except for `!exit`. May update the chat history (e.g. reset it) or the
        generation config (e.g. set a new flag).
        z!clearz!helpz!save   r
   NzChat saved in !green)r   rK   z!setr5   r   z(Invalid flag format, missing `=` after `z;`. Please use the format `arg_1=value_1 arg_2=value_2 ...`.red!exampler   r,   r   zExample z* not found in list of available examples: r   z!status)r)   rT   rU   'z/' is not a valid command. Showing help message.r/   )r   rb   rI   rS   r=   r   r   r   rN   r?   r   r@   rJ   r<   listkeysrV   r|   )r.   r   rv   r   r   rT   rU   rz   Zsplit_inputr   Znew_generate_flagsr   Zparsed_new_generate_flagsZnew_model_kwargsZexample_nameZexample_errorr/   r/   r0   handle_non_exit_user_commandsA  s\   






z)ChatCommand.handle_non_exit_user_commandsc              	   C   s  t  stdt std| j}|jd u rt}nt|j}t|}W d    n1 s.w   Y  |j	d u r=| 
 }n|j	}| |\}}t|ddd}| ||\}}	t|j|d}
|
  | |j}|
jdd 	 zg|
 }|dr|dkrW d S | j|||
|||	|d	\}}}	|d
sW qmn|d|d |j|ddd|j}t|}||||d|	}t|j|d}|  |
 |}|!  |d|d W n
 t"y   Y d S w qn)NzHYou need to install rich to use the chat interface. (`pip install rich`)zJYou need to install torch to use the chat interface. (`pip install torch`)T)Zskip_special_tokensZskip_prompt)r)   r*   )rP   r   z!exit)r   rv   r   r   rT   rU   rz   r   r,   r   pt)Zreturn_tensorsZadd_generation_prompt)inputsattention_maskstreamerrT   )targetkwargsr+   )#r   ImportErrorr	   rv   rd   DEFAULT_EXAMPLESr   yamlZ	safe_loadr,   r   r   r   r   r(   r|   rI   r   rb   rS   rH   r=   r   r<   Zapply_chat_templater   rh   r   Z	ones_liker   generatestartrF   r>   KeyboardInterrupt)r.   rv   r   r   r,   r   r   Zgeneration_streamerrT   rU   r   rz   r   r   r   Zgeneration_kwargsthreadZmodel_outputr/   r/   r0   run  sv   







zChatCommand.runr   )rW   rX   rY   staticmethodr   r   r1   r]   r   rZ   r   r   r   r   r\   r   r   r   tupler   r   rt   r   r   r   r   r(   r   r   r/   r/   r/   r0   rw   ,  sd    +
3

'	
Rrw   )5r   r   r   stringr   r   argparser   r   dataclassesr   r   	threadingr   typingr   r   Ztransformers.utilsr   r	   r4   r   r   r   Zrich.consoler   Z	rich.liver   Zrich.markdownr   r   Ztransformersr   r   r   r   r   setascii_letters
whitespaceZALLOWED_KEY_CHARSdigitsZALLOWED_VALUE_CHARSr   rQ   r>   r   rR   r   r(   r]   rx   rw   r/   r/   r/   r0   <module>   sZ   		UX