o
    Zh}~                     @   s  d Z ddlZddlmZ ddlmZmZmZmZ ddl	Z	ddl
Z	ddl	mZ ddlmZ ddlmZmZmZ d	d
lmZ eeZeG dd deZeG dd deZeG dd deZG dd dejZG dd dejZG dd dejZG dd dejZG dd dejZG dd dejZ G dd dejZ!G dd  d ejZ"G d!d" d"ejZ#eG d#d$ d$eZ$ed%d&G d'd( d(e$Z%d(d$gZ&dS ))zPyTorch EnCodec model.    N)	dataclass)ListOptionalTupleUnion)nn   )PreTrainedModel)ModelOutputauto_docstringlogging   )EncodecConfigc                   @   6   e Zd ZU dZdZeej ed< dZ	eej
 ed< dS )EncodecOutputah  
    Args:
        audio_codes (`torch.LongTensor`  of shape `(batch_size, nb_chunks, chunk_length)`, *optional*):
            Discret code embeddings computed using `model.encode`.
        audio_values (`torch.FlaotTensor` of shape `(batch_size, sequence_length)`, *optional*)
            Decoded audio values, obtained using the decoder part of Encodec.
    Naudio_codesaudio_values)__name__
__module____qualname____doc__r   r   torch
LongTensor__annotations__r   FloatTensor r   r   [/var/www/auris/lib/python3.10/site-packages/transformers/models/encodec/modeling_encodec.pyr   (      
 r   c                   @   r   )EncodecEncoderOutputa  
    Args:
        audio_codes (`torch.LongTensor`  of shape `(batch_size, nb_chunks, chunk_length)`, *optional*):
            Discret code embeddings computed using `model.encode`.
        audio_scales (`torch.Tensor` of shape `(batch_size, nb_chunks)`, *optional*):
            Scaling factor for each `audio_codes` input. This is used to unscale each chunk of audio when decoding.
    Nr   audio_scales)r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   6   r   r   c                   @   s$   e Zd ZU dZdZeej ed< dS )EncodecDecoderOutputz
    Args:
        audio_values (`torch.FloatTensor`  of shape `(batch_size, segment_length)`, *optional*):
            Decoded audio values, obtained using the decoder part of Encodec.
    Nr   )	r   r   r   r   r   r   r   r   r   r   r   r   r   r    D   s   
 r    c                       s   e Zd ZdZ	ddededededef
 fdd	Zd
ejdejfddZe	dd
ejde
eef dedefddZdd Z  ZS )EncodecConv1dz;Conv1d with asymmetric or causal padding and normalization.r   in_channelsout_channelskernel_sizestridedilationc              	      sT  t    |j| _|j| _|j| _| jdvrtd| j |dkr5|dkr5td| d| d| d t	j
|||||d| _t	jj}tt	jjd	rPt	jjj}| jd	kr\|| j| _n| jd
krht	d|| _| jjd }tj| jjd tjd}| jjd }tj|d | d tjd}| jd|dd | jd|dd | jd|| dd d S )Nweight_normtime_group_normIself.norm_type must be one of `"weight_norm"`, `"time_group_norm"`), got r   zQEncodecConv1d has been initialized with stride > 1 and dilation > 1 (kernel_size=z stride=z, dilation=z).r&   r(   r)   r   )dtyper%   F)
persistentr$   padding_total)super__init__use_causal_convcausalpad_mode	norm_type
ValueErrorloggerwarningr   Conv1dconvutilsr(   hasattrparametrizations	GroupNormnormr$   r   tensorr%   int64r&   register_buffer)selfconfigr"   r#   r$   r%   r&   r(   	__class__r   r   r0   R   sB   





zEncodecConv1d.__init__hidden_statesreturnc                 C   sX   |j d }|| j | j | j d }t|tjd }|| j | j | j }|| S )zSee `pad_for_conv1d`.r   )shaper$   r.   r%   r   ceiltor@   )rB   rF   lengthZn_framesZideal_lengthr   r   r   _get_extra_padding_for_conv1d{   s
   
z+EncodecConv1d._get_extra_padding_for_conv1dzero        paddingsmodevaluec                 C   s   | j d }|\}}|dkstj| |||S t||}d}||kr0|| d }tj| d|f} tj| |||}	|	j d | }
|	dd|
f S )zTiny wrapper around torch.nn.functional.pad, just to allow for reflect padding on small input.
        If this is the case, we insert extra 0 padding to the right before the reflection happens.
        rH   Zreflectr   r   .N)rI   r   
