a
    hZ                     @   s.  d dl Z d dlmZmZmZmZ d dlZd dlmZmZ d dl	m
Z dgZdeeeeejjdd	d
Zd eeeeeeeeee f  eeeejjd	ddZeedddZG dd dejZG dd dejZG dd dejZG dd dejZG dd dejZG dd dejZG dd dejZdS )!    N)ListOptionalTupleUnion)nnTensor)
functional	Tacotron2Tlinear)in_dimout_dimbiasw_init_gainreturnc                 C   s4   t jj| ||d}t jjj|jt jj|d |S )a  Linear layer with xavier uniform initialization.

    Args:
        in_dim (int): Size of each input sample.
        out_dim (int): Size of each output sample.
        bias (bool, optional): If set to ``False``, the layer will not learn an additive bias. (Default: ``True``)
        w_init_gain (str, optional): Parameter passed to ``torch.nn.init.calculate_gain``
            for setting the gain parameter of ``xavier_uniform_``. (Default: ``linear``)

    Returns:
        (torch.nn.Linear): The corresponding linear layer.
    r   Zgain)torchr   Linearinitxavier_uniform_weightcalculate_gain)r   r   r   r   r
    r   I/var/www/auris/lib/python3.9/site-packages/torchaudio/models/tacotron2.py_get_linear_layer)   s    r      )	in_channelsout_channelskernel_sizestridepaddingdilationr   r   r   c           	   	   C   sl   |du r0|d dkrt dt||d  d }tjj| ||||||d}tjjj|jtjj|d |S )al  1D convolution with xavier uniform initialization.

    Args:
        in_channels (int): Number of channels in the input image.
        out_channels (int): Number of channels produced by the convolution.
        kernel_size (int, optional): Number of channels in the input image. (Default: ``1``)
        stride (int, optional): Number of channels in the input image. (Default: ``1``)
        padding (str, int or tuple, optional): Padding added to both sides of the input.
            (Default: dilation * (kernel_size - 1) / 2)
        dilation (int, optional): Number of channels in the input image. (Default: ``1``)
        w_init_gain (str, optional): Parameter passed to ``torch.nn.init.calculate_gain``
            for setting the gain parameter of ``xavier_uniform_``. (Default: ``linear``)

    Returns:
        (torch.nn.Conv1d): The corresponding Conv1D layer.
    N   r   zkernel_size must be odd)r   r   r    r!   r   r   )	
ValueErrorintr   r   Conv1dr   r   r   r   )	r   r   r   r   r    r!   r   r   Zconv1dr   r   r   _get_conv1d_layer;   s    
r&   )lengthsr   c                 C   sF   t |  }t jd|| j| jd}|| dk  }t |d}|S )al  Returns a binary mask based on ``lengths``. The ``i``-th row and ``j``-th column of the mask
    is ``1`` if ``j`` is smaller than ``i``-th element of ``lengths.

    Args:
        lengths (Tensor): The length of each element in the batch, with shape (n_batch, ).

    Returns:
        mask (Tensor): The binary mask, with shape (n_batch, max of ``lengths``).
    r   )devicedtyper   )	r   maxitemZaranger(   r)   	unsqueezebytele)r'   max_lenidsmaskr   r   r   _get_mask_from_lengthsi   s
    
r2   c                       s:   e Zd ZdZeeed fddZeedddZ  ZS )_LocationLayera  Location layer used in the Attention model.

    Args:
        attention_n_filter (int): Number of filters for attention model.
        attention_kernel_size (int): Kernel size for attention model.
        attention_hidden_dim (int): Dimension of attention hidden representation.
    )attention_n_filterattention_kernel_sizeattention_hidden_dimc              	      sH   t    t|d d }td|||dddd| _t||ddd| _d S )Nr   r"   F)r   r    r   r   r!   tanhr   r   )super__init__r$   r&   location_convr   location_dense)selfr4   r5   r6   r    	__class__r   r   r:      s    
	z_LocationLayer.__init__)attention_weights_catr   c                 C   s$   |  |}|dd}| |}|S )a  Location layer used in the Attention model.

        Args:
            attention_weights_cat (Tensor): Cumulative and previous attention weights
                with shape (n_batch, 2, max of ``text_lengths``).

        Returns:
            processed_attention (Tensor): Cumulative and previous attention weights
                with shape (n_batch, ``attention_hidden_dim``).
        r   r"   )r;   	transposer<   )r=   r@   Zprocessed_attentionr   r   r   forward   s    

