o
    Zh                     @   s  d Z ddlZddlmZmZmZ ddlZddlZddlmZ ddl	m
Z
 ddlmZmZmZ ddlmZ dd	lmZ dd
lmZ ddlmZmZ ddlmZmZmZmZ ddlmZ ddlm Z  e!e"Z#dZ$dZ%dZ&g dZ'dZ(dZ)G dd dej*Z+G dd dej*Z,G dd dej*Z-G dd dej*Z.G dd dej*Z/G dd  d ej*Z0G d!d" d"ej*Z1G d#d$ d$ej*Z2G d%d& d&ej*Z3G d'd( d(eZ4d)Z5d*Z6G d+d, d,e4Z7ed-e5G d.d/ d/e4Z8ed0e5G d1d2 d2e4Z9g d3Z:dS )4zPyTorch M-CTC-T model.    N)OptionalTupleUnion)nn   )ACT2FN)add_code_sample_docstringsadd_start_docstrings%add_start_docstrings_to_model_forward)is_deepspeed_zero3_enabled)is_fsdp_managed_module)_prepare_4d_attention_mask)BaseModelOutputCausalLMOutput)PreTrainedModelapply_chunking_to_forward find_pruneable_heads_and_indicesprune_linear_layer)logging   )MCTCTConfigr   zspeechbrain/m-ctc-t-large)r      i   zY"Mr. Quilter is the apostle of the middle classes, and we're glad to welcome his gospel."gv@c                       s(   e Zd ZdZ fddZdd Z  ZS )MCTCTConv1dSubsamplerz
    Convolutional subsampler: a stack of 1D convolution (along temporal dimension) followed by non-linear activation
    via gated linear units (https://arxiv.org/abs/1911.08460)
    c                    s   t    | _|j _t|j _|j	 _
|j|j  _ j
dkr1|jd u r,td|j _nd  _|jd  _|j _|j _t fddt jD  _d S )Nr   zbNeed to specify `conv_channels` configuration in `MCTCTConfig` to use multiple convolution layers.   c                 3   s\    | ])\}}t j|d kr jn j| | jd k r j| n j| j| ddV  qdS )r   r   Zvalid)kernel_sizestridepaddingN)r   Conv1din_channelsmid_channels
num_layersout_channelsr   ).0ikself b/var/www/auris/lib/python3.10/site-packages/transformers/models/deprecated/mctct/modeling_mctct.py	<genexpr>Y   s    
z1MCTCTConv1dSubsampler.__init__.<locals>.<genexpr>)super__init__configZconv_glu_dimglu_dimr   DropoutZconv_dropoutdropoutnum_conv_layersr    Zinput_feat_per_channelZinput_channelsr   Zconv_channels
ValueErrorr   hidden_sizer!   conv_kernelr   conv_strider   
ModuleList	enumerateconv_layersr&   r,   	__class__r%   r(   r+   =   s&   



zMCTCTConv1dSubsampler.__init__c                 C   s   t dd | jD }tjj|dd||fdd}|dd }| jD ]}||}tjj	|| j
d}| |}q#|dd }|S )Nc                 S   s   g | ]}|d  qS )r   r'   )r"   sizer'   r'   r(   
<listcomp>g       z1MCTCTConv1dSubsampler.forward.<locals>.<listcomp>r   Zconstantr   r   dim)sumr   torchr   
functionalpad	transpose
contiguousr7   Zglur-   r/   )r&   input_featuresr   hidden_statesconvr'   r'   r(   forwardd   s   
zMCTCTConv1dSubsampler.forward__name__
__module____qualname____doc__r+   rI   __classcell__r'   r'   r9   r(   r   7   s    'r   c                       s,   e Zd ZdZ fddZ	dddZ  ZS )	MCTCTEmbeddingszGConstruct the embeddings from word, position and token_type embeddings.c                    s   t    tj|j|j|jd| _t|j|j| _	t|j
|j| _t | _t|j| _| jdt|jddd | jdtj| j tj| jjddd d S )N)padding_idxposition_ids)r   F)
persistenttoken_type_idsdtypedevice)r*   r+   r   	Embedding
vocab_sizer2   pad_token_idword_embeddingsmax_position_embeddingsZposition_embeddingsZtype_vocab_sizetoken_type_embeddingsMCTCTLayerNorm	LayerNormr.   hidden_dropout_probr/   Zregister_bufferrA   arangeexpandzerosrR   r;   longrX   r8   r9   r'   r(   r+   w   s   