functionalpadmax)rF   rP   rQ   rR   rL   padding_leftpadding_rightZmax_padZ	extra_padpaddedendr   r   r   _pad1d   s   

zEncodecConv1d._pad1dc                 C   sz   |  |}| jr| j|| j|f| jd}n| jd }| j| }| j|||| f| jd}| |}| jdkr;| |}|S )N)rQ      r)   )rM   r2   rZ   r.   r3   r9   r4   r>   )rB   rF   Zextra_paddingrW   rV   r   r   r   forward   s   





zEncodecConv1d.forward)r   r   )rN   rO   )r   r   r   r   intr0   r   TensorrM   staticmethodr   strfloatrZ   r\   __classcell__r   r   rD   r   r!   O   s,    )
(r!   c                	       s<   e Zd ZdZddedededef fddZd	d
 Z  ZS )EncodecConvTranspose1dzDConvTranspose1d with asymmetric or causal padding and normalization.r   r"   r#   r$   r%   c                    s   t    |j| _|j| _|j| _| jdvrtd| j t||||| _	tj
j}ttj
jdr7tj
jj}|jdkrC|| j	| _	n|jdkrOtd|| _| js[| jdks]tdd S d S )Nr'   r*   r(   r)   r         ?zB`trim_right_ratio` != 1.0 only makes sense for causal convolutions)r/   r0   r1   r2   trim_right_ratior4   r5   r   ZConvTranspose1dr9   r:   r(   r;   r<   r=   r>   )rB   rC   r"   r#   r$   r%   r(   rD   r   r   r0      s&   





zEncodecConvTranspose1d.__init__c                 C   s   | j jd }| j jd }|| }|  |}| jdkr| |}| jr+t|| j }n|d }|| }|j	d | }|d||f }|S )Nr   r)   r[   rH   .)
r9   r$   r%   r4   r>   r2   mathrJ   re   rI   )rB   rF   r$   r%   r.   rW   rV   rY   r   r   r   r\      s   


zEncodecConvTranspose1d.forward)r   )r   r   r   r   r]   r0   r\   rb   r   r   rD   r   rc      s     rc   c                       s(   e Zd ZdZ fddZdd Z  ZS )EncodecLSTMzz
    LSTM without worrying about the hidden state, nor the layout of the data. Expects input as convolutional layout.
    c                    s    t    t|||j| _d S N)r/   r0   r   LSTMZnum_lstm_layerslstm)rB   rC   	dimensionrD   r   r   r0      s   
zEncodecLSTM.__init__c                 C   s2   | ddd}| |d | }| ddd}|S )Nr[   r   r   )permuterj   )rB   rF   r   r   r   r\      s   zEncodecLSTM.forward)r   r   r   r   r0   r\   rb   r   r   rD   r   rg      s    rg   c                       s:   e Zd ZdZdededee f fddZdd Z  Z	S )	EncodecResnetBlockz>
    Residual block from SEANet model as used by EnCodec.
    rC   dim	dilationsc              	      s   t    |jdf}t|t|krtd||j }g }tt||D ]-\}\}}	|dkr0|n|}
|t|d kr<|n|}|t	 g7 }|t
||
|||	dg7 }q$t|| _|jrft
|||dd| _d S t | _d S )Nr   z7Number of kernel sizes should match number of dilationsr   r+   )r$   )r/   r0   Zresidual_kernel_sizelenr5   compress	enumeratezipr   ELUr!   
ModuleListblockZuse_conv_shortcutshortcutZIdentity)rB   rC   rn   ro   Zkernel_sizesZhiddenrv   ir$   r&   Zin_chsZout_chsrD   r   r   r0      s   


zEncodecResnetBlock.__init__c                 C   s&   |}| j D ]}||}q| || S rh   )rv   rw   )rB   rF   residuallayerr   r   r   r\     s   