z_LocationLayer.forward	__name__
__module____qualname____doc__r$   r:   r   rB   __classcell__r   r   r>   r   r3   z   s   
r3   c                       sd   e Zd ZdZeeeeedd fddZeeeedddZeeeeeeeef d	d
dZ	  Z
S )
_Attentiona  Locally sensitive attention model.

    Args:
        attention_rnn_dim (int): Number of hidden units for RNN.
        encoder_embedding_dim (int): Number of embedding dimensions in the Encoder.
        attention_hidden_dim (int): Dimension of attention hidden representation.
        attention_location_n_filter (int): Number of filters for Attention model.
        attention_location_kernel_size (int): Kernel size for Attention model.
    N)attention_rnn_dimencoder_embedding_dimr6   attention_location_n_filterattention_location_kernel_sizer   c                    s\   t    t||ddd| _t||ddd| _t|ddd| _t|||| _td | _	d S )NFr7   r8   r   r   inf)
r9   r:   r   query_layermemory_layervr3   location_layerfloatscore_mask_value)r=   rJ   rK   r6   rL   rM   r>   r   r   r:      s    
z_Attention.__init__)queryprocessed_memoryr@   r   c                 C   s@   |  |d}| |}| t|| | }|d}|S )a=  Get the alignment vector.

        Args:
            query (Tensor): Decoder output with shape (n_batch, n_mels * n_frames_per_step).
            processed_memory (Tensor): Processed Encoder outputs
                with shape (n_batch, max of ``text_lengths``, attention_hidden_dim).
            attention_weights_cat (Tensor): Cumulative and previous attention weights
                with shape (n_batch, 2, max of ``text_lengths``).

        Returns:
            alignment (Tensor): attention weights, it is a tensor with shape (batch, max of ``text_lengths``).
        r   r"   )rO   r,   rR   rQ   r   r7   squeeze)r=   rU   rV   r@   Zprocessed_queryZprocessed_attention_weightsZenergies	alignmentr   r   r   _get_alignment_energies   s
    

z"_Attention._get_alignment_energies)attention_hidden_statememoryrV   r@   r1   r   c           	      C   sN   |  |||}||| j}tj|dd}t|d|}|d}||fS )a  Pass the input through the Attention model.

        Args:
            attention_hidden_state (Tensor): Attention rnn last output with shape (n_batch, ``attention_rnn_dim``).
            memory (Tensor): Encoder outputs with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``).
            processed_memory (Tensor): Processed Encoder outputs
                with shape (n_batch, max of ``text_lengths``, ``attention_hidden_dim``).
            attention_weights_cat (Tensor): Previous and cumulative attention weights
                with shape (n_batch, current_num_frames * 2, max of ``text_lengths``).
            mask (Tensor): Binary mask for padded data with shape (n_batch, current_num_frames).

        Returns:
            attention_context (Tensor): Context vector with shape (n_batch, ``encoder_embedding_dim``).
            attention_weights (Tensor): Attention weights with shape (n_batch, max of ``text_lengths``).
        r   Zdim)	rY   Zmasked_fillrT   FZsoftmaxr   Zbmmr,   rW   )	r=   rZ   r[   rV   r@   r1   rX   attention_weightsattention_contextr   r   r   rB      s    
z_Attention.forward)rD   rE   rF   rG   r$   r:   r   rY   r   rB   rH   r   r   r>   r   rI      s    
rI   c                       s>   e Zd ZdZeee dd fddZeedddZ  Z	S )	_PrenetzPrenet Module. It is consists of ``len(output_size)`` linear layers.

    Args:
        in_dim (int): The size of each input sample.
        output_sizes (list): The output dimension of each linear layers.
    N)r   	out_sizesr   c                    s<   t    |g|d d  }tdd t||D | _d S )Nc                 S   s   g | ]\}}t ||d dqS )Fr   )r   ).0Zin_sizeZout_sizer   r   r   