zMCTCTEmbeddings.__init__Nr   c                 C   s   |d ur|  n|  d d }|d }|d u r%| jd d ||| f }|d u rOt| drD| jd d d |f }||d |}	|	}ntj|tj| jjd}|d u rX| 	|}| 
|}
||
 }| |}| |}|S )NrS   r   rU   r   rV   )r;   rR   hasattrrU   rc   rA   rd   re   rX   r\   r^   r`   r/   )r&   rF   rU   rR   inputs_embedsZpast_key_values_lengthZinput_shapeZ
seq_lengthZbuffered_token_type_idsZ buffered_token_type_ids_expandedr^   Z
embeddingsr'   r'   r(   rI      s"    




zMCTCTEmbeddings.forward)NNNNr   rJ   r'   r'   r9   r(   rP   t   s
    rP   c                       sD   e Zd Z fddZdd Zdd Zdd Z					
dddZ  ZS )MCTCTSelfAttentionc                    s   t    |j|j dkrt|dstd|j d|j d|j| _|j| _| j| j | _t	j
|j| jdd| _t	j
|j| jdd| _t	j
|j| jdd| _t	|j| _|j| _t	d|j d	 | j| _|j| _d S )
Nr   Zembedding_sizezThe hidden size (z6) is not a multiple of the number of attention heads ()Fbiasr   r   )r*   r+   r2   num_attention_headsrf   r1   Zattention_head_dimattention_head_sizeall_head_sizer   Linearquerykeyvaluer.   Zattention_probs_dropout_probr/   r]   rY   distance_embedding
is_decoderr8   r9   r'   r(   r+      s"   

zMCTCTSelfAttention.__init__c                 C   s6   |  d d | j| jf }|j| }|ddddS )NrS   r   r   r      )r;   rl   rm   viewpermute)r&   xZnew_x_shaper'   r'   r(   transpose_for_scores   s   
z'MCTCTSelfAttention.transpose_for_scoresc                 C   sF   t |jdkr|jttt |j }|jt| jttt | S )Nr   )lenshaperw   reversedrangeZreshape)r&   rx   r{   r'   r'   r(   reshape_fortran   s    z"MCTCTSelfAttention.reshape_fortranc                 C   s   | dddd}|j\}}}}tj|tj||||f|jdfdd}| |||| | d|g}|d d d || d | f }| |||| d ||g}|d }|d d ||| f dd}| ddddS )Nr   r   ru   r   rX   r>   )rw   r{   rA   catrd   rX   r~   rD   )r&   ZscoresbatchZhidden_stateZseq_lenheadsZ	halfpointr'   r'   r(   "relative_position_embedding_rotate   s   &  z5MCTCTSelfAttention.relative_position_embedding_rotateNFc                 C   s   |  |}|t| j }| | |}| | |}| |}t||	dd}	| j
j}
td|
|	dd}| |}|	| }	|d urL|	| }	tjj|	dd}| |}|d ura|| }t||}|ddddjdd	}|r{||f}|S |f}|S )
NrS   zlh, bche -> bcler   ru   r>   r   r   )Z	start_dim)rp   mathsqrtrm   ry   rq   rr   rA   matmulrD   rs   weightZeinsumr   r   rB   Zsoftmaxr/   rw   flatten)r&   rG   attention_mask	head_maskoutput_attentionsZmixed_query_layerZ	key_layerZvalue_layerZquery_layerZattention_scoresZpositional_embeddingZrelative_position_scoresZattention_probsZcontext_layeroutputsr'   r'   r(   rI      s,   



zMCTCTSelfAttention.forwardNNF)	rK   rL   rM   r+   ry   r~   r   rI   rO   r'   r'   r9   r(   rh      s    rh   c                       $   e Zd Z fddZdd Z  ZS )r_   c                    s2   t    ttd| _ttd| _d S Nr   )	r*   r+   r   	ParameterrA   onessingleton_weightrd   singleton_biasr%   r9   r'   r(   r+     s   
zMCTCTLayerNorm.__init__c                 C   s   || j  | j S N)r   r   r&   rG   r'   r'   r(   rI      s   zMCTCTLayerNorm.forwardrK   rL   rM   r+   rI   rO   r'   r'   r9   r(   r_     s    r_   c                       r   )MCTCTSelfOutputc                    sL   t    || _tj|j|jdd| _tj|j|jd| _t	|j
| _d S NFrj   )Zeps)r*   r+   r,   r   ro   r2   denser`   layer_norm_epsr.   ra   r/   r8   r9   r'   r(   r+   %  s
   
