o
    Zhod                     @   s   d dl Z d dlmZ d dlmZmZmZmZmZm	Z	 d dl
Z
d dlmZ ddlmZmZmZmZ ddlmZ ddlmZmZ ddlmZ dd	lmZ erRdd
lmZ eeZeG dd deZ G dd deZ!dS )    N)	dataclass)TYPE_CHECKINGDictListOptionalTupleUnion   )GenerateDecoderOnlyOutputGenerationConfigGenerationMixinGenerationMode)LogitsProcessorList)MaxLengthCriteriaStoppingCriteriaList)GenerateNonBeamOutput)logging)BaseStreamerc                   @   s(   e Zd ZU dZdZeeej  e	d< dS )CsmGenerateOutputap  
    Outputs of CsmForConditionalGeneration.generate.

    Args:
        sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
            if all batches finished early due to the `eos_token_id`.
        scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`):
            Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
            at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
            each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
        logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`):
            Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
            at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
            each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
        attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
            `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
        hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`):
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
            `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
        past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True`):
            Returns the model cache, used to speed up decoding. Different models have a different cache format, check
        audio (`list(torch.FloatTensor)` of length `batch_size`):
            The generated audio.
    Naudio)
__name__
__module____qualname____doc__r   r   r   torchTensor__annotations__ r   r   U/var/www/auris/lib/python3.10/site-packages/transformers/models/csm/generation_csm.pyr   *   s   
 r   c                       s  e Zd Zdef fddZ	ddee dee dede	eef f fdd	Z