<listcomp>      z$_Prenet.__init__.<locals>.<listcomp>)r9   r:   r   
ModuleListziplayers)r=   r   ra   Zin_sizesr>   r   r   r:   
  s
    
z_Prenet.__init__xr   c                 C   s*   | j D ]}tjt||ddd}q|S )zPass the input through Prenet.

        Args:
            x (Tensor): The input sequence to Prenet with shape (n_batch, in_dim).

        Return:
            x (Tensor): Tensor with shape (n_batch, sizes[-1])
              ?T)ptraining)rh   r]   dropoutrelu)r=   rj   r
   r   r   r   rB     s    

z_Prenet.forward)
rD   rE   rF   rG   r$   r   r:   r   rB   rH   r   r   r>   r   r`     s   r`   c                       s<   e Zd ZdZeeeed fddZeedddZ  ZS )_Postneta  Postnet Module.

    Args:
        n_mels (int): Number of mel bins.
        postnet_embedding_dim (int): Postnet embedding dimension.
        postnet_kernel_size (int): Postnet kernel size.
        postnet_n_convolution (int): Number of postnet convolutions.
    )n_melspostnet_embedding_dimpostnet_kernel_sizepostnet_n_convolutionc           
         s   t    t | _t|D ]}|dkr,|n|}||d kr@|n|}||d krTdnd}||d krh|n|}	| jtt|||dt	|d d d|dt
|	 qt| j| _d S )Nr   r   r
   r7   r"   r   r   r    r!   r   )r9   r:   r   rf   convolutionsrangeappend
Sequentialr&   r$   BatchNorm1dlenn_convs)
r=   rq   rr   rs   rt   ir   r   Z	init_gainZnum_featuresr>   r   r   r:   *  s,    

	z_Postnet.__init__ri   c                 C   sZ   t | jD ]J\}}|| jd k r>tjt||d| jd}q
tj||d| jd}q
|S )a  Pass the input through Postnet.

        Args:
            x (Tensor): The input sequence with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``).

        Return:
            x (Tensor): Tensor with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``).
        r   rk   )rm   )	enumeraterv   r|   r]   rn   r   r7   rm   )r=   rj   r}   convr   r   r   rB   J  s
    
z_Postnet.forwardrC   r   r   r>   r   rp      s    rp   c                       s>   e Zd ZdZeeedd fddZeeedddZ  ZS )	_Encodera  Encoder Module.

    Args:
        encoder_embedding_dim (int): Number of embedding dimensions in the encoder.
        encoder_n_convolution (int): Number of convolution layers in the encoder.
        encoder_kernel_size (int): The kernel size in the encoder.

    Examples
        >>> encoder = _Encoder(3, 512, 5)
        >>> input = torch.rand(10, 20, 30)
        >>> output = encoder(input)  # shape: (10, 30, 512)
    N)rK   encoder_n_convolutionencoder_kernel_sizer   c                    s   t    t | _t|D ]@}tt|||dt|d d dddt	|}| j
| qtj|t|d dddd| _| j  d S )Nr   r"   ro   ru   T)batch_firstbidirectional)r9   r:   r   rf   rv   rw   ry   r&   r$   rz   rx   ZLSTMlstmZflatten_parameters)r=   rK   r   r   _Z
conv_layerr>   r   r   r:   k  s0    

	
z_Encoder.__init__)rj   input_lengthsr   c                 C   sv   | j D ]}tt||d| j}q|dd}| }tjj	j
||dd}| |\}}tjj	j|dd\}}|S )a_  Pass the input through the Encoder.

        Args:
            x (Tensor): The input sequences with shape (n_batch, encoder_embedding_dim, n_seq).
            input_lengths (Tensor): The length of each input sequence with shape (n_batch, ).

        Return:
            x (Tensor): A tensor with shape (n_batch, n_seq, encoder_embedding_dim).
        rk   r   r"   T)r   )rv   r]   rn   ro   rm   rA   cpur   utilsZrnnZpack_padded_sequencer   Zpad_packed_sequence)r=   rj   r   r   outputsr   r   r   r   rB     s    
