o
    Zh                     @   s  d dl Z d dlZd dlmZmZmZ d dlZd dlZd dl	m
Z
 d dlm
  mZ d dl	mZ ddlmZ ddlmZ ddlmZ ddlmZmZmZmZmZmZ dd	lmZ dd
lmZm Z m!Z! ddl"m#Z# e!$e%Z&G dd de
j'Z(G dd de
j'Z)G dd de
j'Z*G dd de
j'Z+G dd de
j'Z,G dd de
j'Z-G dd de
j'Z.G dd de
j'Z/G dd de
j'Z0G dd  d e
j'Z1eG d!d" d"eZ2G d#d$ d$e
j'Z3G d%d& d&e
j'Z4G d'd( d(e
j'Z5G d)d* d*e
j'Z6G d+d, d,e
j'Z7G d-d. d.e
j'Z8		 dKd/ee9e9f d0e:d1e9d2eej; d3e9d4ej<fd5d6Z=eZ>eG d7d8 d8e2Z?d9Z@ed:d;G d<d= d=e2ZAed>d;G d?d@ d@e2ZBeG dAdB dBe2ZCG dCdD dDe
j'ZDG dEdF dFe
j'ZEedGd;G dHdI dIe2ZFg dJZGdS )L    N)OptionalTupleUnion)CrossEntropyLoss   )ACT2FN)is_deepspeed_zero3_enabled)is_fsdp_managed_module)BaseModelOutputCausalLMOutputSequenceClassifierOutputTokenClassifierOutputWav2Vec2BaseModelOutputXVectorOutput)PreTrainedModel)auto_docstringis_peft_availablelogging   )WavLMConfigc                       $   e Zd Z fddZdd Z  ZS )WavLMSamePadLayerc                    s*   t    |d dkrd| _d S d| _d S N   r   r   )super__init__num_pad_remove)selfnum_conv_pos_embeddings	__class__ W/var/www/auris/lib/python3.10/site-packages/transformers/models/wavlm/modeling_wavlm.pyr   %   s   
 zWavLMSamePadLayer.__init__c                 C   s,   | j dkr|d d d d d | j  f }|S Nr   )r   r   hidden_statesr!   r!   r"   forward)   s   
zWavLMSamePadLayer.forward__name__
__module____qualname__r   r&   __classcell__r!   r!   r   r"   r   $   s    r   c                       r   )WavLMPositionalConvEmbeddingc                    s$  t    tj|j|j|j|jd |jd| _tjj	}t
tjjdr'tjjj	}t r{dd l}|jj| jjdd || jddd| _W d    n1 sLw   Y  t
| jdrd| jjjj}| jjjj}n| jj}| jj}|j| | |j| | n	|| jddd| _t|j| _t|j | _d S )	Nr   )kernel_sizepaddinggroupsweight_normr   )Zmodifier_rankweight)namedimparametrizations)r   r   nnConv1dhidden_sizer   Znum_conv_pos_embedding_groupsconvutilsr0   hasattrr4   r   	deepspeedzeroZGatheredParametersr1   Z	original0Z	original1weight_gweight_vZregister_external_parameterr   r.   r   feat_extract_activation
activation)r   configr0   r;   r=   r>   r   r!   r"   r   0   s4   

z%WavLMPositionalConvEmbedding.__init__c                 C   s:   | dd}| |}| |}| |}| dd}|S Nr   r   )	transposer8   r.   r@   r$   r!   r!   r"   r&   Q   s   


z$WavLMPositionalConvEmbedding.forwardr'   r!   r!   r   r"   r,   /   s    !r,   c                       r   )WavLMFeatureProjectionc                    sJ   t    tj|jd |jd| _t|jd |j| _	t
|j| _d S )NZeps)r   r   r5   	LayerNormconv_dimlayer_norm_eps
layer_normLinearr7   
projectionDropoutZfeat_proj_dropoutdropoutr   rA   r   r!   r"   r   ]   s   
zWavLMFeatureProjection.__init__c                 C   s&   |  |}| |}| |}||fS N)rJ   rL   rN   )r   r%   Znorm_hidden_statesr!   r!   r"   r&   c   s   


zWavLMFeatureProjection.forwardr'   r!   r!   r   r"   rD   \   s    rD   c                       s   e Zd ZdZ				d"dededed	ed
edef fddZ				d#dej	de
ej	 de
ej	 dedeej	e
ej	 e
eej	  f f
ddZdejdeejejf dejdedejejff
ddZdededejfddZdejdejfd d!Z  ZS )$WavLMAttentionz=Multi-headed attention from 'Attention Is All You Need' paper        @     T	embed_dim	num_headsrN   num_bucketsmax_distancehas_relative_position_biasc                    s   t    || _|| _|| _|| | _| j| | jkr'td| j d| d| jd | _t	||| _
t	||| _t	||| _t	||| _|| _|| _ttd| jdd| _t	| jd| _|rqt| j| j| _d S d S )Nz;embed_dim must be divisible by num_heads (got `embed_dim`: z and `num_heads`: z).g      r      )r   r   rU   rV   rN   Zhead_dim
ValueErrorZscalingr5   rK   k_projv_projq_projout_projrW   rX   	Parametertorchonesgru_rel_pos_constgru_rel_pos_linearZ	Embeddingrel_attn_embed)r   rU   rV   rN   rW   rX   rY   r   r!   r"   r   n   s.   
	

