o
    ZhJ                     @   s  d dl mZ d dlmZmZmZmZ d dlZd dlm	Z	 ddl
mZmZ ddlmZ ddlmZ ddlmZmZ dd	lmZ dd
lmZ ddlmZmZmZmZ ddlmZ ddlm Z m!Z!m"Z"m#Z#m$Z$m%Z%m&Z&m'Z' ddl(m)Z)m*Z* ddl+m,Z, e-e.Z/eG dd deZ0eddeG dd deZ1G dd de&Z2G dd de'Z3G dd de$Z4G dd de!Z5G d d! d!e"Z6eG d"d# d#e%Z7G d$d% d%e	j8Z9ed&dG d'd( d(e#eZ:G d)d* d*e	j8Z;eG d+d, d,e%Z<ed-dG d.d/ d/e1e,Z=g d0Z>dS )1    )	dataclass)ListOptionalTupleUnionN   )CacheDynamicCache)GenerationMixin)FlashAttentionKwargs)BaseModelOutputWithPastCausalLMOutputWithPast)PreTrainedModel)Unpack)ModelOutputauto_docstringcan_return_tuplelogging   )	AutoModel)KwargsForCausalLMLlamaAttentionLlamaDecoderLayerLlamaForCausalLMLlamaMLP
LlamaModelLlamaRMSNormLlamaRotaryEmbedding   )	CsmConfigCsmDepthDecoderConfig)CsmGenerationMixinc                   @   s   e Zd ZU dZdZeej ed< dZ	ejed< dZ
eeeej   ed< dZeeejdf  ed< dZeeejdf  ed< dZeej ed	< dZejed
< dZeeeej   ed< dZeeejdf  ed< dZeeejdf  ed< dZeej ed< dS )CsmOutputWithPastaf  
    Base class for the model autoregressive outputs.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Language modeling loss (for next-token prediction).
        logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
            Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
            `(batch_size, num_heads, sequence_length, embed_size_per_head)`)

            Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
            `past_key_values` input) to speed up sequential decoding.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
        depth_decoder_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Language modeling loss (for next-token prediction) of the depth decoder model.
        depth_decoder_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
            Prediction scores of the depth decoder (scores for each vocabulary token before SoftMax).
        depth_decoder_past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
            Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
            `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
        depth_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
        depth_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.
        backbone_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Language modeling loss (for next-token prediction) of the backbone model.
    Nlosslogitspast_key_values.hidden_states
attentionsdepth_decoder_lossdepth_decoder_logitsdepth_decoder_past_key_valuesdepth_decoder_hidden_statesdepth_decoder_attentionsbackbone_loss)__name__
__module____qualname____doc__r#   r   torchFloatTensor__annotations__r$   r%   r   r&   r'   r(   r)   r*   r+   r,   r-    r5   r5   R/var/www/auris/lib/python3.10/site-packages/transformers/models/csm/modular_csm.pyr"   /   s   
 ,r"   z[
    The bare Csm Model outputting raw hidden-states without any specific head on top.
    )Zcustom_introc                   @   sD   e Zd ZeZdZdZdgZdgZdZ	dZ
dZdZdZdZdd ZdS )CsmPreTrainedModelmodelTCsmDecoderLayerr%   c                 C   s   | j j}t|tjr"|jjjd|d |jd ur |jj	  d S d S t|tj
rC|jjjd|d |jd urA|jj|j 	  d S d S t|tra|j}t|d D ]}|jj| jd|d qQd S t|tro|jjd d S d S )Ng        )meanstdr   g      ?)configZinitializer_range
isinstancennLinearweightdataZnormal_biasZzero_	EmbeddingZpadding_idxCsmCodebooksHeadnum_codebooksrange
CsmRMSNormZfill_)selfmoduler;   rE   ir5   r5   r6   _init_weights   s&   



z CsmPreTrainedModel._init_weightsN)r.   r/   r0   r   config_classZbase_model_prefixZsupports_gradient_checkpointingZ_no_split_modulesZ_skip_keys_device_placementZ_supports_flash_attn_2Z_supports_sdpaZ_supports_cache_classZ_supports_quantized_cacheZ_supports_static_cacheZ_supports_attention_backendrK   r5   r5   r5   r6   r7   j   s    r7   c                   @      e Zd ZdS )rG   Nr.   r/   r0   r5   r5   r5   r6   rG          rG   c                   @   rM   )CsmRotaryEmbeddingNrN   r5   r5   r5   r6   rP      rO   rP   c                   @   rM   )CsmMLPNrN   r5   r5   r5   r6   rQ      rO   rQ   c                   @   rM   )CsmAttentionNrN   r5   r5   r5   r6   rR      rO   rR   c                   @   rM   )r9   NrN   r5   r5   r5   r6   r9      rO   r9   c                       s   e Zd ZeZ fddZee										ddej	de
ej de
ej de
ej	 de
e d	e
ej d
e
e de
e de
e de
ej	 dee deeef fddZ  ZS )CsmDepthDecoderModelc                    s>   t  | t|j|j |j| _tj|j|j	dd| _
d S NF)rB   )super__init__r>   rC   rE   
vocab_sizeZbackbone_hidden_sizeembed_tokensr?   hidden_sizeinputs_embeds_projectorrH   r<   	__class__r5   r6   rV      s   zCsmDepthDecoderModel.__init__N	input_idsbackbone_last_hidden_stateattention_maskposition_idsr%   inputs_embeds	use_cacheoutput_attentionsoutput_hidden_statescache_positionflash_attn_kwargsreturnc                 K   sX  |durt j std d}|dur|n| jj}|	dur |	n| jj}	|dur*|n| jj}|du |duA r:t	d| j
rI| jrI|rItd d}|rR|du rRt }|
du r|dur^| nd}|duri|jd n|jd }|duru|jn|j}t j||| |d}
|du rt j|
d dd	}|| j }| || }|
d dk}|dur||dddf< nt j s|rtd
 | |}| |||
||}|}|
d}| ||}|	rdnd}|rdnd}| jd| jj D ]'}|	r||f7 }||f||||||
|d|}|d }|r||d f7 }q| |}|	r||f7 }t||r&|nd||dS )aJ  
        backbone_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, backbone_hidden_size)`, *optional*):
            The last hidden state of the backbone model. Such input is required when the first codebook token (the one generated by the backbone model)
            is provided in the `input_ids` argument.
        NzCustom `position_ids` were provided but will be ignored. CSM depth decoder automatically determines position_ids from `cache_position` and as it requires them to be identical across the batch, the provided position_ids will be ignored.z;You must specify exactly one of input_ids or inputs_embeds.zX`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.Fr   r   device)minzvWhen the first codebook token is provided, `backbone_last_hidden_state` should also be provided for correct inference.r5   )r`   ra   Zpast_key_valuerd   rc   rf   position_embeddings)Zlast_hidden_stater%   r&   r'   )r2   compilerZis_compilingloggerZwarning_oncer<   rd   re   rc   