z_Encoder.forwardrC   r   r   r>   r   r   ]  s   !r   c                       s2  e Zd ZdZeeeeeeeeeeeeeedd fddZeedddZ	ee
eeeeeeeef dd	d
ZeedddZeeee
eeef dddZeeeeeeeeeeee
eeeeeeeeef	 dddZeeee
eeef dddZeedddZejjeee
eeeef dddZ  ZS )_Decodera,  Decoder with Attention model.

    Args:
        n_mels (int): number of mel bins
        n_frames_per_step (int): number of frames processed per step, only 1 is supported
        encoder_embedding_dim (int): the number of embedding dimensions in the encoder.
        decoder_rnn_dim (int): number of units in decoder LSTM
        decoder_max_step (int): maximum number of output mel spectrograms
        decoder_dropout (float): dropout probability for decoder LSTM
        decoder_early_stopping (bool): stop decoding when all samples are finished
        attention_rnn_dim (int): number of units in attention LSTM
        attention_hidden_dim (int): dimension of attention hidden representation
        attention_location_n_filter (int): number of filters for attention model
        attention_location_kernel_size (int): kernel size for attention model
        attention_dropout (float): dropout probability for attention LSTM
        prenet_dim (int): number of ReLU units in prenet layers
        gate_threshold (float): probability threshold for stop token
    N)rq   n_frames_per_steprK   decoder_rnn_dimdecoder_max_stepdecoder_dropoutdecoder_early_stoppingrJ   r6   rL   rM   attention_dropout