zWavLMAttention.__init__NFr   r%   attention_maskposition_biasoutput_attentionsreturnc                 C   s  |  \}}}|du r$| ||}|d|ddd|| j ||}||jdd | jdf }	|	dddd}	| |	}
|
|	jdd d 	d}
t
|
jddd\}}||| j d	  d
 }||| j dd| }|d||f}| ||||\}}|||fS )z'Attention layer with relative attentionNr   r   rE   r   r   )r      r3         ?g       @)sizecompute_bias	unsqueezerepeatviewrV   shapepermuterd   sumra   Zsigmoidchunkrc   torch_multi_head_self_attention)r   r%   rf   rg   rh   indexZbszZtgt_len_Zgated_hidden_statesZrelative_position_projZgate_aZgate_bZgate_outputgated_position_biasattn_outputattn_weightsr!   r!   r"   r&      s"   	$

zWavLMAttention.forwardry   c                 C   s   | dd } }}|dur|dnd}d }	}
d}tj|||| j| jtdgt| j	j
| jj
| jj
f|	|
|| j| jj| jj
| j|||d| j	j| jj| jjd\}}| dd}|durz|dddf |jdd | jf |jdd  }||fS )zCsimple wrapper around torch's multi_head_attention_forward functionr   r   NFT)Zuse_separate_proj_weightZq_proj_weightZk_proj_weightZv_proj_weight)rC   neFZmulti_head_attention_forwardrU   rV   ra   emptycatr^   biasr\   r]   rN   r_   r1   trainingbroadcast_torr   )r   r%   rf   ry   rh   querykeyvalueZkey_padding_maskZbias_kZbias_vZadd_zero_attnrz   r{   r!   r!   r"   rv      sB   	

"z.WavLMAttention.torch_multi_head_self_attentionquery_length
key_lengthc                 C   sv   t j|t jdd d d f }t j|t jdd d d f }|| }| |}|| jjj}| |}|g d}|S )Ndtype)r   r   r   )	ra   arangelong_relative_positions_buckettore   r1   devicers   )r   r   r   Zcontext_positionZmemory_positionZrelative_positionZrelative_position_bucketvaluesr!   r!   r"   rn      s   

zWavLMAttention.compute_biasrelative_positionsc                 C   s   | j d }|dktj| }t|}|d }||k }t| | }|t| j|  }|||  }|| tj}t	|t
||d }|t|||7 }|S r   )rW   r   ra   r   abslogfloatmathrX   minZ	full_likewhere)r   r   rW   Zrelative_bucketsZ	max_exactZis_smallZrelative_positions_if_largeZrelative_position_if_larger!   r!   r"   r      s   

z)WavLMAttention._relative_positions_bucket)rR   rS   rT   TNNFr   )r(   r)   r*   __doc__intr   boolr   ra   Tensorr   r   r&   FloatTensorr   
LongTensorZ
BoolTensorrv   rn   r   r+   r!   r!   r   r"   rQ   k   s^    '
)

7
rQ   c                       r   )WavLMFeedForwardc                    sp   t    t|j| _t|j|j| _	t
|jtr"t|j | _n|j| _t|j|j| _t|j| _d S rP   )r   r   r5   rM   Zactivation_dropoutintermediate_dropoutrK   r7   Zintermediate_sizeintermediate_dense
isinstanceZ
hidden_actstrr   intermediate_act_fnoutput_densehidden_dropoutoutput_dropoutrO   r   r!   r"   r     s   
zWavLMFeedForward.__init__c                 C   s6   |  |}| |}| |}| |}| |}|S rP   )r   r   r   r   r   r$   r!   r!   r"   r&     s   




zWavLMFeedForward.forwardr'   r!   r!   r   r"   r         r   c                       s2   e Zd Zddedef fddZdd	d
Z  ZS )WavLMEncoderLayerTrA   rY   c                    n   t    t|j|j|j|j|j|d| _t	
|j| _t	j|j|jd| _t|| _t	j|j|jd| _d S N)rU   rV   rN   rW   rX   rY   rF   r   r   rQ   r7   Znum_attention_headsZattention_dropoutrW   Zmax_bucket_distance	attentionr5   rM   r   rN   rG   rI   rJ   r   feed_forwardfinal_layer_normr   rA   rY   r   r!   r"   r   *     

zWavLMEncoderLayer.__init__NFr   c           	      C   sl   |}| j |||||d\}}}| |}|| }| |}|| | }| |}||f}|r4||f7 }|S )Nrf   rg   rh   rw   )r   rN   rJ   r   r   )	r   r%   rf   rg   rh   rw   attn_residualr{   outputsr!   r!   r"   r&   9  s"   



zWavLMEncoderLayer.forwardTr   r(   r)   r*   r   r   r   r&   r+   r!   r!   r   r"   r   )      r   c                       s2   e Zd Zd
dedef fddZddd	Z  ZS ) WavLMEncoderLayerStableLayerNormTrA   rY   c                    r   r   r   r   r   r!   r"   r   S  r   z)WavLMEncoderLayerStableLayerNorm.__init__NFc                 C   sf   |}|  |}| j||||d\}}}| |}|| }|| | | }||f}|r1||f7 }|S )N)rf   rg   rh   )rJ   r   rN   r   r   )r   r%   rf   rg   rh   r   r{   r   r!   r!   r"   r&   b  s   