d
ejdededededed deeejf fddZ									dd
eej deej deej dee dee dee dee ded dee deeejf f fddZ  ZS )CsmGenerationMixinreturnc                    sR   t  j|i |}t }|D ]}t|ts!td|jj d q|	| q|S )NzCsm does not support z' stopping criteria, it will be ignored.)
super_get_stopping_criteriar   
isinstancer   loggerwarning	__class__r   append)selfargskwargscriteriaZkept_criteria	criterionr&   r   r   r"   K   s   
z)CsmGenerationMixin._get_stopping_criteriaNgeneration_configuse_model_defaultsr*   c           	         s  dd |  D }dd |  D }t j||fi |\}}| jjjdi | t| jjdp5| jjd }t| jjdpB| jjd }||h| jjd hkr`t	d| d| d	| jjd  d
| jjj
rotd d| jj_
|| jj_|| jj_|j d fdd	}||_||fS )z
        This method overrides [~generation.utils.GenerationMixin._prepare_generation_config].
        It ensures that the depth decoder generation config is initialized and that passed args as depth_decoder_* are properly handled.
        c                 S   s,   i | ]\}}| d r|td d |qS )depth_decoder_N)
startswithlen.0kvr   r   r   
<dictcomp>d   s
    zACsmGenerationMixin._prepare_generation_config.<locals>.<dictcomp>c                 S   s    i | ]\}}| d s||qS )r0   )r1   r3   r   r   r   r7   i   s     min_new_tokens   max_new_tokensz2depth_decoder_generation_config's min_new_tokens (z) and max_new_tokens (z2) must be equal to self.config.num_codebooks - 1 ()zdepth_decoder_generation_config.return_dict_in_generate is set to True, but this will be ignored as the depth decoder model does not return a dictionary in generateFNc                    s,    | }|t jt jfvrtd| d|S )NzGeneration mode zg is not supported for CSM model. Please set generation parameters to use greedy or sampling generation.)r   ZGREEDY_SEARCHZSAMPLE
ValueError)Zassistant_modelZgeneration_modeZoriginal_get_generation_moder   r   patched_get_generation_mode   s   
zRCsmGenerationMixin._prepare_generation_config.<locals>.patched_get_generation_moder   N)itemsr!   _prepare_generation_configdepth_decoderr.   updategetattrconfigZnum_codebooksr<   return_dict_in_generater$   r%   r8   r:   Zget_generation_mode)	r(   r.   r/   r*   Zdepth_decoder_kwargsmodel_kwargsZdepth_decoder_min_new_tokensZdepth_decoder_max_new_tokensr>   r-   r=   r   rA   \   s>   






	z-CsmGenerationMixin._prepare_generation_config	input_idslogits_processorstopping_criteriasynced_gpusstreamerr   c           '   	   K   s
  | j j}|jdu}	|j}
|j}|j}|j}|j}|j}|r!|r!dnd}|r)|r)dnd}|r1|
r1dnd}|r9|r9dnd}|j	dd \}}d}t
j|t
j|jd}| ||j|}|jdkru|ddu ru|D ]}t|trt| j|8  _qf| j}| ||}|rdtjd< | |j}d	}| j|||jd
r| j|fi |}||
rd|
ini  |dd	i |r| di |dd	i}d}n|di |dd	i}| ||}|r|rq|jdddddf    }|!|j}|||}|r|r||f7 }|r||f7 }|
r||j"f7 }|r||j#f7 }|r-t$j%j&|dd} t
j'| dd(d}!nt
j)|dd}!|!dddf }"t$j%j*|"ddd}#|j#d dddddf }$| j+j,|#|$ d}%t|%t
j-rf|%n|%j.}&|&ddddf }&|&}!|	r|!|/d |d|/d   }!|jdkr|!dddddf }nt
j0||!dddddf gdd}|dur|1|!2  ||dddddf | j j3k4d @ }|||| @ }|5 dk}|d7 }~~%| j|||jd
s|dur|6  |rt7||||||ddS |S )a!  
        This method overrides [~generation.utils.GenerationMixin._sample].
        To ease maintenance, modifications are marked with the comment "Csm specific".

        Indeed, Csm model requires a custom generation sampling step:
        1. Infer the backbone model to sample the first codebook token
        2. Call generate on the depth decoder with the first codebook token as input_ids to sample the next codebook tokens
        3. Use these generated codebook tokens as input_ids to sample the next first codebook token using the backbone model
        4. Repeat until stopping criteria is met

        Csm supports two stopping criterias:
        - stop when the generated sequence is at max_length
        - stop when all the generated codebook tokens are the codebook_eos_token_id
        Nr      F)ZdtypedeviceZinputs_embeds0ZTOKENIZERS_PARALLELISMT)rN   output_attentionsoutput_hidden_statesZreturn_dictdimr9   )Znum_samples)r9   r   r   )value)rH   backbone_last_hidden_statepast_key_values)	sequencesscoreslogits
attentionshidden_statesrW   )8rE   Zcodebook_pad_token_idZ_eos_token_tensorrP   rQ   output_scoresoutput_logitsrF   	do_sampleshaper   ZoneslongrN   Z_get_initial_cache_positionndimgetr#   r   
max_length__call__Z_valid_auto_compile_criteriaosenvironZget_compiled_callZcompile_configZ_has_unfinished_sequencesZprepare_inputs_for_generationrC   Z#_update_model_kwargs_for_generationrZ   clonefloattor[   r\   nnZ
functionalZsoftmaxZmultinomialZsqueezeZargmaxpadrB   generater   rX   	unsqueezecatputcpucodebook_eos_token_idallmaxendr
   )'r(   rH   rI   rJ   r.   rK   rL   rG   Zpad_token_idZhas_eos_stopping_criteriarP   rQ   r]   r^   rF   r_   rY   Z
raw_logitsZdecoder_attentionsZdecoder_hidden_statesZ
batch_sizecur_lenZthis_peer_finishedZunfinished_sequencesr,   Zmodel_forwardZcompile_forwardZ
is_prefillZmodel_inputsZoutputsZnext_token_logitsZnext_token_scoresZprobsZnext_tokensZfirst_codebook_idsZdepth_decoder_input_idsrV   Zdepth_decoder_outputsZcodebook_idsr   r   r   _sample   s   


 



$

m	zCsmGenerationMixin._sampleFinput_valuesinput_values_cutoffsoutput_audioc
                    s  t  jd	||||||||d|
}t|tj }d}|	rv|r"|jn|}g }t D |D ]9}|| jjkj	dd
 }| dkrF| }n|jd }|d| }| j|ddd}||jd  q-W d   n1 sqw   Y  |rtd	d|i|S |	r|S |S )
aI  
        This method overrides [`~generation.utils.GenerationMixin.generate`] to match the specifics of the Csm model.
        Indeed, Csm model requires a custom generation sampling step:
        1. Infer the backbone model to sample the first codebook token
        2. Call generate on the depth decoder with the first codebook token as `input_ids` to sample the next codebook tokens
        3. Use these generated codebook tokens as `input_ids` to sample the next first codebook token using the backbone model
        4. Repeat until stopping criteria is met

        <Tip warning={true}>

        Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
        model's default generation configuration. You can override any `generation_config` by passing the corresponding
        parameters to generate(), e.g. `.generate(inputs, do_sample=True)`.
        </Tip>

        Parameters:
            inputs_ids (`torch.Tensor` of shape (batch_size, seq_length), *optional*):
                The sequence used as a prompt for the backbone model.
            input_values (`torch.Tensor` of shape (batch_size, channels, max_concatenated_audio_length), *optional*):
                The batched audio input values, where each batch entry contains the concatenation of all audio segments for that entry.
                These values will be encoded into codebook tokens using the codec model and merged with the text input ids provided in `input_ids`.
            input_values_cutoffs (`torch.Tensor` of shape (batch_size, max_num_audio), *optional*):
                Specify the end positions of audio segments within each batch entry, relative to the concatenated audio input.
                If a batch entry has fewer segments than the maximum, it is padded with -1. For example, in a batch of 2 sequences
                where the first contains 2 audio segments of length l1, and the second contains 1 audio segment of length l2,
                the input_values_cutoffs would be: [[l1, 2 * l1], [l2, -1]].
            generation_config ([`~generation.GenerationConfig`], *optional*):
                The generation configuration to be used as base parametrization for the generation call. `**kwargs`
                passed to generate matching the attributes of `generation_config` will override them. If
                `generation_config` is not provided, the default will be used, which has the following loading
                priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
                configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
                default values, whose documentation should be checked to parameterize generation.
            logits_processor (`LogitsProcessorList`, *optional*):
                Custom logits processors that complement the default logits processors built from arguments and
                generation config. If a logit processor is passed that is already created with the arguments or a
                generation config an error is thrown. This feature is intended for advanced users.
            stopping_criteria (`StoppingCriteriaList`, *optional*):
                Custom stopping criteria that complements the default stopping criteria built from arguments and a
                generation config. If a stopping criteria is passed that is already created with the arguments or a
                generation config an error is thrown. If your stopping criteria depends on the `scores` input, make
                sure you pass `return_dict_in_generate=True, output_scores=True` to `generate`. This feature is
                intended for advanced users.
            synced_gpus (`bool`, *optional*):
                Whether to continue running the while loop until max_length. Unless overridden, this flag will be set
                to `True` if using `FullyShardedDataParallel` or DeepSpeed ZeRO Stage 3 with multiple GPUs to avoid
                deadlocking if one GPU finishes generating before other GPUs. Otherwise, defaults to `False`.
            streamer (`BaseStreamer`, *optional*):
                Streamer object that will be used to stream the generated sequences. Generated tokens are passed
                through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
            output_audio (`bool`, *optional*):
                Whether to return the generated audio.
            kwargs (`Dict[str, Any]`, *optional*):
                Ad hoc parametrization of `generation_config` and/or additional model-specific kwargs that will be
                forwarded to the `forward` function of the model. Depth decoder specific kwargs should be prefixed with *depth_decoder_*.

        Return:
            [`CsmGenerateOutput`] or `torch.LongTensor` or `List[torch.FloatTensor]`: A [`CsmGenerateOutput`]
            (if `return_dict_in_generate=True` or when `config.return_dict_in_generate=True`) or a `torch.LongTensor` when `output_audio=False`
            or a `List[torch.FloatTensor]` otherwise.

        Example:

        ```python
        >>> from transformers import CsmProcessor, CsmForConditionalGeneration
        >>> from datasets import load_dataset, Audio

        >>> model_id = "eustlb/csm-1b"
        >>> torch_device = "cuda" if torch.cuda.is_available() else "cpu"

        >>> processor = AutoProcessor.from_pretrained(model_id)

        >>> ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
        >>> # ensure the audio is 24kHz
        >>> ds = ds.cast_column("audio", Audio(sampling_rate=24000))

        >>> conversation = []
        >>> # prepare a conversation with text and corresponding audio
        >>> for text, audio, speaker_id in zip(ds[:4]["text"], ds[:4]["audio"], ds[:4]["speaker_id"]):
        ...     conversation.append(
        ...         {
        ...             "role": f"{speaker_id}",
        ...             "content": [{"type": "text", "text": text}, {"type": "audio", "path": audio["array"]}],
        ...         }
        ...     )

        >>> # text prompt
        >>> conversation.append({"role": f"{ds[4]['speaker_id']}", "content": [{"type": "text", "text": ds[4]["text"]}]})

        >>> inputs = processor.apply_chat_template(
        ...     conversation,
        ...     tokenize=True,
        ...     return_dict=True,
        ... ).to(torch_device)

        >>> model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=torch_device)
        >>> audio = model.generate(**inputs, output_audio=True)
        >>> processor.save_audio(audio, "output.wav")
        ```
        )rH   rx   ry   r.   rI   rJ   rK   rL   NrR   rS   r   r9   )r   r   r   r   )r!   rm   r#   r   r   rX   Zno_gradrE   rr   rs   ZnonzeroZnumelminr`   Zcodec_modeldecodeZ	transposern   r'   Zaudio_valuesr   )r(   rH   rx   ry   r.   rI   rJ   rK   rL   rz   r*   Zgenerate_outputZgenerate_returned_dictr   Zgenerated_audio_codesZaudio_codes_batchZeos_idxsZ
cutoff_idxZcodec_decode_outputr-   r   r   rm   R  sB   q	


zCsmGenerationMixin.generater?   )	NNNNNNNNF)r   r   r   r   r"   r   r   boolr   r   rA   r   Z
LongTensorr   r   r   rw   r   rm   __classcell__r   r   r-   r   r   J   sx    
:	
 ?	
r   )"rf   dataclassesr   typingr   r   r   r   r   r   r   Ztorch.nnrk   Z
generationr
   r   r   r   Zgeneration.logits_processr   Zgeneration.stopping_criteriar   r   Zgeneration.utilsr   utilsr   Zgeneration.streamersr   Z
get_loggerr   r$   r   r   r   r   r   r   <module>   s     