prenet_dimgate_thresholdr   c                    s   t    || _|| _|| _|| _|| _|| _|| _|| _	|| _
|| _|| _t|| ||g| _t|| || _t|||	|
|| _t|| |d| _t|| || | _t|| dddd| _d S )NTr   sigmoidr8   )r9   r:   rq   r   rK   rJ   r   r   r   r   r   r   r   r`   prenetr   ZLSTMCellattention_rnnrI   attention_layerdecoder_rnnr   linear_projection
gate_layer)r=   rq   r   rK   r   r   r   r   rJ   r6   rL   rM   r   r   r   r>   r   r   r:     s4    
z_Decoder.__init__)r[   r   c                 C   s4   | d}|j}|j}tj|| j| j ||d}|S )am  Gets all zeros frames to use as the first decoder input.

        Args:
            memory (Tensor): Encoder outputs with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``).

        Returns:
            decoder_input (Tensor): all zeros frames with shape
                (n_batch, max of ``text_lengths``, ``n_mels * n_frames_per_step``).
        r   r)   r(   sizer)   r(   r   zerosrq   r   r=   r[   n_batchr)   r(   decoder_inputr   r   r   _get_initial_frame  s
    
z_Decoder._get_initial_framec                 C   s   | d}| d}|j}|j}tj|| j||d}tj|| j||d}tj|| j||d}tj|| j||d}	tj||||d}
tj||||d}tj|| j||d}| j	|}||||	|
|||fS )a  Initializes attention rnn states, decoder rnn states, attention
        weights, attention cumulative weights, attention context, stores memory
        and stores processed memory.

        Args:
            memory (Tensor): Encoder outputs with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``).

        Returns:
            attention_hidden (Tensor): Hidden state of the attention LSTM with shape (n_batch, ``attention_rnn_dim``).
            attention_cell (Tensor): Hidden state of the attention LSTM with shape (n_batch, ``attention_rnn_dim``).
            decoder_hidden (Tensor): Hidden state of the decoder LSTM with shape (n_batch, ``decoder_rnn_dim``).
            decoder_cell (Tensor): Hidden state of the decoder LSTM with shape (n_batch, ``decoder_rnn_dim``).
            attention_weights (Tensor): Attention weights with shape (n_batch, max of ``text_lengths``).
            attention_weights_cum (Tensor): Cumulated attention weights with shape (n_batch, max of ``text_lengths``).
            attention_context (Tensor): Context vector with shape (n_batch, ``encoder_embedding_dim``).
            processed_memory (Tensor): Processed encoder outputs
                with shape (n_batch, max of ``text_lengths``, ``attention_hidden_dim``).
        r   r   r   )
r   r)   r(   r   r   rJ   r   rK   r   rP   )r=   r[   r   Zmax_timer)   r(   attention_hiddenattention_celldecoder_hiddendecoder_cellr^   attention_weights_cumr_   rV   r   r   r   _initialize_decoder_states  s*    

z#_Decoder._initialize_decoder_states)decoder_inputsr   c                 C   s@   | dd}||dt|d| j d}| dd}|S )ak  Prepares decoder inputs.

        Args:
            decoder_inputs (Tensor): Inputs used for teacher-forced training, i.e. mel-specs,
                with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``)

        Returns:
            inputs (Tensor): Processed decoder inputs with shape (max of ``mel_specgram_lengths``, n_batch, ``n_mels``).
        r   r"   r   rb   )rA   viewr   r$   r   )r=   r   r   r   r   _parse_decoder_inputs.  s    z_Decoder._parse_decoder_inputs)mel_specgramgate_outputs
alignmentsr   c                 C   sb   | dd }| dd }| dd }|jd d| jf}|j| }| dd}|||fS )aq  Prepares decoder outputs for output

        Args:
            mel_specgram (Tensor): mel spectrogram with shape (max of ``mel_specgram_lengths``, n_batch, ``n_mels``)
            gate_outputs (Tensor): predicted stop token with shape (max of ``mel_specgram_lengths``, n_batch)
            alignments (Tensor): sequence of attention weights from the decoder
                with shape (max of ``mel_specgram_lengths``, n_batch, max of ``text_lengths``)

        Returns:
            mel_specgram (Tensor): mel spectrogram with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``)
            gate_outputs (Tensor): predicted stop token with shape (n_batch, max of ``mel_specgram_lengths``)
            alignments (Tensor): sequence of attention weights from the decoder
                with shape (n_batch, max of ``mel_specgram_lengths``, max of ``text_lengths``)
        r   r   rb   r"   )rA   
contiguousshaperq   r   )r=   r   r   r   r   r   r   r   _parse_decoder_outputsC  s    
z_Decoder._parse_decoder_outputs)r   r   r   r   r   r^   r   r_   r[   rV   r1   r   c              	   C   s   t ||fd}| |||f\}}t|| j| j}t j|d|dfdd}| ||	|
||\}}||7 }t ||fd}| 	|||f\}}t|| j
| j}t j||fdd}| |}| |}|||||||||f	S )a&	  Decoder step using stored states, attention and memory

        Args:
            decoder_input (Tensor): Output of the Prenet with shape (n_batch, ``prenet_dim``).
            attention_hidden (Tensor): Hidden state of the attention LSTM with shape (n_batch, ``attention_rnn_dim``).
            attention_cell (Tensor): Hidden state of the attention LSTM with shape (n_batch, ``attention_rnn_dim``).
            decoder_hidden (Tensor): Hidden state of the decoder LSTM with shape (n_batch, ``decoder_rnn_dim``).
            decoder_cell (Tensor): Hidden state of the decoder LSTM with shape (n_batch, ``decoder_rnn_dim``).
            attention_weights (Tensor): Attention weights with shape (n_batch, max of ``text_lengths``).
            attention_weights_cum (Tensor): Cumulated attention weights with shape (n_batch, max of ``text_lengths``).
            attention_context (Tensor): Context vector with shape (n_batch, ``encoder_embedding_dim``).
            memory (Tensor): Encoder output with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``).
            processed_memory (Tensor): Processed Encoder outputs
                with shape (n_batch, max of ``text_lengths``, ``attention_hidden_dim``).
            mask (Tensor): Binary mask for padded data with shape (n_batch, current_num_frames).

        Returns:
            decoder_output: Predicted mel spectrogram for the current frame with shape (n_batch, ``n_mels``).
            gate_prediction (Tensor): Prediction of the stop token with shape (n_batch, ``1``).
            attention_hidden (Tensor): Hidden state of the attention LSTM with shape (n_batch, ``attention_rnn_dim``).
            attention_cell (Tensor): Hidden state of the attention LSTM with shape (n_batch, ``attention_rnn_dim``).
            decoder_hidden (Tensor): Hidden state of the decoder LSTM with shape (n_batch, ``decoder_rnn_dim``).
            decoder_cell (Tensor): Hidden state of the decoder LSTM with shape (n_batch, ``decoder_rnn_dim``).
            attention_weights (Tensor): Attention weights with shape (n_batch, max of ``text_lengths``).
            attention_weights_cum (Tensor): Cumulated attention weights with shape (n_batch, max of ``text_lengths``).
            attention_context (Tensor): Context vector with shape (n_batch, ``encoder_embedding_dim``).
        rb   r   r\   )r   catr   r]   rn   r   rm   r,   r   r   r   r   r   )r=   r   r   r   r   r   r^   r   r_   r[   rV   r1   Z
cell_inputr@   Z decoder_hidden_attention_contextZdecoder_outputZgate_predictionr   r   r   decodec  s0    )


z_Decoder.decode)r[   mel_specgram_truthmemory_lengthsr   c                 C   s  |  |d}| |}tj||fdd}| |}t|}| |\}}}	}
}}}}g g g   }}}t||	dd k r|t| }| 
||||	|
||||||\	}}}}}	}
}}}||dg7 }||dg7 }||g7 }qh| t|t|t|\}}}|||fS )a  Decoder forward pass for training.

        Args:
            memory (Tensor): Encoder outputs
                with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``).
            mel_specgram_truth (Tensor): Decoder ground-truth mel-specs for teacher forcing
                with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``).
            memory_lengths (Tensor): Encoder output lengths for attention masking
                (the same as ``text_lengths``) with shape (n_batch, ).

        Returns:
            mel_specgram (Tensor): Predicted mel spectrogram
                with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``).
            gate_outputs (Tensor): Predicted stop token for each timestep
                with shape (n_batch,  max of ``mel_specgram_lengths``).
            alignments (Tensor): Sequence of attention weights from the decoder
                with shape (n_batch,  max of ``mel_specgram_lengths``, max of ``text_lengths``).
        r   r\   r   )r   r,   r   r   r   r   r2   r   r{   r   r   rW   r   stack)r=   r[   r   r   r   r   r1   r   r   r   r   r^   r   r_   rV   Zmel_outputsr   r   Z
mel_outputgate_outputr   r   r   r   rB     s`    