z(WavLMEncoderLayerStableLayerNorm.forwardr   )NNFr   r!   r!   r   r"   r   R  r   r   c                       .   e Zd Z fddZ				dddZ  ZS )	WavLMEncoderc                    f   t     | _t | _tj j jd| _	t
 j| _t fddt jD | _d| _d S )NrF   c                       g | ]
}t  |d kdqS r   )rY   )r   .0irA   r!   r"   
<listcomp>  s    z)WavLMEncoder.__init__.<locals>.<listcomp>Fr   r   rA   r,   pos_conv_embedr5   rG   r7   rI   rJ   rM   r   rN   
ModuleListrangenum_hidden_layerslayersgradient_checkpointingrO   r   r   r"   r   x  s   


zWavLMEncoder.__init__NFTc                 C   s`  |rdnd }|r
dnd }|d ur"| ddd|jd }d|| < | |}	||	 }| |}| |}t p;t| }
d }t| j	D ]P\}}|rN||f }t
g }| jo_|dko_|| jjk }|rd|
r| jru| jru| |j||||}n	||||||d}|d d \}}|rd}|r||d f }qC|r||f }|stdd	 |||fD S t|||d
S )Nr!   rE   r   r   r   r   NNNc                 s       | ]	}|d ur|V  qd S rP   r!   r   vr!   r!   r"   	<genexpr>      z'WavLMEncoder.forward.<locals>.<genexpr>last_hidden_stater%   
attentions)ro   rp   rr   r   rJ   rN   r   r	   	enumerater   ra   randr   rA   	layerdropr   _gradient_checkpointing_func__call__tupler
   r   r%   rf   rh   output_hidden_statesreturn_dictZall_hidden_statesZall_self_attentionsZexpand_attention_maskZposition_embeddingsZsynced_gpusrg   r   layerZdropout_probabilityZskip_the_layerZlayer_outputsr!   r!   r"   r&     s^   






zWavLMEncoder.forwardNFFTr'   r!   r!   r   r"   r   w  s    r   c                       r   )	WavLMEncoderStableLayerNormc                    r   )NrF   c                    r   r   )r   r   r   r!   r"   r     s    z8WavLMEncoderStableLayerNorm.__init__.<locals>.<listcomp>Fr   rO   r   r   r"   r     s   



z$WavLMEncoderStableLayerNorm.__init__NFTc                 C   s^  |rdnd }|r
dnd }|d ur"| ddd|jd }d|| < | |}	||	 }| |}t p6t| }
d }t| jD ]O\}}|rI||f }t	
g }| joZ|dkoZ|| jjk }|r_|
r| jrp| jrp| |j||||}n|||||d}|d d \}}|rd}|r||d f }q>| |}|r||f }|stdd	 |||fD S t|||d
S )Nr!   rE   r   r   r   )rf   rh   rg   r   c                 s   r   rP   r!   r   r!   r!   r"   r     r   z6WavLMEncoderStableLayerNorm.forward.<locals>.<genexpr>r   )ro   rp   rr   r   rN   r   r	   r   r   ra   r   r   rA   r   r   r   r   rJ   r   r
   r   r!   r!   r"   r&     sX   






z#WavLMEncoderStableLayerNorm.forwardr   r'   r!   r!   r   r"   r     s    r   c                       s4   e Zd ZdZ fddZedd Zdd Z  ZS )WavLMGumbelVectorQuantizerz
    Vector quantization using gumbel softmax. See [CATEGORICAL REPARAMETERIZATION WITH
    GUMBEL-SOFTMAX](https://arxiv.org/pdf/1611.01144.pdf) for more information.
    c                    s   t    |j| _|j| _|j| j dkr"td|j d| j dt	t
d| j| j |j| j | _t|jd | j| j | _d| _d S )Nr   z`config.codevector_dim z5 must be divisible by `config.num_codevector_groups` z for concatenation.r   rE   r   )r   r   Znum_codevector_groups
num_groupsZnum_codevectors_per_groupnum_varsZcodevector_dimr[   r5   r`   ra   r   codevectorsrK   rH   weight_projtemperaturerO   r   r!   r"   r   $  s   


z#WavLMGumbelVectorQuantizer.__init__c                 C   s8   | j dd}ttj|t|d  dd  }|S )Nr   rk   gHz>rE   )meanra   exprt   r   )ZprobsZmarginal_probs
perplexityr!   r!   r"   _compute_perplexity9  s   (z.WavLMGumbelVectorQuantizer._compute_perplexityc                 C   s  |j \}}}| |}||| | j d}| jrAtjj| | j	dd}|
|}tj||| | jd dd}| |}n$|jdd}|j|j  d|ddd}||| | jd}| |}||| d}|d| j }	|	|| | j| jd}
|
d||d}
|
|fS )NrE   T)tauZhardrk   r   rl   )rr   r   rq   r   r   r5   
functionalZgumbel_softmaxr   r   Ztype_asra   softmaxr   argmaxZ	new_zerosZscatter_ro   r   r   rt   )r   r%   
batch_sizesequence_lengthr7   Zcodevector_probsZcodevector_soft_distr   Zcodevector_idxZcodevectors_per_groupr   r!   r!   r"   r&   ?  s*   


z"WavLMGumbelVectorQuantizer.forward)	r(   r)   r*   r   r   staticmethodr   r&   r+   r!   r!   r   r"   r     s    
r   c                   @   sh   e Zd ZeZdZdZdZdZdZ	dd Z
	ddeejef d	ee fd
dZ	ddedejfddZdS )WavLMPreTrainedModelwavlminput_valuesTFc              	   C   s  t |tr|jjjjddd |jjj  tj	
|j dS t |trItj	j|jjddtd|jjd |jj   d tj	|jjd dS t |trqtd|jj }tj	j
|jj| |d tj	j
|jj| |d dS t |tjr|jjjd| jjd |jdur|jj  dS dS t |tjtjfr|jj  |jjd dS t |tjrtj	|j |jdurt|j|j|jd   }tj	j
|j| |d dS dS dS )	zInitialize the weightsrR   r   )r   stdr   r   )abNrl   )r   r   r   r1   dataZnormal_r   Zzero_r5   inituniform_r   r,   r8   r   sqrtr-   Zin_channelsZ	constant_rD   rL   Zin_featuresrK   rA   Zinitializer_rangerG   	GroupNormZfill_r6   Zkaiming_normal_r/   )r   modulekr!   r!   r"   _init_weightsm  s<   

 


z"WavLMPreTrainedModel._init_weightsNinput_lengthsadd_adapterc                 C   sn   |du r| j jn|}dd }t| j j| j jD ]
\}}||||}q|r5t| j jD ]
}||d| j j}q*|S )zH
        Computes the output length of the convolutional layers
        Nc                 S   s   t j| | |ddd S )Nfloor)Zrounding_moder   )ra   divinput_lengthr-   strider!   r!   r"   _conv_out_length  s   zOWavLMPreTrainedModel._get_feat_extract_output_lengths.<locals>._conv_out_lengthr   )rA   r   zipconv_kernelconv_strider   num_adapter_layersadapter_stride)r   r   r   r  r-   r  rx   r!   r!   r"    _get_feat_extract_output_lengths  s   z5WavLMPreTrainedModel._get_feat_extract_output_lengthsfeature_vector_lengthrf   c                 C   s   |j ddd d df }| j||d}|tj}|jd }tj||f|j|jd}d|tj	|jd |jd|d f< |
dg d
dg }|S )NrE   rk   r   r   )r   r   r   )r   )Zcumsumr  r   ra   r   rr   zerosr   r   r   flipr   )r   r  rf   r   Znon_padded_lengthsZoutput_lengthsr   r!   r!   r"   "_get_feature_vector_attention_mask  s   
"z7WavLMPreTrainedModel._get_feature_vector_attention_maskrP   )r(   r)   r*   r   Zconfig_classZbase_model_prefixZmain_input_nameZsupports_gradient_checkpointingZ_supports_flash_attn_2Z_supports_sdpar   r   ra   r   r   r   r   r  r  r!   r!   r!   r"   r   d  s(    "
r   c                       &   e Zd Zd fdd	Zdd Z  ZS )WavLMNoLayerNormConvLayerr   c                    sj   t    |dkr|j|d  nd| _|j| | _tj| j| j|j| |j| |j	d| _
t|j | _d S )Nr   r   r-   r  r   )r   r   rH   in_conv_dimout_conv_dimr5   r6   r  r  	conv_biasr8   r   r?   r@   r   rA   layer_idr   r!   r"   r     s   
z"WavLMNoLayerNormConvLayer.__init__c                 C   s   |  |}| |}|S rP   )r8   r@   r$   r!   r!   r"   r&     s   

z!WavLMNoLayerNormConvLayer.forwardr   r'   r!   r!   r   r"   r    s    r  c                       r  )WavLMLayerNormConvLayerr   c                    s|   t    |dkr|j|d  nd| _|j| | _tj| j| j|j| |j| |j	d| _
tj| jdd| _t|j | _d S )Nr   r   r  T)Zelementwise_affine)r   r   rH   r  r  r5   r6   r  r  r  r8   rG   rJ   r   r?   r@   r  r   r!   r"   r     s   
z WavLMLayerNormConvLayer.__init__c                 C   s:   |  |}|dd}| |}|dd}| |}|S )Nr   rE   )r8   rC   rJ   r@   r$   r!   r!   r"   r&     s   


zWavLMLayerNormConvLayer.forwardr  r'   r!   r!   r   r"   r    s    r  c                       r  )WavLMGroupNormConvLayerr   c                    s   t    |dkr|j|d  nd| _|j| | _tj| j| j|j| |j| |j	d| _
t|j | _tj| j| jdd| _d S )Nr   r   r  T)r   Znum_channelsZaffine)r   r   rH   r  r  r5   r6   r  r  r  r8   r   r?   r@   r   rJ   r  r   r!   r"   r     s   
z WavLMGroupNormConvLayer.__init__c                 C   s"   |  |}| |}| |}|S rP   )r8   rJ   r@   r$   r!   r!   r"   r&     s   


zWavLMGroupNormConvLayer.forwardr  r'   r!   r!   r   r"   r    s    r  c                       s0   e Zd ZdZ fddZdd Zdd Z  ZS )WavLMFeatureEncoderz.Construct the features from raw audio waveformc                    s   t     jdkr t ddg fddt jd D  }n jdkr2 fddt jD }n	td	 j d
t|| _	d| _
d| _d S )Ngroupr   r  c                    s   g | ]
}t  |d  dqS )r   r  )r  r   r   r!   r"   r   	  s    z0WavLMFeatureEncoder.__init__.<locals>.<listcomp>r   r   c                    s   g | ]}t  |d qS )r  )r  r   r   r!   r"   r     s    z`config.feat_extract_norm` is z), but has to be one of ['group', 'layer']FT)r   r   Zfeat_extract_normr  r   Znum_feat_extract_layersr[   r5   r   conv_layersr   _requires_grad)r   rA   r  r   r   r"   r     s   




zWavLMFeatureEncoder.__init__c                 C   s   |   D ]}d|_qd| _d S )NF)
parametersrequires_gradr   r   paramr!   r!   r"   _freeze_parameters  s   
z&WavLMFeatureEncoder._freeze_parametersc                 C   s\   |d d d f }| j r| jrd|_| jD ]}| j r'| jr'| jr'| |j|}q||}q|S )NT)r   r   r"  r  r   r   r   )r   r   r%   Z
conv_layerr!   r!   r"   r&     s   

zWavLMFeatureEncoder.forward)r(   r)   r*   r   r   r%  r&   r+   r!   r!   r   r"   r    s
    r  c                       r   )WavLMAdapterLayerc                    s0   t    tj|jd|j |j|jdd| _d S )Nr   r   )r  r.   )r   r   r5   r6   output_hidden_sizeZadapter_kernel_sizer
  r8   rO   r   r!   r"   r   /  s   
zWavLMAdapterLayer.__init__c                 C   s   |  |}tjj|dd}|S )Nr   rk   )r8   r5   r   Zglur$   r!   r!   r"   r&   9  s   
zWavLMAdapterLayer.forwardr'   r!   r!   r   r"   r&  .  s    
r&  c                       r   )WavLMAdapterc                    sp   t     j jkrt j j| _t j| _nd  | _| _t	 fddt
 jD | _ j| _d S )Nc                 3   s    | ]}t  V  qd S rP   )r&  r   rx   r   r!   r"   r   K  s    z(WavLMAdapter.__init__.<locals>.<genexpr>)r   r   r'  r7   r5   rK   projrG   proj_layer_normr   r   r	  r   r   rO   r   r   r"   r   A  s   
 zWavLMAdapter.__init__c                 C   sr   | j d ur| jd ur|  |}| |}|dd}| jD ]}tj }| jr,|| jkr0||}q|dd}|S rB   )r*  r+  rC   r   nprandomr   r   )r   r%   r   Zlayerdrop_probr!   r!   r"   r&   N  s   



zWavLMAdapter.forwardr'   r!   r!   r   r"   r(  @  r   r(  rr   	mask_probmask_lengthrf   	min_masksri   c                    s  | \}dk rt dkrt d d dtjd   fdd}|dur:| d	 n
fd
dt|D }tj	|ft
d}g }	|}
|
dkrZ|S |D ];}||}tjjt|d  |dd}t|dkr}d }n|d }t|tj|
| tjd| g}|	| q\t|	}	t|	dddddf ||
f}	|	||
 }	tddddf }t|||
f||
 }|	| }	|	 d krd |	|	d k< t||	dd	 |S )af  
    Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
    ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
    CPU as part of the preprocessing during training.

    Args:
        shape: The shape for which to compute masks. This should be of a tuple of size 2 where
               the first element is the batch size and the second element is the length of the axis to span.
        mask_prob:  The percentage of the whole axis (between 0 and 1) which will be masked. The number of
                    independently generated mask spans of length `mask_length` is computed by
                    `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
                    actual percentage will be smaller.
        mask_length: size of the mask
        min_masks: minimum number of masked spans
        attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
                        each batch dimension.
    r   z&`mask_length` has to be bigger than 0.zO`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: z and `sequence_length`: `c                    sX   t |     }t|}| kr }| d  |k r*t| d  d}|S )z;Given input length, compute how many spans should be maskedr   r   )r   max)r  num_masked_spanepsilonr/  r.  r0  r   r!   r"   compute_num_masked_span  s   