zEncodecResnetBlock.forward)
r   r   r   r   r   r]   r   r0   r\   rb   r   r   rD   r   rm      s    rm   c                       .   e Zd ZdZdef fddZdd Z  ZS )EncodecEncoderz"SEANet encoder as used by EnCodec.rC   c              	      s   t    t||j|j|jg}d}t|jD ]8}||j }t|j	D ]}|t
|||j| dgg7 }q#|t g7 }|t|||d |d |dg7 }|d9 }q|t|||j g7 }|t g7 }|t|||j |j|jg7 }t|| _d S )Nr   r[   r$   r%   )r/   r0   r!   audio_channelsnum_filtersr$   reversedupsampling_ratiosrangenum_residual_layersrm   dilation_growth_rater   rt   rg   hidden_sizelast_kernel_sizeru   layers)rB   rC   modelscalingratiocurrent_scalejrD   r   r   r0     s   

 
zEncodecEncoder.__init__c                 C      | j D ]}||}q|S rh   r   rB   rF   rz   r   r   r   r\   3     

zEncodecEncoder.forwardr   r   r   r   r   r0   r\   rb   r   r   rD   r   r|     s    r|   c                       r{   )EncodecDecoderz"SEANet decoder as used by EnCodec.rC   c              	      s   t    tdt|j }t||j||j |jg}|t	|||j g7 }|jD ]:}||j }|t
 g7 }|t|||d |d |dg7 }t|jD ]}|t||d |j| dfg7 }qL|d }q)|t
 g7 }|t||j|j|jg7 }t
|| _d S )Nr[   r}   r   )r/   r0   r]   rp   r   r!   r   r   r$   rg   r   rt   rc   r   r   rm   r   r~   r   ru   r   )rB   rC   r   r   r   r   r   rD   r   r   r0   <  s    


"
zEncodecDecoder.__init__c                 C   r   rh   r   r   r   r   r   r\   U  r   zEncodecDecoder.forwardr   r   r   rD   r   r   9  s    r   c                       s>   e Zd ZdZdef fddZdd Zdd Zd	d
 Z  Z	S )EncodecEuclideanCodebookz!Codebook with Euclidean distance.rC   c                    sj   t    t|j|j}|j| _| dtdg | dt|j | d| | d|  d S )NinitedTZcluster_sizeembedZ	embed_avg)	r/   r0   r   zeroscodebook_sizeZcodebook_dimrA   r^   clone)rB   rC   r   rD   r   r   r0   ^  s   
z!EncodecEuclideanCodebook.__init__c                 C   sV   | j  }|djddd}|d| |  |djddd  }|jddj}|S )Nr[   r   Tkeepdimr   rH   )rn   )r   tpowsumrU   indices)rB   rF   r   Zscaled_statesdist	embed_indr   r   r   quantizei  s
   
&z!EncodecEuclideanCodebook.quantizec                 C   s8   |j }|d|d f}| |}|j|d d  }|S )NrH   )rI   Zreshaper   view)rB   rF   rI   r   r   r   r   encodep  s
   
zEncodecEuclideanCodebook.encodec                 C   s   t j|| j}|S rh   )r   rS   Z	embeddingr   rB   r   r   r   r   r   decodez  s   zEncodecEuclideanCodebook.decode)
r   r   r   r   r   r0   r   r   r   rb   r   r   rD   r   r   [  s    
r   c                       s6   e Zd ZdZdef fddZdd Zdd Z  ZS )	EncodecVectorQuantizationzY
    Vector quantization implementation. Currently supports only euclidean distance.
    rC   c                    s   t    t|| _d S rh   )r/   r0   r   codebookrB   rC   rD   r   r   r0     s   
z"EncodecVectorQuantization.__init__c                 C   s   | ddd}| j|}|S Nr   r[   r   )rl   r   r   )rB   rF   Zembed_inr   r   r   r     s   z EncodecVectorQuantization.encodec                 C   s   | j |}|ddd}|S r   )r   r   rl   r   r   r   r   r     s   z EncodecVectorQuantization.decode)	r   r   r   r   r   r0   r   r   rb   r   r   rD   r   r     s
    r   c                       st   e Zd ZdZdef fddZddee defdd	Z	dd