ValueErrorZgradient_checkpointingZtrainingr	   Zget_seq_lengthshaperj   arangeclamprW   rX   warningrZ   Z_update_causal_mask	unsqueezeZ
rotary_embZlayersZnum_hidden_layersZnormr   )rH   r^   r_   r`   ra   r%   rb   rc   rd   re   rf   rg   Zpast_seen_tokensZinputs_seq_lengthrj   codebook_idxsoffsetZinput_ids_are_first_codebookcausal_maskr&   rl   Zall_hidden_statesZall_self_attnsZdecoder_layerZlayer_outputsr5   r5   r6   forward   s   




	

zCsmDepthDecoderModel.forward)
NNNNNNNNNN)r.   r/   r0   r    rL   rV   r   r   r2   
LongTensorr   r3   Tensorr   boolr   r   r   r   r   rx   __classcell__r5   r5   r\   r6   rS      sP    	

rS   c                       s&   e Zd Z fddZdddZ  ZS )rD   c                    s0   t    || _tt| jd ||| _d S )Nr   )rU   rV   rE   r>   	Parameterr2   emptyr@   )rH   rY   rE   rW   r\   r5   r6   rV   %  s   
 zCsmCodebooksHead.__init__Nc                    sf   |d u rj d }| jt|  n	|d }| j|   fddt j d D tjddS )Nr   c              	      s2   g | ]}t jd d |d d f  | jqS N)r>   
functionalZlinearT).0Zcodebook_idxZcodebook_weightr&   r5   r6   
<listcomp>2  s    $z,CsmCodebooksHead.forward.<locals>.<listcomp>r   dim)rp   r@   r2   rq   rF   stack)rH   r&   rf   Z
seq_lengthru   r5   r   r6   rx   *  s   

zCsmCodebooksHead.forwardr   r.   r/   r0   rV   rx   r|   r5   r5   r\   r6   rD   $  s    rD   a$  
    The CsmDepthDecoder Model transformer, with a [`CsmCodebooksHead`] on top,
    which can be seen a position-specific language modeling head, allowing to use a different linear layer for each codebook
    (e.g. position 0 is the first codebook and uses the first codebook head, etc.)
    c                !       s:  e Zd ZdZdZdZ fddZdd Zdd Z				dde	j
d	ee d
ee	j
 dee	j dee	j
 f
 fddZee												dde	j
dee	j d
ee	j dee	j
 d	eeeee	j f  dee	j dee	j
 dee dee dee dee	j
 deee	jf dee deeef fddZ  ZS )CsmDepthDecoderForCausalLMNc                    s2   t  | | `t|j|j|j| _t|| _	d S r   )