z6_compute_mask_indices.<locals>.compute_num_masked_spanNrE   c                    s   g | ]} qS r!   r!   r)  )r   r!   r"   r     s    z)_compute_mask_indices.<locals>.<listcomp>r   r   F)replace)r[   r,  r-  r   itemdetachrt   tolistr   r  r   choicer   lenZconcatenaterb   Zint32appendarrayr   Zreshaper2  Zput_along_axis)rr   r.  r/  rf   r0  r   r6  r   Zspec_aug_maskZspec_aug_mask_idxsZmax_num_masked_spanr  r3  Zspec_aug_mask_idxZdummy_mask_idxoffsetsr!   r4  r"   _compute_mask_indices_  s\   

r@  c                       s   e Zd Zdef fddZdd Zdd Z		dd	ejd
e	ej de	ej
 fddZe					dde	ej de	ej d
e	ej de	e de	e de	e deeef fddZ  ZS )
WavLMModelrA   c                    s   t  | || _t|| _t|| _|jdks|jdkr)t	
t|j | _|jr2t|| _nt|| _|jr>t|nd | _|   d S )NrR   )r   r   rA   r  feature_extractorrD   feature_projectionmask_time_probmask_feature_probr5   r`   ra   r   r7   r   masked_spec_embedZdo_stable_layer_normr   encoderr   r   r(  adapter	post_initrO   r   r!   r"   r     s   


zWavLMModel.__init__c                 C      t dt |   dS z
        Calling this function will disable the gradient computation for the feature encoder so that its parameters will
        not be updated during training.
        The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. Please use the equivalent `freeze_feature_encoder` method instead.NwarningswarnFutureWarningfreeze_feature_encoderr   r!   r!   r"   freeze_feature_extractor  
   z#WavLMModel.freeze_feature_extractorc                 C   s   | j   dS 
        Calling this function will disable the gradient computation for the feature encoder so that its parameter will
        not be updated during training.
        N)rB  r%  rR  r!   r!   r"   rQ    s   z!WavLMModel.freeze_feature_encoderNr%   mask_time_indicesrf   c                 C   s  t | jdds	|S | \}}}|dur| j|j||< n-| jjdkrK| jrKt||f| jj| jj	|| jj
d}tj||jtjd}| j|j||< | jjdkr| jrt||f| jj| jj| jjd}tj||jtjd}|dddf d|d}d||< |S )	z
        Masks extracted features along time axis and/or along feature axis according to
        [SpecAugment](https://arxiv.org/abs/1904.08779).
        Zapply_spec_augmentTNr   )r.  r/  rf   r0  )r   r   )r.  r/  r0  rE   )getattrrA   rm   rF  r   r   rD  r   r@  Zmask_time_lengthZmask_time_min_masksra   Ztensorr   r   rE  Zmask_feature_lengthZmask_feature_min_masksexpand)r   r%   rW  rf   r   r   r7   Zmask_feature_indicesr!   r!   r"   _mask_hidden_states  s4   zWavLMModel._mask_hidden_statesr   rh   r   r   ri   c           
      C   s   |dur|n| j j}|dur|n| j j}|dur|n| j j}| |}|dd}|dur8| j|jd |dd}| |\}}| j	|||d}| j
|||||d}	|	d }| jdur_| |}|sk||f|	dd  S t|||	j|	jd	S )
a/  
        mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
            masked extracted features in *config.proj_codevector_dim* space.
        Nr   r   Fr  )rW  rf   rf   rh   r   r   r   )r   extract_featuresr%   r   )rA   rh   r   use_return_dictrB  rC   r  rr   rC  rZ  rG  rH  WavLMBaseModelOutputr%   r   )
r   r   rf   rW  rh   r   r   r\  r%   Zencoder_outputsr!   r!   r"   r&   0  s@   


zWavLMModel.forward)NNNNNNN)r(   r)   r*   r   r   rS  rQ  ra   r   r   r   rZ  r   r   r   r   r   r^  r&   r+   r!   r!   r   r"   rA    sD    

.
rA  r   zm
    WavLM Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).
    )Zcustom_introc                       s   e Zd Zddee f fddZdd Zdd Zd	d
 Zdd Z	e
					ddeej deej dee dee dee deej deeef fddZ  ZS )WavLMForCTCNtarget_langc                    s~   t  | t|| _t|j| _|| _|j	du r#t
d| j dt|dr.|jr.|jn|j}t||j	| _|   dS )a/  
        target_lang (`str`, *optional*):
            Language id of adapter weights. Adapter weights are stored in the format adapter.<lang>.safetensors or
            adapter.<lang>.bin. Only relevant when using an instance of [`WavLMForCTC`] with adapters. Uses 'eng' by
            default.
        NzYou are trying to instantiate z with a configuration that does not define the vocabulary size of the language model head. Please instantiate the model as follows: `WavLMForCTC.from_pretrained(..., vocab_size=vocab_size)`. or define `vocab_size` of your model's configuration.r   )r   r   rA  r   r5   rM   Zfinal_dropoutrN   ra  
vocab_sizer[   r    r:   r   r'  r7   rK   lm_headrI  )r   rA   ra  r'  r   r!   r"   r   t  s   

zWavLMForCTC.__init__c                 C   sv   | j }|durt| jdddu rtd| d|du r,t| jdddur,td dS |dur9| j|dd dS dS )a'  
        This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when
        passing `target_lang=...` to `from_pretrained(...)`.

        This method is **not** supposed to be called by the user and is prone to be changed in the future.
        NZadapter_attn_dimzCannot pass `target_lang`: z- if `config.adapter_attn_dim` is not defined.z)By default `target_lang` is set to 'eng'.T)Z
force_load)ra  rX  rA   r[   loggerinfoZload_adapter)r   ra  r!   r!   r"   tie_weights  s   zWavLMForCTC.tie_weightsc                 C   rJ  rV  rL  NrM  rR  r!   r!   r"   rS    rT  z$WavLMForCTC.freeze_feature_extractorc                 C      | j j  dS rU  r   rB  r%  rR  r!   r!   r"   rQ       z"WavLMForCTC.freeze_feature_encoderc                 C      | j  D ]}d|_qdS z
        Calling this function will disable the gradient computation for the base model so that its parameters will not
        be updated during training. Only the classification head will be updated.
        FNr   r!  r"  r#  r!   r!   r"   freeze_base_model     zWavLMForCTC.freeze_base_modelr   rf   rh   r   r   labelsri   c              
   C   s|  |dur|n| j j}|dur| | j jkrtd| j j | j|||||d}|d }| |}| |}	d}
|dur|durC|ntj	|tj
d}| |dtj
}|dk}|d}||}tjj|	dtjddd}tjjjd	d
 tjj||||| j j| j j| j jd}
W d   n1 sw   Y  |s|	f|td  }|
dur|
f| S |S t|
|	|j|jdS )a  
        labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
            Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
            the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
            All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
            config.vocab_size - 1]`.
        Nz$Label values must be <= vocab_size: r[  r   r   rE   )r3   r   r   F)enabled)blankZ	reductionZzero_infinitylosslogitsr%   r   )rA   r]  r2  rb  r[   r   rN   rc  ra   Z	ones_liker   r  rt   r   Zmasked_selectr5   r   Zlog_softmaxZfloat32rC   backendsZcudnnflagsZctc_lossZpad_token_idZctc_loss_reductionZctc_zero_infinity_HIDDEN_STATES_START_POSITIONr   r%   r   )r   r   rf   rh   r   r   rp  r   r%   ru  rt  r   Zlabels_maskZtarget_lengthsZflattened_targetsZ	log_probsoutputr!   r!   r"   r&     sN   