z_Decoder.forwardc                 C   s4   | d}|j}|j}tj|| j| j ||d}|S )aU  Gets all zeros frames to use as the first decoder input

        args:
            memory (Tensor): Encoder outputs
                with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``).

        returns:
            decoder_input (Tensor): All zeros frames with shape(n_batch, ``n_mels`` * ``n_frame_per_step``).
        r   r   r   r   r   r   r   _get_go_frame  s
    
z_Decoder._get_go_frame)r[   r   r   c                 C   s  | d|j }}| |}t|}| |\}}}	}
}}}}tj|gtj|d}tj|gtj|d}g }g }g }t	| j
D ]}| |}| ||||	|
||||||\	}}}}}	}
}}}||d ||dd || ||   d7  < |t|d| jkO }| jr,t|r, q2|}q|t|| j
krLtd tj|dd}tj|dd}tj|dd}| |||\}}}||||fS )a  Decoder inference

        Args:
            memory (Tensor): Encoder outputs
                with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``).
            memory_lengths (Tensor): Encoder output lengths for attention masking
                (the same as ``text_lengths``) with shape (n_batch, ).

        Returns:
            mel_specgram (Tensor): Predicted mel spectrogram
                with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``).
            mel_specgram_lengths (Tensor): the length of the predicted mel spectrogram (n_batch, ))
            gate_outputs (Tensor): Predicted stop token for each timestep
                with shape (n_batch,  max of ``mel_specgram_lengths``).
            alignments (Tensor): Sequence of attention weights from the decoder
                with shape (n_batch,  max of ``mel_specgram_lengths``, max of ``text_lengths``).
        r   r   r   zZReached max decoder steps. The generated spectrogram might not cover the whole transcript.r\   )r   r(   r   r2   r   r   r   int32boolrw   r   r   r   rx   r,   rA   r   rW   r   r   allr{   warningswarnr   r   )r=   r[   r   Z
batch_sizer(   r   r1   r   r   r   r   r^   r   r_   rV   mel_specgram_lengthsfinishedZmel_specgramsr   r   r   r   r   r   r   r   infer
  sx    



z_Decoder.infer)rD   rE   rF   rG   r$   rS   r   r:   r   r   r   r   r   r   r   rB   r   r   jitexportr   rH   r   r   r>   r   r     sX   31"KLr   c                       s   e Zd ZdZdeeeeeeeeeeeeeeeeeeeeeedd fddZeeeee	eeeef dddZ
ejjdeee e	eeef dddZ  ZS )r	   a	  Tacotron2 model from *Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions*
    :cite:`shen2018natural` based on the implementation from
    `Nvidia Deep Learning Examples <https://github.com/NVIDIA/DeepLearningExamples/>`_.

    See Also:
        * :class:`torchaudio.pipelines.Tacotron2TTSBundle`: TTS pipeline with pretrained model.

    Args:
        mask_padding (bool, optional): Use mask padding (Default: ``False``).
        n_mels (int, optional): Number of mel bins (Default: ``80``).
        n_symbol (int, optional): Number of symbols for the input text (Default: ``148``).
        n_frames_per_step (int, optional): Number of frames processed per step, only 1 is supported (Default: ``1``).
        symbol_embedding_dim (int, optional): Input embedding dimension (Default: ``512``).
        encoder_n_convolution (int, optional): Number of encoder convolutions (Default: ``3``).
        encoder_kernel_size (int, optional): Encoder kernel size (Default: ``5``).
        encoder_embedding_dim (int, optional): Encoder embedding dimension (Default: ``512``).
        decoder_rnn_dim (int, optional): Number of units in decoder LSTM (Default: ``1024``).
        decoder_max_step (int, optional): Maximum number of output mel spectrograms (Default: ``2000``).
        decoder_dropout (float, optional): Dropout probability for decoder LSTM (Default: ``0.1``).
        decoder_early_stopping (bool, optional): Continue decoding after all samples are finished (Default: ``True``).
        attention_rnn_dim (int, optional): Number of units in attention LSTM (Default: ``1024``).
        attention_hidden_dim (int, optional): Dimension of attention hidden representation (Default: ``128``).
        attention_location_n_filter (int, optional): Number of filters for attention model (Default: ``32``).
        attention_location_kernel_size (int, optional): Kernel size for attention model (Default: ``31``).
        attention_dropout (float, optional): Dropout probability for attention LSTM (Default: ``0.1``).
        prenet_dim (int, optional): Number of ReLU units in prenet layers (Default: ``256``).
        postnet_n_convolution (int, optional): Number of postnet convolutions (Default: ``5``).
        postnet_kernel_size (int, optional): Postnet kernel size (Default: ``5``).
        postnet_embedding_dim (int, optional): Postnet embedding dimension (Default: ``512``).
        gate_threshold (float, optional): Probability threshold for stop token (Default: ``0.5``).
    FP      r                 皙?T             rk   N)mask_paddingrq   n_symbolr   symbol_embedding_dimrK   r   r   r   r   r   r   rJ   r6   rL   rM   r   r   rt   rs   rr   r   r   c                    s   t    || _|| _|| _t||| _tjj	