rU   rV   lm_headrD   rY   rE   rW   codebooks_headrS   r8   r[   r\   r5   r6   rV   G  s   z#CsmDepthDecoderForCausalLM.__init__c                 C      t dNzNot needed for CsmAttributeErrorrH   r5   r5   r6   get_output_embeddingsM     z0CsmDepthDecoderForCausalLM.get_output_embeddingsc                 C   r   r   r   rH   Znew_embeddingsr5   r5   r6   set_output_embeddingsP  r   z0CsmDepthDecoderForCausalLM.set_output_embeddingsr^   r%   r`   rb   rf   c           	         sH   t  j|||||fi |}|d d dk}|s|d |d |S )Nrf   r   r_   ra   )rU   prepare_inputs_for_generationpop)	rH   r^   r%   r`   rb   rf   kwargsmodel_inputsZis_first_generation_stepr\   r5   r6   r   S  s   	


z8CsmDepthDecoderForCausalLM.prepare_inputs_for_generationr   r_   ra   labelsrc   rd   re   logits_to_keepr   rh   c                 K   s  |	dur|	n| j j}	|
dur|
n| j j}
| jd||||||||	|
|d
|}|d }t|trA|dkr:tdd}n	t| d}n|}| |dd|ddf |durW|| nd}| }d}|dur}|dddf  }| j	d|d| j j
|d|}t|||j|j|jdS )	a  
        backbone_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, backbone_hidden_size)`, *optional*):
            The last hidden state of the backbone model. Such input is required when the first codebook token (the one generated by the backbone model)
            is provided in the `input_ids` argument.
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
        N)
r^   r_   r`   ra   r%   rb   rc   rd   re   rf   r   r   .)r$   r   rW   shift_labels)r#   r$   r%   r&   r'   r5   )r<   rd   re   r8   r=   intslicer   
contiguousloss_functionrW   r   r%   r&   r'   )rH   r^   r_   r`   ra   r%   rb   r   rc   rd   re   rf   r   r   Zoutputsr&   slice_indicesr$   r#   r   r5   r5   r6   rx   i  sT   
&z"CsmDepthDecoderForCausalLM.forwardNNNN)NNNNNNNNNNNr   )r.   r/   r0   _tied_weights_keysZ_tp_planZ_pp_planrV   r   r   r2   ry   r   r   r3   r   r   r   rz   r   r   r{   r   r   r   r   r   rx   r|   r5   r5   r\   r6   r   ;  s    	