zWavLMForCTC.forwardrP   r_  )r(   r)   r*   r   r   r   rf  rS  rQ  rn  r   ra   r   r   r   r   r   r&   r+   r!   r!   r   r"   r`  n  s6    
r`  z
    WavLM Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like
    SUPERB Keyword Spotting.
    c                       s   e Zd Z fddZdd Zdd Zdd Ze										dd
ee	j
 dee	j
 dee dee dee dee	j
 deeef fddZ  ZS )WavLMForSequenceClassificationc                    s   t  | t|dr|jrtdt|| _|jd }|jr*t	
t|| | _t	|j|j| _t	|j|j| _|   d S )Nr   z\Sequence classification does not support the use of WavLM adapters (config.add_adapter=True)r   )r   r   r:   r   r[   rA  r   r   use_weighted_layer_sumr5   r`   ra   rb   layer_weightsrK   r7   Zclassifier_proj_size	projector
num_labels
classifierrI  r   rA   
num_layersr   r!   r"   r     s   

z'WavLMForSequenceClassification.__init__c                 C   rJ  rK  rM  rR  r!   r!   r"   rS  !  rT  z7WavLMForSequenceClassification.freeze_feature_extractorc                 C   rh  rU  ri  rR  r!   r!   r"   rQ  -  rj  z5WavLMForSequenceClassification.freeze_feature_encoderc                 C   rk  rl  rm  r#  r!   r!   r"   rn  4  ro  z0WavLMForSequenceClassification.freeze_base_modelNr   rf   rh   r   r   rp  ri   c                 C   sz  |dur|n| j j}| j jrdn|}| j|||||d}| j jrB|t }tj|dd}tjj	| j
dd}	||	ddd jdd}n|d }| |}|du rV|jdd}
n+| |jd |}|ddd|jd }d	|| < |jdd|jdddd }
| |
}d}|durt }||d| j j|d}|s|f|td  }|dur|f| S |S t|||j|jd
S )  
        input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
            Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
            into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
            soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and
            conversion into a tensor of type `torch.FloatTensor`. See [`WavLMProcessor.__call__`] for details.
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        NTr[  r   rk   rE   r   r   rR   rs  )rA   r]  r{  r   rx  ra   stackr5   r   r   r|  rq   rt   r}  r   r  rr   ro   rp   r  r   r~  r   r%   r   )r   r   rf   rh   r   r   rp  r   r%   norm_weightsZpooled_outputZpadding_maskZexpand_padding_maskru  rt  loss_fctry  r!   r!   r"   r&   <  sH   

 
z&WavLMForSequenceClassification.forwardr_  )r(   r)   r*   r   rS  rQ  rn  r   r   ra   r   r   r   r   r   r&   r+   r!   r!   r   r"   rz  	  s4    
rz  c                       s   e Zd Z fddZdd Zdd Zdd Ze										dd
ee	j
 dee	j
 dee	j
 dee dee dee deeef fddZ  ZS ) WavLMForAudioFrameClassificationc                    sz   t  | t|dr|jrtdt|| _|jd }|jr*t	