| jj t|||| _t||||	|
|||||||||| _t||||| _d S )N)r9   r:   r   rq   r   r   Z	Embedding	embeddingr   r   r   r   r   encoderr   decoderrp   postnet)r=   r   rq   r   r   r   rK   r   r   r   r   r   r   rJ   r6   rL   rM   r   r   rt   rs   rr   r   r>   r   r   r:     s0    
zTacotron2.__init__)tokenstoken_lengthsr   r   r   c                 C   s   |  |dd}| ||}| j|||d\}}}| |}	||	 }	| jrt|}
|
| j|
	d|
	d}
|

ddd}
||
d |	|
d ||
dddddf d ||	||fS )a  Pass the input through the Tacotron2 model. This is in teacher
        forcing mode, which is generally used for training.

        The input ``tokens`` should be padded with zeros to length max of ``token_lengths``.
        The input ``mel_specgram`` should be padded with zeros to length max of ``mel_specgram_lengths``.

        Args:
            tokens (Tensor): The input tokens to Tacotron2 with shape `(n_batch, max of token_lengths)`.
            token_lengths (Tensor): The valid length of each sample in ``tokens`` with shape `(n_batch, )`.
            mel_specgram (Tensor): The target mel spectrogram
                with shape `(n_batch, n_mels, max of mel_specgram_lengths)`.
            mel_specgram_lengths (Tensor): The length of each mel spectrogram with shape `(n_batch, )`.

        Returns:
            [Tensor, Tensor, Tensor, Tensor]:
                Tensor
                    Mel spectrogram before Postnet with shape `(n_batch, n_mels, max of mel_specgram_lengths)`.
                Tensor
                    Mel spectrogram after Postnet with shape `(n_batch, n_mels, max of mel_specgram_lengths)`.
                Tensor
                    The output for stop token at each time step with shape `(n_batch, max of mel_specgram_lengths)`.
                Tensor
                    Sequence of attention weights from the decoder with
                    shape `(n_batch, max of mel_specgram_lengths, max of token_lengths)`.
        r   r"   )r   r   g        Ng     @@)r   rA   r   r   r   r   r2   expandrq   r   ZpermuteZmasked_fill_)r=   r   r   r   r   embedded_inputsencoder_outputsr   r   Zmel_specgram_postnetr1   r   r   r   rB     s    !
