o
    Zhe                     @   s   d Z ddlmZ ddlmZmZmZ ddlZddlZddlm	Z	 ddl
mZ ddlmZ dd	lmZmZ d
dlmZ eeZeG dd deZG dd de	jZG dd de	jZG dd de	jZG dd de	jZeG dd deZdgZdS )zPyTorch UnivNetModel model.    )	dataclass)OptionalTupleUnionN)nn   )ModelOutput)PreTrainedModel)auto_docstringlogging   )UnivNetConfigc                   @   s6   e Zd ZU dZdZeej ed< dZ	eej ed< dS )UnivNetModelOutputa  
    Output class for the [`UnivNetModel`], which includes the generated audio waveforms and the original unpadded
    lengths of those waveforms (so that the padding can be removed by [`UnivNetModel.batch_decode`]).

    Args:
        waveforms (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
            Batched 1D (mono-channel) output audio waveforms.
        waveform_lengths (`torch.FloatTensor` of shape `(batch_size,)`):
            The batched length in samples of each unpadded waveform in `waveforms`.
    N	waveformswaveform_lengths)
__name__
__module____qualname____doc__r   r   torchFloatTensor__annotations__r    r   r   [/var/www/auris/lib/python3.10/site-packages/transformers/models/univnet/modeling_univnet.pyr       s   
 r   c                       sF   e Zd ZdZdef fddZdejfddZdd	 Z	d
d Z
  ZS )#UnivNetKernelPredictorResidualBlockz
    Implementation of the residual block for the kernel predictor network inside each location variable convolution
    block (LVCBlock).

    Parameters:
        config: (`UnivNetConfig`):
            Config for the `UnivNetModel` model.
    configc                    s   t    |j| _|j| _|j| _|j| _| jd d }t	
| j| _t	j| j| j| j|dd| _t	j| j| j| j|dd| _d S )Nr      Tpaddingbias)super__init__model_in_channelsZchannelskernel_predictor_conv_sizekernel_sizeZkernel_predictor_dropoutZdropout_probleaky_relu_sloper   ZDropoutdropoutConv1dconv1conv2)selfr   r   	__class__r   r   r!   ;   s   
 z,UnivNetKernelPredictorResidualBlock.__init__hidden_statesc                 C   sJ   |}|  |}| |}tj|| j}| |}tj|| j}|| S N)r&   r(   r   
functional
leaky_relur%   r)   )r*   r-   residualr   r   r   forwardK   s   


z+UnivNetKernelPredictorResidualBlock.forwardc                 C   s8   t jj}tt jjdrt jjj}|| j || j d S Nweight_norm)r   utilsr4   hasattrparametrizationsr(   r)   r*   r4   r   r   r   apply_weight_normU   s
   

z5UnivNetKernelPredictorResidualBlock.apply_weight_normc                 C   s    t j| j t j| j d S r.   )r   r5   remove_weight_normr(   r)   r*   r   r   r   r:   ]   s   z6UnivNetKernelPredictorResidualBlock.remove_weight_norm)r   r   r   r   r   r!   r   r   r2   r9   r:   __classcell__r   r   r+   r   r   1   s    	
r   c                       sT   e Zd ZdZ		ddededef fddZd	ejfd
dZ	dd Z
dd Z  ZS )UnivNetKernelPredictora  
    Implementation of the kernel predictor network which supplies the kernel and bias for the location variable
    convolutional layers (LVCs) in each UnivNet LVCBlock.

    Based on the KernelPredictor implementation in
    [maum-ai/univnet](https://github.com/maum-ai/univnet/blob/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/model/lvcnet.py#L7).

    Parameters:
        config: (`UnivNetConfig`):
            Config for the `UnivNetModel` model.
        conv_kernel_size (`int`, *optional*, defaults to 3):
            The kernel size for the location variable convolutional layer kernels (convolutional weight tensor).
        conv_layers (`int`, *optional*, defaults to 4):
            The number of location variable convolutional layers to output kernels and biases for.
    r      r   conv_kernel_sizeconv_layersc                    s   t     j| _d j | _|| _|| _| j| j | j | j | _| j| j | _ j	| _
 j| _ j| _ j| _ j| _| jd d }tj| j