t|| | _t	|j|j| _|j| _|   d S )Nr   z_Audio frame classification does not support the use of WavLM adapters (config.add_adapter=True)r   )r   r   r:   r   r[   rA  r   r   r{  r5   r`   ra   rb   r|  rK   r7   r~  r  init_weightsr  r   r!   r"   r     s   

z)WavLMForAudioFrameClassification.__init__c                 C   rJ  rg  rM  rR  r!   r!   r"   rS    rT  z9WavLMForAudioFrameClassification.freeze_feature_extractorc                 C   rh  rU  ri  rR  r!   r!   r"   rQ    rj  z7WavLMForAudioFrameClassification.freeze_feature_encoderc                 C   rk  rl  rm  r#  r!   r!   r"   rn    ro  z2WavLMForAudioFrameClassification.freeze_base_modelNr   rf   rp  rh   r   r   ri   c                 C   s   |dur|n| j j}| j jrdn|}| j|||||d}| j jrB|t }tj|dd}tjj	| j
dd}	||	ddd jdd}n|d }| |}
d}|durht }||
d| jtj|d| jdd}|su|
f|td  }|S t||
|j|jd	S )
r  NTr[  r   rk   rE   r   )Zaxisrs  )rA   r]  r{  r   rx  ra   r  r5   r   r   r|  rq   rt   r  r   r~  r   r   r%   r   )r   r   rf   rp  rh   r   r   r   r%   r  ru  rt  r  ry  r!   r!   r"   r&     s:   
(z(WavLMForAudioFrameClassification.forwardr_  )r(   r)   r*   r   rS  rQ  rn  r   r   ra   r   r   r   r   r   r&   r+   r!   r!   r   r"   r    s4    
r  c                       s&   e Zd Zd fdd	Zdd Z  ZS )AMSoftmaxLoss      >@皙?c                    sF   t t|   || _|| _|| _tjt	||dd| _
t | _d S )NT)r"  )r   r  r   scalemarginr~  r5   r`   ra   Zrandnr1   r   rt  )r   Z	input_dimr~  r  r  r   r!   r"   r     s   zAMSoftmaxLoss.__init__c           	      C   sx   |  }tjj| jdd}tjj|dd}t||}|| j }tj|| j	}| j
t| || }| ||}|S )Nr   rk   r   )flattenr5   r   	normalizer1   ra   mmr  Zone_hotr~  r  r   r   rt  )	r   r%   rp  r1   Z	cos_thetapsiZonehotru  rt  r!   r!   r"   r&     s   
zAMSoftmaxLoss.forward)r  r  r'   r!   r!   r   r"   r    s    r  c                       s4   e Zd Zd fdd	ZdejdejfddZ  ZS )		TDNNLayerr   c                    sv   t    |dkr|j|d  n|j| | _|j| | _|j| | _|j| | _t	
| j| j | j| _t	 | _d S )Nr   r   )r   r   tdnn_dimr  r  tdnn_kernelr-   Ztdnn_dilationdilationr5   rK   kernelZReLUr@   r  r   r!   r"   r     s   
"zTDNNLayer.__init__r%   ri   c                 C   s   t  r	ddlm} t  rt| j|rtd |dd}| jj	| j
| j| jdd}tjj||| jj| jd}|dd}| |}|S )Nr   )	LoraLayerzDetected LoRA on TDNNLayer. LoRA weights won't be applied due to optimization. You should exclude TDNNLayer from LoRA's target modules.r   r   )r  )r   Zpeft.tuners.lorar  r   r  rN  rO  rC   r1   rq   r  r-   r  r5   r   Zconv1dr   r  r@   )r   r%   r  r1   r!   r!   r"   r&     s    
zTDNNLayer.forwardr  )r(   r)   r*   r   ra   r   r&   r+   r!   r!   r   r"   r    s    
r  zi
    WavLM Model with an XVector feature extraction head on top for tasks like Speaker Verification.
    c                       s   e Zd Z fddZdd Zdd Zdd Zd	eej	e