zTacotron2.forward)r   r'   r   c                 C   s   |j \}}|du r0t|g||j|j}|dus<J | |dd}| 	||}| j
||\}}}	}
| |}|| }|
d||dd}
|||
fS )a  Using Tacotron2 for inference. The input is a batch of encoded
        sentences (``tokens``) and its corresponding lengths (``lengths``). The
        output is the generated mel spectrograms, its corresponding lengths, and
        the attention weights from the decoder.

        The input `tokens` should be padded with zeros to length max of ``lengths``.

        Args:
            tokens (Tensor): The input tokens to Tacotron2 with shape `(n_batch, max of lengths)`.
            lengths (Tensor or None, optional):
                The valid length of each sample in ``tokens`` with shape `(n_batch, )`.
                If ``None``, it is assumed that the all the tokens are valid. Default: ``None``

        Returns:
            (Tensor, Tensor, Tensor):
                Tensor
                    The predicted mel spectrogram with shape `(n_batch, n_mels, max of mel_specgram_lengths)`.
                Tensor
                    The length of the predicted mel spectrogram with shape `(n_batch, )`.
                Tensor
                    Sequence of attention weights from the decoder with shape
                    `(n_batch, max of mel_specgram_lengths, max of lengths)`.
        Nr   r"   r   )r   r   Ztensorr   tor(   r)   r   rA   r   r   r   r   Zunfold)r=   r   r'   r   
max_lengthr   r   r   r   r   r   Zmel_outputs_postnetr   r   r   r     s    

zTacotron2.infer)Fr   r   r   r   r   r   r   r   r   r   Tr   r   r   r   r   r   r   r   r   rk   )N)rD   rE   rF   rG   r   r$   rS   r:   r   r   rB   r   r   r   r   r   rH   r   r   r>   r   r	   e  sp   "                      56)Tr
   )r   r   Nr   Tr
   )r   typingr   r   r   r   r   r   r   Ztorch.nnr   r]   __all__r$   r   strr   r   r%   r&   r2   Moduler3   rI   r`   rp   r   r   r	   r   r   r   r   <module>   sF         .1W=H   C