zMCTCTSelfOutput.__init__c                 C   &   |  |}| |}| || }|S r   r   r/   r`   r&   rG   Zinput_tensorr'   r'   r(   rI   ,     

zMCTCTSelfOutput.forwardr   r'   r'   r9   r(   r   $  s    r   c                       s4   e Zd Z fddZdd Z			d	ddZ  ZS )
MCTCTAttentionc                    s*   t    t|| _t|| _t | _d S r   )r*   r+   rh   r&   r   outputsetpruned_headsr8   r9   r'   r(   r+   4  s   


zMCTCTAttention.__init__c                 C   s   t |dkrd S t|| jj| jj| j\}}t| jj|| j_t| jj|| j_t| jj	|| j_	t| j
j|dd| j
_| jjt | | j_| jj| jj | j_| j|| _d S )Nr   r   r>   )rz   r   r&   rl   rm   r   r   rp   rq   rr   r   r   rn   union)r&   r   indexr'   r'   r(   prune_heads:  s   zMCTCTAttention.prune_headsNFc                 C   s6   |  ||||}| |d |}|f|dd   }|S )Nr   r   )r&   r   )r&   rG   r   r   r   Zself_outputsattention_outputr   r'   r'   r(   rI   L  s   zMCTCTAttention.forwardr   )rK   rL   rM   r+   r   rI   rO   r'   r'   r9   r(   r   3  s    r   c                       r   )MCTCTIntermediatec                    sH   t    tj|j|jdd| _t|jt	rt
|j | _d S |j| _d S )NFrj   )r*   r+   r   ro   r2   intermediate_sizer   
isinstanceZ
hidden_actstrr   intermediate_act_fnr8   r9   r'   r(   r+   `  s
   
zMCTCTIntermediate.__init__c                 C   s   |  |}| |}|S r   )r   r   r   r'   r'   r(   rI   h  s   

zMCTCTIntermediate.forwardr   r'   r'   r9   r(   r   _  s    r   c                       r   )MCTCTOutputc                    sF   t    tj|j|jdd| _tj|j|jd| _t	|j
| _d S r   )r*   r+   r   ro   r   r2   r   r`   r   r.   ra   r/   r8   r9   r'   r(   r+   o  s   
zMCTCTOutput.__init__c                 C   r   r   r   r   r'   r'   r(   rI   u  r   zMCTCTOutput.forwardr   r'   r'   r9   r(   r   n  s    r   c                       s:   e Zd Zdef fddZ			d
ddZdd	 Z  ZS )
MCTCTLayerr,   c                    sB   t    d| _|j| _t|| _t|| _|j| _t	|| _
d S r   )r*   r+   seq_len_dimchunk_size_feed_forwardr   intermediater   	attentionrt   r   r   r8   r9   r'   r(   r+   }  s   


zMCTCTLayer.__init__NFc           	      C   sH   | j ||||d}|d }|dd  }t| j| j| j|}|f| }|S )N)r   r   r   )r   r   feed_forward_chunkr   r   )	r&   rG   r   r   r   Zself_attention_outputsr   r   layer_outputr'   r'   r(   rI     s   
zMCTCTLayer.forwardc                 C   s   |  |}| ||}|S r   )r   r   )r&   r   Zintermediate_outputr   r'   r'   r(   r     s   
zMCTCTLayer.feed_forward_chunkr   )rK   rL   rM   r   r+   rI   r   rO   r'   r'   r9   r(   r   |  s    
r   c                   @   s@   e Zd ZdZeZdZdZdZdd Z	de
jfdd	Zd
d ZdS )MCTCTPreTrainedModelz
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    mctctrF   Tc                 C   s  | j j}t|tjr|jjjd|d |jdur|jj	  nDt|tj
r=|jjjd|d |jdur<|jj|j 	  n&t|tjrQ|jj	  |jjd nt|trc|jjd |jj	  t|tjtjfr|jjjd|d |jdur|jj	  dS dS dS )zInitialize the weightsg        )meanstdNg      ?)r,   Zinitializer_ranger   r   ro   r   dataZnormal_rk   Zzero_rY   rQ   r`   Zfill_r_   r   r   r   )r&   moduler   r'   r'   r(   _init_weights  s.   



z"MCTCTPreTrainedModel._init_weightsinput_lengthsc                 C   sh   d}t t| jj| jj| jjD ]!\}}}|d }|d|  ||d   d }tj||ddd }q|S )zH
        Computes the output length of the convolutional layers
        r   r   trunc)Zrounding_mode)zipr}   r,   r0   r3   r4   rA   div)r&   r   Zdilation_Z	kernel_szr   r   r'   r'   r(    _get_feat_extract_output_lengths  s   z5MCTCTPreTrainedModel._get_feat_extract_output_lengthsc                 C   s   t |jdkr|d d d d df }| |d}| d }tj||f|j|jd}d|tj	||jd|d f< |