f fd
dZe					ddeej deej dee dee dee deej deeef fddZ  ZS )WavLMForXVectorc                    s   t    t | _ jd } jrtt	|| | _
t j jd | _ fddtt jD }t|| _t jd d  j| _t j j| _t j j| _|   d S )Nr   r   c                    s   g | ]}t  |qS r!   )r  r   r   r!   r"   r   0  s    z,WavLMForXVector.__init__.<locals>.<listcomp>rE   r   )r   r   rA  r   r   r{  r5   r`   ra   rb   r|  rK   r7   r  r}  r   r<  r   tdnnZxvector_output_dimrB  r  r  r~  	objectiver  )r   rA   r  Ztdnn_layersr   r   r"   r   '  s   

zWavLMForXVector.__init__c                 C   rJ  rg  rM  rR  r!   r!   r"   rS  :  rT  z(WavLMForXVector.freeze_feature_extractorc                 C   rh  rU  ri  rR  r!   r!   r"   rQ  F  rj  z&WavLMForXVector.freeze_feature_encoderc                 C   rk  rl  rm  r#  r!   r!   r"   rn  M  ro  z!WavLMForXVector.freeze_base_modelr   c                 C   s&   dd }| j jD ]}|||d}q|S )z?
        Computes the output length of the TDNN layers
        c                 S   s   | | | d S )Nr   r!   r  r!   r!   r"   r  Z  s   zBWavLMForXVector._get_tdnn_output_lengths.<locals>._conv_out_lengthr   )rA   r  )r   r   r  r-   r!   r!   r"   _get_tdnn_output_lengthsU  s   z(WavLMForXVector._get_tdnn_output_lengthsNr   rf   rh   r   r   rp  ri   c                 C   s  |dur|n| j j}| j jrdn|}| j|||||d}| j jrB|t }tj|dd}tjj	| j