| jdddd| _t fddt| jD | _tj| j| j| j|dd| _tj| j| j| j|dd| _d S )Nr   r      Tr   c                    s   g | ]}t  qS r   )r   ).0_r   r   r   
<listcomp>   s    z3UnivNetKernelPredictor.__init__.<locals>.<listcomp>)r    r!   model_hidden_channelsconv_in_channelsconv_out_channelsr?   r@   Zkernel_channelsZbias_channelsZnum_mel_binsZresnet_in_channelsZ kernel_predictor_hidden_channelsZresnet_hidden_channelsr#   Zresnet_kernel_sizeZkernel_predictor_num_blocks
num_blocksr%   r   r'   
input_conv
ModuleListrange	resblockskernel_conv	bias_conv)r*   r   r?   r@   r   r+   rD   r   r!   s   s,   
 zUnivNetKernelPredictor.__init__spectrogramc                 C   s   |j \}}}| |}tj|| j}| jD ]}||}q| |}| |}|	|| j
| j| j| j| }	|	|| j
| j| }
|	|
fS )a  
        Maps a conditioning log-mel spectrogram to a tensor of convolutional kernels and biases, for use in location
        variable convolutional layers. Note that the input spectrogram should have shape (batch_size, input_channels,
        seq_length).

        Args:
            spectrogram (`torch.FloatTensor` of shape `(batch_size, input_channels, seq_length)`):
                Tensor containing the log-mel spectrograms.

        Returns:
            Tuple[`torch.FloatTensor, `torch.FloatTensor`]: tuple of tensors where the first element is the tensor of
            location variable convolution kernels of shape `(batch_size, self.conv_layers, self.conv_in_channels,
            self.conv_out_channels, self.conv_kernel_size, seq_length)` and the second element is the tensor of
            location variable convolution biases of shape `(batch_size, self.conv_layers. self.conv_out_channels,
            seq_length)`.
        )shaperJ   r   r/   r0   r%   rM   rN   rO   viewr@   rG   rH   r?   
contiguous)r*   rP   Z
batch_sizerC   Z
seq_lengthr-   resblockZkernel_hidden_statesZbias_hidden_stateskernelsbiasesr   r   r   r2      s4   




zUnivNetKernelPredictor.forwardc                 C   sV   t jj}tt jjdrt jjj}|| j | jD ]}|  q|| j || j	 d S r3   )
r   r5   r4   r6   r7   rJ   rM   r9   rN   rO   r*   r4   layerr   r   r   r9      s   




z(UnivNetKernelPredictor.apply_weight_normc                 C   sB   t j| j | jD ]}|  q
t j| j t j| j d S r.   )r   r5   r:   rJ   rM   rN   rO   r*   rX   r   r   r   r:      s
   

z)UnivNetKernelPredictor.remove_weight_norm)r   r>   r   r   r   r   r   intr!   r   r   r2   r9   r:   r<   r   r   r+   r   r=   b   s    &.r=   c                       sr   e Zd ZdZdededef fddZddd	Z	
	ddej	dej	dej	dedef
ddZ
dd Zdd Z  ZS )UnivNetLvcResidualBlocka  
    Implementation of the location variable convolution (LVC) residual block for the UnivNet residual network.

    Parameters:
        config: (`UnivNetConfig`):
            Config for the `UnivNetModel` model.
        kernel_size (`int`):
            The kernel size for the dilated 1D convolutional layer.
        dilation (`int`):
            The dilation for the dilated 1D convolutional layer.
    r   r$   dilationc                    s\   t    |j| _|| _|| _|j| _| j| jd  d }tj| j| j| j|| jd| _	d S )Nr   r   )r   r]   )
r    r!   rF   hidden_channelsr$   r]   r%   r   r'   conv)r*   r   r$   r]   r   r+   r   r   r!      s   
z UnivNetLvcResidualBlock.__init__   c                 C   s   |}t j|| j}| |}t j|| j}| j||||d}t|d d d | jd d f t	|d d | jd d d f  }|| }|S N)hop_size)
r   r/   r0   r%   r_   location_variable_convolutionr   Zsigmoidr^   tanh)r*   r-   kernelr   rb   r1   r   r   r   r2      s   
$zUnivNetLvcResidualBlock.forwardr   r-   re   r   rb   c                 C   sB  |j \}}}|j \}}}	}
}||| kr!td||  d| d|t|
d d  }tj|||fdd}|d|d|  |}||k rPtj|d|fdd}|d||}|d	d	d	d	d	d	d	d	d	|f }|dd
}|d
|
d}t	d||}|j
tjd}|ddj
tjd}|| }| ||	d}|S )u  
        Performs location-variable convolution operation on the input sequence (hidden_states) using the local
        convolution kernel. This was introduced in [LVCNet: Efficient Condition-Dependent Modeling Network for Waveform
        Generation](https://arxiv.org/abs/2102.10815) by Zhen Zheng, Jianzong Wang, Ning Cheng, and Jing Xiao.

        Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100.

        Args:
            hidden_states (`torch.FloatTensor` of shape `(batch_size, in_channels, in_length)`):
                The input sequence of shape (batch, in_channels, in_length).
            kernel (`torch.FloatTensor` of shape `(batch_size, in_channels, out_channels, kernel_size, kernel_length)`):
                The local convolution kernel of shape (batch, in_channels, out_channels, kernel_size, kernel_length).
            bias (`torch.FloatTensor` of shape `(batch_size, out_channels, kernel_length)`):
                The bias for the local convolution of shape (batch, out_channels, kernel_length).
            dilation (`int`, *optional*, defaults to 1):
                The dilation of convolution.
            hop_size (`int`, *optional*, defaults to 256):
                The hop_size of the conditioning sequence.
        Returns:
            `torch.FloatTensor`: the output sequence after performing local convolution with shape (batch_size,
            out_channels, in_length).
        z#Dim 2 of `hidden_states` should be z
) but got zX. Please check `hidden_states` or `kernel` and `hop_size` to make sure they are correct.r   r   Zconstantr   r   Nr>   zbildsk,biokl->bolsd)Zmemory_format)rQ   
ValueErrorr[   r   r/   padZunfold	transposer   ZeinsumtoZchannels_last_3d	unsqueezerS   rR   )r*   r-   re   r   r]   rb   batchrC   Z	in_lengthZout_channelsr$   Zkernel_lengthr   Zoutput_hidden_statesr   r   r   rc     s*   &z5UnivNetLvcResidualBlock.location_variable_convolutionc                 C   s.   t jj}tt jjdrt jjj}|| j d S r3   )r   r5   r4   r6   r7   r_   r8   r   r   r   r9   N  s   
z)UnivNetLvcResidualBlock.apply_weight_normc                 C   s   t j| j d S r.   )r   r5   r:   r_   r;   r   r   r   r:   U  s   z*UnivNetLvcResidualBlock.remove_weight_normr`   )r   r`   )r   r   r   r   r   r[   r!   r2   r   r   rc   r9   r:   r<   r   r   r+   r   r\      s2    

Ar\   c                       sX   e Zd ZdZ	ddededef fddZdejd	ejfd
dZ	dd Z
dd Z  ZS )UnivNetLvcBlocka#  
    Implementation of the location variable convolution (LVC) residual block of the UnivNet residual block. Includes a
    `UnivNetKernelPredictor` inside to predict the kernels and biases of the LVC layers.

    Based on LVCBlock in
    [maum-ai/univnet](https://github.com/maum-ai/univnet/blob/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/model/lvcnet.py#L98)

    Parameters:
        config (`UnivNetConfig`):
            Config for the `UnivNetModel` model.
        layer_id (`int`):
            An integer corresponding to the index of the current LVC resnet block layer. This should be between 0 and
            `len(config.resblock_stride_sizes) - 1)` inclusive.
        lvc_hop_size (`int`, *optional*, defaults to 256):
            The hop size for the location variable convolutional layers.
    r`   r   layer_idlvc_hop_sizec                    s   t     j_ j| _ j| _ j| _	|_
 j_tj	_tjjjdj jjd jd  jd d_t jj_t fddtjD _d S )Nr   )strider   Zoutput_paddingc                    s    g | ]}t  jj| qS r   )r\   r$   	dilationsrB   ir   r*   r   r   rE     s     z,UnivNetLvcBlock.__init__.<locals>.<listcomp>)r    r!   rF   r^   resblock_kernel_sizesr$   resblock_stride_sizesrq   Zresblock_dilation_sizesrr   cond_hop_lengthr%   lenrI   r   ConvTranspose1d	convt_prer=   kernel_predictorrK   rL   rM   )r*   r   ro   rp   r+   ru   r   r!   k  s(   
	
zUnivNetLvcBlock.__init__r-   rP   c           	   	   C   s   t j|| j}| |}| |\}}t| jD ]/\}}|d d |d d d d d d d d f }|d d |d d d d f }||||| jd}q|S ra   )	r   r/   r0   r%   r{   r|   	enumeraterM   rx   )	r*   r-   rP   rU   rV   rt   rT   re   r   r   r   r   r2     s   
(zUnivNetLvcBlock.forwardc                 C   sL   t jj}tt jjdrt jjj}|| j | j  | jD ]}|  qd S r3   )	r   r5   r4   r6   r7   r{   r|   r9   rM   rW   r   r   r   r9     s   




z!UnivNetLvcBlock.apply_weight_normc                 C   s0   t j| j | j  | jD ]}|  qd S r.   )r   r5   r:   r{   r|   rM   rY   r   r   r   r:     s
   


z"UnivNetLvcBlock.remove_weight_normrm   rZ   r   r   r+   r   rn   Y  s    
rn   c                       s   e Zd ZeZdZdef fddZe				ddej	de
ej	 de
ej	 de
ej d	e
e d
eeej	 ef fddZdd Zdd Zdd Z  ZS )UnivNetModelinput_featuresr   c                    s   t    t j| _ j| _tj j j	ddddd| _
t j}d}g  jD ]}|| }| q*t fddt|D | _tj j	ddddd| _|   d S )	N   r   r   Zreflect)r$   rq   r   padding_modec                    s   g | ]}t  || d qS ))ro   rp   )rn   rs   r   Zhop_lengthsr   r   rE     s    z)UnivNetModel.__init__.<locals>.<listcomp>)r   r   )r    r!   ry   rv   Znum_kernelsr%   r   r'   r"   rF   conv_prerw   appendrK   rL   rM   	conv_postZ	post_init)r*   r   Z
num_layersZ
hop_lengthrq   r+   r   r   r!     s0   