e
jdee de
jfddZde
jde
jfddZ  ZS )EncodecResidualVectorQuantizerzResidual Vector Quantizer.rC   c                    sF   t     j| _ j| _ j| _t fddt jD | _d S )Nc                    s   g | ]}t  qS r   )r   ).0_rC   r   r   
<listcomp>  s    z;EncodecResidualVectorQuantizer.__init__.<locals>.<listcomp>)	r/   r0   r   
frame_ratenum_quantizersr   ru   r   r   r   rD   r   r   r0     s
   
$z'EncodecResidualVectorQuantizer.__init__N	bandwidthrG   c                 C   sH   t | j| j }| j}|dur"|dkr"ttdt |d | }|S )z:Return num_quantizers based on specified target bandwidth.NrO   r   i  )rf   log2r   r   r   r]   rU   floor)rB   r   Zbw_per_qr   r   r   r    get_num_quantizers_for_bandwidth  s
   z?EncodecResidualVectorQuantizer.get_num_quantizers_for_bandwidth
embeddingsc           
      C   sZ   |  |}|}g }| jd| D ]}||}||}|| }|| qt|}	|	S )z
        Encode a given input tensor with the specified frame rate at the given bandwidth. The RVQ encode method sets
        the appropriate number of quantizers to use and returns indices for each quantizer.
        N)r   r   r   r   appendr   stack)
rB   r   r   r   ry   Zall_indicesrz   r   	quantizedZout_indicesr   r   r   r     s   



z%EncodecResidualVectorQuantizer.encodecodesc                 C   sB   t jd|jd}t|D ]\}}| j| }||}|| }q|S )z7Decode the given codes to the quantized representation.rO   )device)r   r?   r   rr   r   r   )rB   r   Zquantized_outrx   r   rz   r   r   r   r   r     s   


z%EncodecResidualVectorQuantizer.decoderh   )r   r   r   r   r   r0   r   ra   r]   r   r   r^   r   r   rb   r   r   rD   r   r     s     r   c                   @   s    e Zd ZeZdZdZdd ZdS )EncodecPreTrainedModelZencodecinput_valuesc                 C   sf  t |tjr |jjjd| jjd |jdur|jj	  dS dS t |tj
tjfr8|jj	  |jjd dS t |tjrgtj|j |jduret|j|j|jd   }tjj|j| |d dS dS t |tjr|jjjd| jjd |jdur|jj|j 	  dS dS t |tjr| D ]\}}d|v rtj| qd|v rtj|d qdS dS )	zInitialize the weightsrO   )meanZstdNrd   r   )abweightbias)
isinstancer   ZLinearr   dataZnormal_rC   Zinitializer_ranger   Zzero_Z	LayerNormr=   Zfill_r8   initZkaiming_normal_rf   sqrtgroupsr"   r$   Zuniform_Z	EmbeddingZpadding_idxri   Znamed_parametersZxavier_uniform_Z	constant_)rB   moduleknameparamr   r   r   _init_weights  s8   


z$EncodecPreTrainedModel._init_weightsN)r   r   r   r   Zconfig_classZbase_model_prefixZmain_input_namer   r   r   r   r   r     s
    r   z/
    The EnCodec neural audio codec model.
    )Zcustom_introc                       s  e Zd Zdef fddZdd Zdd Zdejd	e	d
e
deejeej f fddZ			d dejd
eej d	ee	 dee deeejeej f ef f
ddZedeej de
fddZd!dejdeej dejfddZ		d"dejdejd
eej dee deeejejf ef f
ddZe					d#dejd
eej d	ee	 deej deej dee deeejejf ef fddZ  ZS )$EncodecModelrC   c                    sj   t  | || _t|| _t|| _t|| _t	t
| jj| _d| j | jjkr/td|   d S )Nr[   z'The codebook_size must be a power of 2.)r/   r0   rC   r|   encoderr   decoderr   	quantizerr]   rf   r   r   Zbits_per_codebookr5   Z	post_initr   rD   r   r   r0     s   


zEncodecModel.__init__c                 C      | j S rh   )r   rB   r   r   r   get_encoder     zEncodecModel.get_encoderc                 C   r   rh   )r   r   r   r   r   get_decoder  r   zEncodecModel.get_decoderr   r   padding_maskrG   c           
      C   s   |j d }|| jj }| jjdur%|d| jj kr%td| d| jj d}| jjrQ||d }tj|ddd|j d  }|	d	j