r   c                       s$   e Zd Z fddZdd Z  ZS )CsmBackboneModelEmbeddingsc                    sD   t    t|j|j |j| _| jdt	
|j|j dd d S )Naudio_tokens_offsetsF)
persistent)rU   rV   r>   rC   rE   rW   rY   embed_audio_tokensZregister_bufferr2   rq   r[   r\   r5   r6   rV     s
   

z#CsmBackboneModelEmbeddings.__init__c                 C   s    |  || j }|jdd}|S )Nr   r   )r   r   sum)rH   r^   Zinput_embedsr5   r5   r6   rx     s   z"CsmBackboneModelEmbeddings.forwardr   r5   r5   r\   r6   r     s    r   c                       s0   e Zd Z fddZee fddZ  ZS )CsmBackboneModelc                    s   t  | t|| _d S r   )rU   rV   r   rX   r[   r\   r5   r6   rV     s   zCsmBackboneModel.__init__c                    s   t  jdi |S )a&  
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks) or (batch_size, sequence_length)`):
            1. (batch_size, sequence_length): corresponds to the input sequence prepared with the processor from the text prompt. Such input
            requires `input_values` to be provided so that audio can be encoded in codebook tokens and then merged with the text tokens.

            2. (batch_size, sequence_length, num_codebooks): codebook tokens generated during the autoregressive decoding. Such input is not meant to be used by end users.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        Nr5   )rU   rx   )rH   Zsuper_kwargsr\   r5   r6   rx     s   zCsmBackboneModel.forward)r.   r/   r0   rV   r   r   rx   r|   r5   r5   r\   r6   r     s
    r   z
    The Csm model consists of two llama-like auto-regressive transformer models: a backbone model that predicts the first codebook token and a depth decoder that predicts the other codebook tokens.
    c                #       s  e Zd ZddgZ fddZdd Zdd Zd	d
 Zdd Zdd Z	e
 fddZ fddZ				d0deej deej deej deej deej f
ddZ				d0dejdee deej deej deej f
 fdd Zee													!d1dejdeej deej deej d"eej deeeeej f  deej deej d#ee d$ee d%ee deej d&eeejf d'ee deeef fd(d)Ze dejd*ed+ed,ej!dejd-efd.d/Z"  Z#S )2CsmForConditionalGenerationz5backbone_model.embed_tokens.embed_audio_tokens.weightz'depth_decoder.model.embed_tokens.weightc                    sp   t  | |j| _tj|j|jdd| _t|j|j| _	t
|| _t|j| _t|j| _|   d S rT   )rU   rV   rW   r>   r?   rY   r   rC   Ztext_vocab_sizeembed_text_tokensr   Z_from_configbackbone_modelr   Zdepth_decoder_configdepth_decoderr   Zfrom_configZcodec_configcodec_modelZ	post_initr[   r\   r5   r6   rV     s   z$CsmForConditionalGeneration.__init__c                 C   s   | j jS r   r   rX   r   r5   r5   r6   get_input_embeddings  r   z0CsmForConditionalGeneration.get_input_embeddingsc                 C   s   || j _d S r   r   )rH   valuer5   r5   r6   set_input_embeddings  s   z0CsmForConditionalGeneration.set_input_embeddingsc                 C   s   | j S r   r   r   r5   r5   r6   r     s   z1CsmForConditionalGeneration.get_output_embeddingsc                 C   s
   || _ d S r   r   r   r5   r5   r6   r     s   
z1CsmForConditionalGeneration.set_output_embeddingsc                 C   s(   | j jr| | jjj| jjj d S d S r   )r<   Ztie_codebooks_embeddingsZ_tie_or_clone_weightsr   rX   r   r   r8   r   r5   r5   r6   _tie_weights  s   z(CsmForConditionalGeneration._tie_weightsc                    s   | ddrt j|i |\}}n	t j|i |}d t  fddt|j D }t|jjddi| |D ]
}t	|j |  q?d|v rR||fS |S )NZoutput_loading_infoFdepth_decoder_c                    s(   i | ]\}}|  r|d  |qS r   )
startswith)r   attrr   prefix
prefix_lenr5   r6   
<dictcomp>  s    z?CsmForConditionalGeneration.from_pretrained.<locals>.<dictcomp>Z_from_model_config)
getrU   from_pretrainedlenvarsgeneration_configitemsr   updatedelattr)clsargsr   r8   Zloading_infodepth_decoder_attrsr   r\   r   r6   r     s   z+CsmForConditionalGeneration.from_pretrainedc                    sV   d}| j j }|dd  | D ]\}}t| j|| | qt j|i | d S )Nr   Ztransformers_version)r   r   Zto_diff_dictr   r   setattrrU   save_pretrained)rH   r   r   r   r   r   r   r\   r5   r6   r   #  s   z+CsmForConditionalGeneration.save_pretrainedNr^   input_valuesinput_values_cutoffsr   rh   c                    s  |  |}|durtj|d}||dk  }||dk }tj| |jd	t
|d}||dk }g }t||D ]?\}	}
|
|
dk }
t|
jd d D ]+}|
| }|
|d  }|	d||f }| j|d}|jdd}||d  qPq=tdd	 |D  t fd
d|D }| j|}| jj}||k}| j|}|| ||< tjdd| jjf|jtjd| jj }| j|d}|| jjk}| |! d||< |dur|d dd| jj}|| ||< |dkj"dd}d||d |d ddf< |}||dS )a  
        Merges the input_ids and input_values to produce a single inputs_embeds tensor:
        1 - Infers the codec model on the input_values to retreive codebook token.
        2 - Embeds codebook tokens and places them at the correct positions in the inputs_embeds tensor.
        3 - If labels are provided, expands them to match codebook dimensions and position the target codebook tokens in the inputs_embeds tensor.

        Args:
            input_ids (`torch.Tensor` of shape `(batch_size, sequence_length)`):
                The input ids to embed.
            input_values (`torch.Tensor` of shape `(batch_size, channels, audio_sequence_length)`):
                The audio input values to embed.
            input_values_cutoffs (`torch.Tensor` of shape `(batch_size, max_num_audio)`):
                The cutoffs of the audio input values relative to its batch index, padded with -1 when no audio.
        Nr   r   r   ri   r   .c                 s   s    | ]}|j d  V  qdS )r   N)rp   r   elr5   r5   r6   	<genexpr>\  s    zQCsmForConditionalGeneration._merge_input_ids_with_input_values.<locals>.<genexpr>c                    s,   g | ]}t j|d d d  |jd   fqS )r   )r>   r   padrp   r   Zmax_audio_framesr5   r6   r   ^  s   , zRCsmForConditionalGeneration._merge_input_ids_with_input_values.<locals>.<listcomp>)rj   dtypeiTas_tuple)rb   r   )#r   r>   r   r   diffr2   rq   maxrj   expandr   rt   ziprF   rp   r   encodeZaudio_codesZ	transposeappendr   Zget_audio_codes_maskr<   audio_token_idr   rX   ZonesrE   longZcodebook_eos_token_idZsqueezeZaudio_eos_token_idrepeatr   nonzero)rH   r^   r   r   r   rb   Zaudio_lengthsZinput_values_maskZaudio_tokens_listZbatch_input_valuesZbatch_input_values_cutoffsrJ   Z	start_idxZend_idxZaudio_batchZcodec_outputsZcodebook_idsZbatched_audio_token_idsZaudio_codes_maskr   Zaudio_token_maskZaudio_embedsZaudio_eos_frame_idsZaudio_eos_embedsZaudio_eos_token_maskZlabels_expandedZ depth_decoder_ignore_frames_idxsr5   r   r6   "_merge_input_ids_with_input_values-  sV   



z>CsmForConditionalGeneration._merge_input_ids_with_input_valuesr%   r`   rb   rf   c           	         s   t  jd	|||||d|}|d ur>|jdkr>|dd u r>| j||d|d|dd}||d |d d d |S )
N)r^   r%   r`   rb   rf   r   rb   r   r   r   )r^   r   r   r   )rb   r   r^   r5   )rU   r   ndimr   r   r   )	rH   r^   r%   r`   rb   rf   r   r   merged_inputsr\   r5   r6   r   }  s(   	 	z9CsmForConditionalGeneration.prepare_inputs_for_generationr   ra   rc   rd   re   r   r   c                 K   s  |
dur|
n| j j}
|dur|n| j j}|dur/|jdkr/| ||||}|d }|d }d}| jd||||||	|
||d	|}|d }t|trPt| dn|}| 	|dd|ddf }d}d}d}d}|dur|dddddf }| j
d||| j jd|}|ddddddf d	kjd
d }|| dd| j jd f }tjj|ddd}|jdd}||d |d d ddf }|| }| j|||	|
|d|d}|j}|| }t|||||j|j|j|dur|jnd|dur|jnd|dur|jnd|dur	|jdS ddS )a  
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks) or (batch_size, sequence_length)`):
            1. (batch_size, sequence_length): corresponds to the input sequence prepared with the processor from the text prompt. Such input
            requires `input_values` to be provided so that audio can be encoded in codebook tokens and then merged with the text tokens.

            2. (batch_size, sequence_length, num_codebooks): codebook tokens generated during the autoregressive decoding. Such input is not meant to be used by end users.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        input_values_cutoffs (`torch.Tensor` of shape `(batch_size, max_num_audio)`, *optional*):
            Specify the end positions of audio segments within each batch entry, relative to the concatenated audio input.
            If a batch entry has fewer segments than the maximum, it is padded with -1. For example, in a batch of 2 sequences
            where the first contains 2 audio segments of length l1, and the second contains 1 audio segment of length l2,
            the input_values_cutoffs would be: [[l1, 2 * l1], [l2, -1]].
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should be in `[config.audio_token_id, -100, -101]`.
            Requires targeted `input_values` to be provided as audio tokens will be infered from it using the `codec_model`.
            - `config.audio_token_id` indicates an audio frames (considering sequence length elements as frames)
            - `-100` will be ignored in the loss computation
            - `-101` indicates the audio frame will be used only for the backbone model (using the first codebook token as labels)

            Such labels can be prepared using `output_labels=True` when calling [`CsmProcessor`].
        logits_to_keep (`int` or `torch.Tensor`, *optional*):
            Kept for compatibility. Does not support another value than:
            1. `0`, which is equivalent to keeping all logits, used in the training regime
            2. `1`, which is equivalent to keeping only the last logit, used in the generation regime

        Example:

        ```python
        >>> import torch
        >>> from transformers import CsmForConditionalGeneration, AutoProcessor
        >>> from datasets import load_dataset, Audio

        >>> model_id = "eustlb/csm-1b"
        >>> torch_device = "cuda" if torch.cuda.is_available() else "cpu"

        >>> processor = AutoProcessor.from_pretrained(model_id)

        >>> ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
        >>> # ensure the audio is 24kHz
        >>> ds = ds.cast_column("audio", Audio(sampling_rate=24000))

        >>> conversation = []
        >>> # prepare a conversation with text and corresponding audio
        >>> for text, audio, speaker_id in zip(ds[:4]["text"], ds[:4]["audio"], ds[:4]["speaker_id"]):
        ...     conversation.append(
        ...         {
        ...             "role": f"{speaker_id}",
        ...             "content": [{"type": "text", "text": text}, {"type": "audio", "path": audio["array"]}],
        ...         }
        ...     )

        >>> inputs = processor.apply_chat_template(
        ...     conversation,
        ...     tokenize=True,
        ...     return_dict=True,
        ...     output_labels=True,
        ... ).to(torch_device)

        >>> model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=torch_device)
        >>> output = model(**inputs)
        >>> output.loss.backward()
        ```Nr   rb   r   )	r^   r`   ra   r%   rb   rc   rd   re   rf   r   )r$   r   rW   r   r   r   r   .r   )r   Tr   )r^   r_   rc   rd   re   Zreturn_dictr   )r#   r-   r(   r$   r%   r&   r'   r)   r*   r+   r,   r5   )r<   rd   re   r   r   r   r=   r   r   r   r   rW   allrE   r>   r   r   r   r   r#   r"   r%   r&   r'   r$   )rH   r^   r   r`   r   ra   r%   rb   r   rc   rd   re   rf   r   r   r   Zbackbone_outputsZbackbone_hidden_statesr   Zbackbone_logitsr#   r-   r(   Zdepth_decoder_outputsZbackbone_labelsZ