dd}	||	ddd jdd}n|d }| |}| jD ]}
|
|}qN|du rf|jdd}|jdd}nC| |jdd}| |}g }g }t|D ]"\}}|||d|f jdd |||d|f jdd q|t|}t|}tj||gdd}| |}| |}d}|dur| ||}|s||f|td  }|dur|f| S |S t||||j|jdS )	r  NTr[  r   rk   rE   r   )rt  ru  Z
embeddingsr%   r   )rA   r]  r{  r   rx  ra   r  r5   r   r   r|  rq   rt   r}  r  r   r   r  r  r   r=  r   rB  r  r  r   r%   r   )r   r   rf   rh   r   r   rp  r   r%   r  Z
tdnn_layerZmean_featuresZstd_featuresZfeat_extract_output_lengthsZtdnn_output_lengthsr   lengthZstatistic_poolingZoutput_embeddingsru  rt  ry  r!   r!   r"   r&   d  s\   



 



zWavLMForXVector.forwardr_  )r(   r)   r*   r   rS  rQ  rn  r   ra   r   r   r  r   r   r   r   r   r   r&   r+   r!   r!   r   r"   r  !  s6    
r  )r  r`  rz  r  rA  r   r#   )Hr   rN  typingr   r   r   numpyr,  ra   Ztorch.nnr5   Ztorch.nn.functionalr   r}   r   Zactivationsr   Zintegrations.deepspeedr   Zintegrations.fsdpr	   Zmodeling_outputsr
   r   r   r   r   r   Zmodeling_utilsr   r9   r   r   r   Zconfiguration_wavlmr   Z
get_loggerr(   rd  Moduler   r,   rD   rQ   r   r   r   r   r   r   r   r  r  r  r  r&  r(  r   r   r   Zndarrayr@  r^  rA  rx  r`  rz  r  r  r  r  __all__r!   r!   r!   r"   <module>   s    
- ')%STFU,#

w  rh  