dgd
dg }|S )Nr   rS   r   rV   r   r   )rz   r{   r   r@   r;   rA   rd   rW   rX   rb   flipZcumsumre   )r&   Zfeature_vector_lengthr   Zsubsampled_lengthsZbszr'   r'   r(   "_get_feature_vector_attention_mask  s   z7MCTCTPreTrainedModel._get_feature_vector_attention_maskN)rK   rL   rM   rN   r   config_classZbase_model_prefixZmain_input_nameZsupports_gradient_checkpointingr   rA   
LongTensorr   r   r'   r'   r'   r(   r     s    r   aH  
    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
    it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
    behavior.

    Parameters:
        config ([`MCTCTConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
a  
    Args:
        input_features (`torch.LongTensor` of shape `({0})`):
            Indices of input sequence tokens in the vocabulary.

            Indices can be obtained using [`Wav2Vec2CTCTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)
        head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
            Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
c                       s^   e Zd Zdef fddZ			ddejdejdejd	ed
ededee	e
f fddZ  ZS )MCTCTEncoderr,   c                    sP   t     j| _t | _t | _t fddt	 j
D | _d| _d S )Nc                    s   g | ]}t  qS r'   )r   )r"   r   r,   r'   r(   r<     r=   z)MCTCTEncoder.__init__.<locals>.<listcomp>F)r*   r+   ra   r_   
layer_normr   rH   r   r5   r}   Znum_hidden_layerslayersgradient_checkpointingr8   r9   r   r(   r+     s   
 
zMCTCTEncoder.__init__FTrF   r   r   r   output_hidden_statesreturn_dictreturnc                 C   s  |d ur|n| j j}|d ur|n| j j}|d ur|n| j j}| |}| |}|d ur5| |jd |}tj	j
|| j| jd}|d urJt||j}|rNdnd }	|rTdnd }
|d urw| d t| jkrwtdt| j d| d  dt p}t| }t| jD ]R\}}|r|	|f }	tg }| jr|| j jk rdnd	}|r|r| jr| jr| |j|||d ur|| nd |}n||||d
}|d }|rd}|r|
|d f }
q|r|	|f }	|stdd ||	|
fD S t||	|
dS )Nr   )ptrainingr'   r   z&The head_mask should be specified for z layers, but it is for .TF)rG   r   r   )NNc                 s   s    | ]	}|d ur|V  qd S r   r'   )r"   vr'   r'   r(   r)   m  s    z'MCTCTEncoder.forward.<locals>.<genexpr>Zlast_hidden_staterG   
attentions)r,   r   r   use_return_dictr   rH   r   r{   r   rB   r/   ra   r   r   rW   r;   rz   r   r1   r   r   r6   rA   ZrandZ	layerdropr   Z_gradient_checkpointing_func__call__tupler   )r&   rF   r   r   r   r   r   rg   rG   Zencoder_statesZall_attentionsZsynced_gpusidxZencoder_layerZdropout_probabilityZskip_the_layerZlayer_outputsr'   r'   r(   rI     sj   	





zMCTCTEncoder.forward)FFT)rK   rL   rM   r   r+   rA   Tensorboolr   r   r   rI   rO   r'   r'   r9   r(   r     s(    
r   zaThe bare M-CTC-T Model transformer outputting raw hidden-states without any specific head on top.c                       s   e Zd Z fddZeedeee	e
ded					ddejdeej d	eej d
ee dee dee deee	f fddZ  ZS )
MCTCTModelc                    s(   t  | || _t|| _|   d S r   )r*   r+   r,   r   encoder	post_initr8   r9   r'   r(   r+   x  s   
zMCTCTModel.__init__zbatch_size, sequence_lengthZaudio)
checkpointoutput_typer   Zmodalityexpected_outputNrF   r   r   r   r   r   r   c           	      C   s   |d ur|n| j j}|d ur|n| j j}|d ur|n| j j}|d u r&td| j||||||d}|d }|s@|f|dd   S t||j|jdS )Nz#You have to specify input_features.r   r   r   r   r   r   r   r   )	r,   r   r   r   r1   r   r   rG   r   )	r&   rF   r   r   r   r   r   Zencoder_outputsZsequence_outputr'   r'   r(   rI     s,   zMCTCTModel.forward)NNNNN)rK   rL   rM   r+   r
   MCTCT_INPUTS_DOCSTRINGformatr   _CHECKPOINT_FOR_DOCr   _CONFIG_FOR_DOC_EXPECTED_OUTPUT_SHAPErA   r   r   r   r   r   rI   rO   r'   r'   r9   r(   r   s  s<    	

r   zcMCTCT Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).c                       s   e Zd Z fddZeeeeee	e
ed						ddejdeej deej dee d	ee d
ee deej deeef fddZ  ZS )MCTCTForCTCc                    sT   t  | t|| _|jd u rtd| j d|j}t	||j| _
|   d S )NzYou are trying to instantiate z with a configuration that does not define the vocabulary size of the language model head. Please instantiate the model as follows: `MCTCTForCTC.from_pretrained(..., vocab_size=vocab_size)`. or define `vocab_size` of your model's configuration.)r*   r+   r   r   rZ   r1   r:   r2   r   ro   ctc_headr   )r&   r,   Zoutput_hidden_sizer9   r'   r(   r+     s   

zMCTCTForCTC.__init__)r   r   r   r   Zexpected_lossNrF   r   r   r   r   r   labelsr   c              
   C   s~  |dur|  | jjkrtd| jj |dur|n| jj}| j||||||d}|d }	| |	}