ddd
 d }|| }| |}| j||}	|	dd}	|	|fS )z
        Encodes the given input using the underlying VQVAE. If `config.normalize` is set to `True` the input is first
        normalized. The padding mask is required to compute the correct scale.
        rH   Ngh㈵>zDuration of frame (z) is longer than chunk r   Tr   r[   )rn   r   g:0yE>r   )rI   rC   Zsampling_rateZchunk_length_sRuntimeError	normalizeZ	unsqueezer   r   r   r   r   r   r   r   	transpose)
rB   r   r   r   rL   durationscalemonor   r   r   r   r   _encode_frame  s   

zEncodecModel._encode_frameNreturn_dictc                 C   sj  |dur|n| j j}|du r| j jd }|| j jvr'td| d| j j d|j\}}}|dk s5|dkr<td| | j j}|du rI|}|}	n| j j}	|du rXt|	 }g }
g }||	 }||	 | dkrltd	t
d|| |	D ]0}|d
||| f 	 }|dddd||| f }| |||\}}|
| || qtt|
}
|s|
|fS t|
|S )a  
        Encodes the input audio waveform into discrete codes.

        Args:
            input_values (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
                Float values of the input audio waveform.
            padding_mask (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
                Padding mask used to pad the `input_values`.
            bandwidth (`float`, *optional*):
                The target bandwidth. Must be one of `config.target_bandwidths`. If `None`, uses the smallest possible
                bandwidth. bandwidth is represented as a thousandth of what it is, e.g. 6kbps bandwidth is represented
                as bandwidth == 6.0

        Returns:
            A list of frames containing the discrete encoded codes for the input audio waveform, along with rescaling
            factors for each chunk when `normalize` is True. Each frames is a tuple `(codebook, scale)`, with
            `codebook` of shape `[batch_size, num_codebooks, frames]`.
        Nr   z)This model doesn't support the bandwidth z. Select one of .r   r[   z1Number of audio channels must be 1 or 2, but got zkThe input length is not properly padded for batched chunked decoding. Make sure to pad the input correctly..)rC   r   Ztarget_bandwidthsr5   rI   chunk_lengthchunk_strider   	ones_likeboolr   r   r   r   r   )rB   r   r   r   r   r   ZchannelsZinput_lengthr   r%   Zencoded_framesscalesstepoffsetmaskframeZencoded_framer   r   r   r   r     sD   


zEncodecModel.encodeframesr%   c                 C   sN  t | dkr
td| d j}| d j}| d jd d }|t | d  | d jd  }| d jd }tjdd|d ||ddd }d|d   }tj|||d}	tjg ||R ||d}
d}| D ]/}|jd }|
d||| f  |d | | 7  < |	|||   |d | 7  < ||7 }qe|		 dkrtd	|	 d
|
|	 S )Nr   z!`frames` cannot be an empty list.rH   r   r[   )r   r,   g      ?.z7`sum_weight` minimum element must be bigger than zero: `)
rp   r5   r   r,   rI   r   Zlinspaceabsr   min)r   r%   r   r,   rI   
total_sizeZframe_lengthZtime_vecr   Z
sum_weightoutr   r   r   r   r   _linear_overlap_add]  s(   

 
( 
z EncodecModel._linear_overlap_addr   r   c                 C   s@   | dd}| j|}| |}|d ur||ddd }|S )Nr   r   rH   )r   r   r   r   r   )rB   r   r   r   Zoutputsr   r   r   _decode_frame  s   
zEncodecModel._decode_framer   r   c                 C   s   |dur|n| j j}| j j}|du r,t|dkr!tdt| | |d |d }n!g }t||D ]\}}	| ||	}
||
 q3| || j j	pKd}|durf|j
d |j
d k rf|dd|j
d f }|sk|fS t|S )aM  
        Decodes the given frames into an output audio waveform.

        Note that the output might be a bit bigger than the input. In that case, any extra steps at the end can be
        trimmed.

        Args:
            audio_codes (`torch.LongTensor`  of shape `(batch_size, nb_chunks, chunk_length)`, *optional*):
                Discret code embeddings computed using `model.encode`.
            audio_scales (`torch.Tensor` of shape `(batch_size, nb_chunks)`, *optional*):
                Scaling factor for each `audio_codes` input.
            padding_mask (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
                Padding mask used to pad the `input_values`.
            return_dict (`bool`, *optional*):
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.

        Nr   zExpected one frame, got r   rH   .)rC   r   r   rp   r5   r   rs   r   r   r   rI   r    )rB   r   r   r   r   r   r   Zdecoded_framesr   r   r   r   r   r   r     s    zEncodecModel.decodec                 C   s   |dur|n| j j}|du rt| }|dur!|du r!td|dur-|du r-td|du r?|du r?| |||d\}}| j||||dd }|sP||fS t||dS )aN	  
        input_values (`torch.FloatTensor` of shape `(batch_size, channels, sequence_length)`, *optional*):
            Raw audio input converted to Float and padded to the appropriate length in order to be encoded using chunks
            of length self.chunk_length and a stride of `config.chunk_stride`.
        padding_mask (`torch.BoolTensor` of shape `(batch_size, channels, sequence_length)`, *optional*):
            Mask to avoid computing scaling factors on padding token indices (can we avoid computing conv on these+).
            Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            <Tip warning={true}>

            `padding_mask` should always be passed, unless the input was truncated or not padded. This is because in
            order to process tensors effectively, the input audio should be padded so that `input_length % stride =
            step` with `step = chunk_length-stride`. This ensures that all chunks are of the same shape

            </Tip>
        bandwidth (`float`, *optional*):
            The target bandwidth. Must be one of `config.target_bandwidths`. If `None`, uses the smallest possible
            bandwidth. bandwidth is represented as a thousandth of what it is, e.g. 6kbps bandwidth is represented as
            `bandwidth == 6.0`
        audio_codes (`torch.LongTensor`  of shape `(batch_size, nb_chunks, chunk_length)`, *optional*):
            Discret code embeddings computed using `model.encode`.
        audio_scales (`torch.Tensor` of shape `(batch_size, nb_chunks)`, *optional*):
            Scaling factor for each `audio_codes` input.

        Examples:

        ```python
        >>> from datasets import load_dataset
        >>> from transformers import AutoProcessor, EncodecModel

        >>> dataset = load_dataset("hf-internal-testing/ashraq-esc50-1-dog-example")
        >>> audio_sample = dataset["train"]["audio"][0]["array"]

        >>> model_id = "facebook/encodec_24khz"
        >>> model = EncodecModel.from_pretrained(model_id)
        >>> processor = AutoProcessor.from_pretrained(model_id)

        >>> inputs = processor(raw_audio=audio_sample, return_tensors="pt")

        >>> outputs = model(**inputs)
        >>> audio_codes = outputs.audio_codes
        >>> audio_values = outputs.audio_values
        ```NzBYou specified `audio_codes` but did not specify the `audio_scales`zBYou specified `audio_scales` but did not specify the `audio_codes`F)r   r   )r   r   )	rC   r   r   r   r   r5   r   r   r   )rB   r   r   r   r   r   r   r   r   r   r   r\     s   8zEncodecModel.forward)NNNrh   )NN)NNNNN)r   r   r   r   r0   r   r   r   r^   ra   r]   r   r   r   r   r   r   r   r_   r   r   r   r    r   r   r   r\   rb   r   r   rD   r   r     s    

H".
0r   )'r   rf   dataclassesr   typingr   r   r   r   r   Ztorch.utils.checkpointr   Zmodeling_utilsr	   r:   r
   r   r   Zconfiguration_encodecr   Z
get_loggerr   r6   r   r   r    Moduler!   rc   rg   rm   r|   r   r   r   r   r   r   __all__r   r   r   r   <module>   sF   

a9!"$,  /