zUnivNetModel.__init__Nnoise_sequencepadding_mask	generatorreturn_dictreturnc                 C   s  |dur|n| j j}| dk}|s|d}|j\}}}	|dur/| dk}
|
s.|d}n||| j jf}tj|||j|j	d}|jd }|dkrV|dkrV|
|dd}n|dkre|dkre|
|dd}||krttd| d| d|dur| dkr|d}|jd }||krtd	| d| d|d
d}|d
d}| |}| jD ]}|||}qtj|| j}| |}t|}|d}d}|durtj|dd}|s||f}|S t||dS )a  
        input_features (`torch.FloatTensor`):
            Tensor containing the log-mel spectrograms. Can be batched and of shape `(batch_size, sequence_length,
            config.num_mel_channels)`, or un-batched and of shape `(sequence_length, config.num_mel_channels)`.
        noise_sequence (`torch.FloatTensor`, *optional*):
            Tensor containing a noise sequence of standard Gaussian noise. Can be batched and of shape `(batch_size,
            sequence_length, config.model_in_channels)`, or un-batched and of shape (sequence_length,
            config.model_in_channels)`. If not supplied, will be randomly generated.
        padding_mask (`torch.BoolTensor`, *optional*):
            Mask indicating which parts of each sequence are padded. Mask values are selected in `[0, 1]`:

            - 1 for tokens that are **not masked**
            - 0 for tokens that are **masked**

            The mask can be batched and of shape `(batch_size, sequence_length)` or un-batched and of shape
            `(sequence_length,)`.
        generator (`torch.Generator`, *optional*):
            A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
            deterministic.
            return_dict:
            Whether to return a [`~utils.ModelOutput`] subclass instead of a plain tuple.

        Example:

         ```python
         >>> from transformers import UnivNetFeatureExtractor, UnivNetModel
         >>> from datasets import load_dataset, Audio

         >>> model = UnivNetModel.from_pretrained("dg845/univnet-dev")
         >>> feature_extractor = UnivNetFeatureExtractor.from_pretrained("dg845/univnet-dev")

         >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
         >>> # Resample the audio to the feature extractor's sampling rate.
         >>> ds = ds.cast_column("audio", Audio(sampling_rate=feature_extractor.sampling_rate))
         >>> inputs = feature_extractor(
         ...     ds[0]["audio"]["array"], sampling_rate=ds[0]["audio"]["sampling_rate"], return_tensors="pt"
         ... )
         >>> audio = model(**inputs).waveforms
         >>> list(audio.shape)
         [1, 140288]
         ```
        Nr   r   )r   dtypedevicer   z&The batch size of `noise_sequence` is z+ and the batch size of `input_features` is z', but the two are expected to be equal.z$The batch size of `padding_mask` is r   )dim)r   r   )r   Zuse_return_dictr   rk   rQ   r"   r   Zrandnr   r   repeatrg   ri   r   rM   r   r/   r0   r%   r   rd   Zsqueezesumr   )r*   r   r   r   r   r   Zspectrogram_batchedZspectrogram_batch_sizeZspectrogram_lengthrC   Znoise_sequence_batchedZnoise_sequence_shapeZnoise_sequence_batch_sizeZpadding_mask_batch_sizer-   rT   Zwaveformr   Zoutputsr   r   r   r2     sl   3









zUnivNetModel.forwardc                 C   sN   t |tjtjtjfr#|jjjd| jj	d |j
dur%|j
j  dS dS dS )zInitialize the weights.g        )meanZstdN)
isinstancer   ZLinearr'   rz   weightdataZnormal_r   Zinitializer_ranger   Zzero_)r*   moduler   r   r   _init_weightsS  s   
zUnivNetModel._init_weightsc                 C   sL   t jj}tt jjdrt jjj}|| j | jD ]}|  q|| j d S r3   )	r   r5   r4   r6   r7   r   rM   r9   r   rW   r   r   r   r9   Z  s   



zUnivNetModel.apply_weight_normc                 C   s4   t j| j | jD ]}|  q
t j| j d S r.   )r   r5   r:   r   rM   r   rY   r   r   r   r:   d  s   

zUnivNetModel.remove_weight_norm)NNNN)r   r   r   r   Zconfig_classZmain_input_namer!   r
   r   r   r   	Generatorboolr   r   r   r2   r   r9   r:   r<   r   r   r+   r   r~     s2    '}
r~   )r   dataclassesr   typingr   r   r   r   Ztorch.utils.checkpointr   Zmodeling_outputsr   Zmodeling_utilsr	   r5   r
   r   Zconfiguration_univnetr   Z
get_loggerr   loggerr   Moduler   r=   r\   rn   r~   __all__r   r   r   r   <module>   s*   
1xP 
B