d}|dur|dur?|ntj|j	dd tj
d}| |dtj
}|dk}|d}||}tjj|
dtjddd}tjjjd	d
 tjj||||| jj| jj| jjd}W d   n1 sw   Y  |s|
f|td  }|dur|f| S |S t||
|j|jdS )a  
        labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
            Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
            the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
            All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
            config.vocab_size - 1]`.
        Nz$Label values must be <= vocab_size: r   r   rS   )rW   )r?   rW   r   F)enabled)blankZ	reductionZzero_infinity)losslogitsrG   r   )maxr,   rZ   r1   r   r   r   rA   r   r{   re   r   r@   toZmasked_selectr   rB   Zlog_softmaxZfloat32rD   backendsZcudnnflagsZctc_lossr[   Zctc_loss_reductionZctc_zero_infinity_HIDDEN_STATES_START_POSITIONr   rG   r   )r&   rF   r   r   r   r   r   r   r   rG   r   r   r   Zlabels_maskZtarget_lengthsZflattened_targetsZ	log_probsr   r'   r'   r(   rI     sR   	


zMCTCTForCTC.forward)NNNNNN)rK   rL   rM   r+   r
   r   r   r   r   r   _CTC_EXPECTED_OUTPUT_CTC_EXPECTED_LOSSrA   r   r   r   r   r   r   rI   rO   r'   r'   r9   r(   r     sB    

	r   )r   r   r   );rN   r   typingr   r   r   rA   Ztorch.utils.checkpointr   Zactivationsr   Z
file_utilsr   r	   r
   Zintegrations.deepspeedr   Zintegrations.fsdpr   Zmodeling_attn_mask_utilsr   Zmodeling_outputsr   r   Zmodeling_utilsr   r   r   r   utilsr   Zconfiguration_mctctr   Z
get_loggerrK   loggerr   r   r   r   r   r   Moduler   rP   rh   r_   r   r   r   r   r   r   ZMCTCT_START_DOCSTRINGr   r   r   r   __all__r'   r'   r'   r(   <module>   s\   
=:l
,'E `8d