train_maskZdepth_decoder_input_idsZ
train_idxsZbackbone_last_hidden_statesZdepth_decoder_labelsr5   r5   r6   rx     s   V

(
z#CsmForConditionalGeneration.forwardsequence_lengthtarget_lengthr   
batch_sizec                 K   sD  | dur|   dkr| }|S t|j}tj||f|||jd}|dkr+tj|dd}|tj||jd|ddk9 }|ddddddf 	|ddd}| dur|
 }| jd }	|ddddddd|	f | ddddddf |j }
|
dk}
|ddddddd|	f |
||ddddddd|	f< |S )	aM  
        Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
        `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.

        Args:
            attention_mask (`torch.Tensor`):
                A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
                `(batch_size, 1, query_length, key_value_length)`.
            sequence_length (`int`):
                The sequence length being processed.
            target_length (`int`):
                The target length: when generating with static cache, the mask should be as long as the static cache,
                to account for the 0 padding, the part of the cache that is not filled yet.
            dtype (`torch.dtype`):
                The dtype to use for the 4D attention mask.
            cache_position (`torch.Tensor`):
                Indices depicting the position of the input sequence tokens in the sequence.
            batch_size (`torch.Tensor`):
                Batch size.
        N   )Z
fill_valuer   rj   r   )Zdiagonalri   r   r   )r   r2   Zfinfork   fullrj   Ztriurq   Zreshaper   clonerp   toZmasked_fill)r`   r   r   r   rf   r   r   rw   Z	min_dtypeZmask_lengthZpadding_maskr5   r5   r6   5_prepare_4d_causal_attention_mask_with_cache_positionF  s,    $
6  zQCsmForConditionalGeneration._prepare_4d_causal_attention_mask_with_cache_positionr   )NNNNNNNNNNNNr   )$r.   r/   r0   r   rV   r   r   r   r   r   classmethodr   r   r   r2   rz   r   ry   r   r3   r   r   r   r   r   r{   r   r   r   r   r"   rx   staticmethodr   r   r|   r5   r5   r\   r6   r     s    
S	

 )r   )r7   r   rS   r   r   )?dataclassesr   typingr   r   r   r   r2   Ztorch.nnr>   Zcache_utilsr   r	   Z
generationr
   Zmodeling_flash_attention_utilsr   Zmodeling_outputsr   r   Zmodeling_utilsr   Zprocessing_utilsr   utilsr   r   r   r   autor   Zllama.modeling_llamar   r   r   r   r   r   r   r   Zconfiguration_csmr   r    Zgeneration_csmr!   Z
get_loggerr.   rn   r"   r7   rG   rP   rQ   rR   r9   rS   ModulerD   r   r   r   r   __all__r5   r5   r5   r6   <module>   s\   (

:"}v   