o
    Zh]                     @   s  d Z ddlmZ ddlmZmZmZ ddlZddlm	Z
 ddlZddlmZ ddlZddlmZmZmZ ddlmZ ddlmZmZ ddlmZ d	d
lmZmZ d	dlmZmZm Z m!Z! d	dl"m#Z#m$Z$m%Z%m&Z& ddl'm(Z( e&)e*Z+ej,j-G dd de#Z.ej,j-G dd de#Z/		dUdee0e0f de1de0deej2 de0dej2fddZ3dVdede0deej2 fddZ4dZ5d Z6G d!d" d"e
j7Z8G d#d$ d$e
j7Z9G d%d& d&e
j7Z:G d'd( d(e
j7Z;G d)d* d*e
j7Z<G d+d, d,e
j7Z=G d-d. d.e
j7Z>G d/d0 d0e
j7Z?G d1d2 d2e
j7Z@G d3d4 d4e
j7ZAG d5d6 d6e
j7ZBG d7d8 d8e
j7ZCG d9d: d:e
j7ZDG d;d< d<e
j7ZEG d=d> d>e
j7ZFG d?d@ d@eZGG dAdB dBe
j7ZHe$dCe5G dDdE dEeGZIdFZJe!eIe6eJ  e eIe.e(dG G dHdI dIe
j7ZKe$dJe5G dKdL dLeGZLdMZMe!eLe6eM  e eLee(dG G dNdO dOe
j7ZNe$dPe5G dQdR dReGZOdSZPe!eOe6eP  e eOe/e(dG g dTZQdS )WzFlax Wav2Vec2 model.    )partial)OptionalTupleUnionN)
FrozenDictfreezeunfreeze)dot_product_attention_weights)flatten_dictunflatten_dict)lax   )FlaxBaseModelOutputFlaxCausalLMOutput)ACT2FNFlaxPreTrainedModel append_replace_return_docstringsoverwrite_call_docstring)ModelOutputadd_start_docstrings%add_start_docstrings_to_model_forwardlogging   )Wav2Vec2Configc                   @   sZ   e Zd ZU dZdZejed< dZejed< dZ	e
eej  ed< dZe
eej  ed< dS )FlaxWav2Vec2BaseModelOutputa  
    Output type of [`FlaxWav2Vec2BaseModelOutput`], with potential hidden states and attentions.

    Args:
        last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        extract_features (`jnp.ndarray` of shape `(batch_size, sequence_length, last_conv_dim)`):
            Sequence of extracted feature vectors of the last convolutional layer of the model with `last_conv_dim`
            being the dimension of the last convolutional layer.
        hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
            `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    Nlast_hidden_stateextract_featureshidden_states
attentions)__name__
__module____qualname____doc__r   jnpndarray__annotations__r   r   r   r   r    r&   r&   b/var/www/auris/lib/python3.10/site-packages/transformers/models/wav2vec2/modeling_flax_wav2vec2.pyr   ,   s   
 r   c                   @   sh   e Zd ZU dZdZejed< dZejed< dZ	ejed< dZ
eeej  ed< dZeeej  ed< dS ) FlaxWav2Vec2ForPreTrainingOutputa%  
    Output type of [`FlaxWav2Vec2ForPreTrainingOutput`], with potential hidden states and attentions.

    Args:
        loss (*optional*, returned when model is in train mode, `jnp.ndarray` of shape `(1,)`):
            Total loss as the sum of the contrastive loss (L_m) and the diversity loss (L_d) as stated in the [official
            paper](https://arxiv.org/pdf/2006.11477.pdf) . (classification) loss.
        projected_states (`jnp.ndarray` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
            Hidden-states of the model projected to *config.proj_codevector_dim* that can be used to predict the masked
            projected quantized states.
        projected_quantized_states (`jnp.ndarray` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
            Quantized extracted feature vectors projected to *config.proj_codevector_dim* representing the positive
            target vectors for contrastive loss.
        hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
            `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    Nprojected_statesprojected_quantized_statescodevector_perplexityr   r   )r   r    r!   r"   r)   r#   r$   r%   r*   r+   r   r   r   r   r&   r&   r&   r'   r(   J   s   
 r(   shape	mask_probmask_lengthattention_mask	min_masksreturnc           	         sB  | \} dk rt d krt d  d dt|   tjd  t|  kr:  tj|ftd}t	 fddt
|D }t|d	d	d	d	d	f | f}||  }t d	d	d	d	f }t|| f|  }|| }t||dd
 |d	urt||d}|S )aw  
    Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
    ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
    CPU as part of the preprocessing during training.

    Args:
        shape: the shape for which to compute masks.
            should be of size 2 where first element is batch size and 2nd is timesteps
        mask_prob:
            probability for each token to be chosen as start of the span to be masked. this will be multiplied by
            number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
            however due to overlaps, the actual number will be smaller (unless no_overlap is True)
        mask_length: size of the mask
        min_masks: minimum number of masked spans

    r   z&`mask_length` has to be bigger than 0.zO`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: z and `sequence_length`: `dtypec                    s,   g | ]}t jjt  d   ddqS )r   F)replace)nprandomchoicearange).0_r.   Znum_masked_spanssequence_lengthr&   r'   
<listcomp>   s    z)_compute_mask_indices.<locals>.<listcomp>NF)
ValueErrorintr6   r7   Zranditemmaxzerosboolarrayrangebroadcast_toreshaper9   Zput_along_axiswhere)	r,   r-   r.   r/   r0   
batch_sizeZspec_aug_maskZspec_aug_mask_idxsoffsetsr&   r<   r'   _compute_mask_indicesm   s<    
$rM   features_shapenum_negativesc                 C   s   | \}}}|dkrt d|||f dg }t|D ]#}|dur(||  d n|d }tjjd||| fd}	||	 qtj|tjd}t	t
|dddf ||f }
|||
k  d7  < td|D ]}||  || 7  < qh|S )z>
    Sample `num_negatives` vectors from feature vectors.
    r   zl`features should have `sequence_length` > 1, but are of shape (batch_size, sequence_length, hidden_size) = ().Nr   )sizer3   )r@   rG   sumr6   r7   randintappendZasarrayZint32rH   r9   flatten)rN   rO   r/   rK   r=   hidden_sizeZsampled_negative_indicesZ	batch_idxhighZsampled_indices_sliceZfeature_indicesr&   r&   r'   _sample_negative_indices   s$   
 &rX   a  
    Wav2Vec2 was proposed in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech
    Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael
    Auli.

    This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a Flax Linen
    [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
    regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.

    Finally, this model supports inherent JAX features such as:

    - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
    - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
    - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
    - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)

    Parameters:
        config ([`Wav2Vec2Config`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
        dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
            The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
            `jax.numpy.bfloat16` (on TPUs).

            This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
            specified all the computation will be performed with the given `dtype`.

            **Note that this only specifies the dtype of the computation and does not influence the dtype of model
            parameters.**

            If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
            [`~FlaxPreTrainedModel.to_bf16`].
a	  
    Args:
        input_values (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
            Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
            into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
            soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and
            conversion into a tensor of type `jnp.ndarray`. See [`Wav2Vec2Processor.__call__`] for details.
        attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0,
            1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask) .. warning:: `attention_mask` should only be passed
            if the corresponding processor has `config.return_attention_mask == True`. For all models whose processor
            has `config.return_attention_mask == False`, such as
            [wav2vec2-base](https://huggingface.co/facebook/wav2vec2-base-960h), `attention_mask` should **not** be
            passed to avoid degraded performance when doing batched inference. For such models `input_values` should
            simply be padded with 0 and passed without `attention_mask`. Be aware that these models also yield slightly
            different results depending on whether `input_values` is padded or not.
        mask_time_indices (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
            Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
            masked extracted features in *config.proj_codevector_dim* space.
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
c                   @   sB   e Zd ZU eed< dZeed< ejZ	ej	ed< dd Z
dd Zd	S )
FlaxWav2Vec2LayerNormConvLayerconfigr   layer_idr4   c              	   C   s   | j dkr| jj| j  nd| _| jj| j  | _tj| jj| j  | jj| j  f| jj| j  f| jj	t
jj d| jd| _tj| jj| jd| _t| jj | _d S )Nr   r   VALID)featureskernel_sizestridesuse_biaskernel_initpaddingr4   epsilonr4   )r[   rZ   Zconv_dimZin_conv_dimZout_conv_dimnnConvconv_kernelconv_strideZ	conv_biasjaxinitializers	he_normalr4   conv	LayerNormlayer_norm_eps
layer_normr   feat_extract_activation
activationselfr&   r&   r'   setup&  s   
	z$FlaxWav2Vec2LayerNormConvLayer.setupc                 C   s"   |  |}| |}| |}|S N)rl   ro   rq   rs   r   r&   r&   r'   __call__6  s   


z'FlaxWav2Vec2LayerNormConvLayer.__call__N)r   r    r!   r   r%   r[   rA   r#   float32r4   rt   rw   r&   r&   r&   r'   rY   !  s   
 rY   c                   @   s>   e Zd ZU eed< ejZejed< dd Zdd Z	dd Z
d	S )
FlaxConvWithWeightNormrZ   r4   c                    s   t j jj jjftj j d jj j	d _
 j
j j
j j
j  j
jd f} dtj j | _ d fdd _ dtj jj j
jf _ j
jd d	  _d S )
Nr\   )r]   r^   ra   rb   feature_group_countr4   r   weight_vweight_gc                    s    t jj jddd d d d f S N)r   r   Zaxis)r#   linalgnormr{   )r;   rr   r&   r'   <lambda>P  s     z.FlaxConvWithWeightNorm.setup.<locals>.<lambda>bias   )re   rf   rZ   rV   num_conv_pos_embeddingsri   rj   rk   Znum_conv_pos_embedding_groupsr4   rl   r]   rz   r^   paramr{   r|   rD   r   prev_padding)rs   Zweight_shaper&   rr   r'   rt   A  s    
	
zFlaxConvWithWeightNorm.setupc                 C   s@   t jj| jddd d d d f }t | j|}t || j}|S r}   )r#   r   r   r{   dividemultiplyr|   )rs   Zweight_v_normZnormed_weight_vZnormed_kernelr&   r&   r'   _get_normed_weightsT  s    z*FlaxConvWithWeightNorm._get_normed_weightsc                 C   sB   |   }t|d| j| jfdf}| jd|j| jdi|}|S )N)r   r   params)kernelr   )r   r#   padr   rl   applyTr   )rs   r   r   r&   r&   r'   rw   Z  s   zFlaxConvWithWeightNorm.__call__N)r   r    r!   r   r%   r#   rx   r4   rt   r   rw   r&   r&   r&   r'   ry   =  s   
 ry   c                   @   6   e Zd ZU eed< ejZejed< dd Zdd Z	dS )#FlaxWav2Vec2PositionalConvEmbeddingrZ   r4   c                 C   sD   t | j| jd| _t| jj | _| jjd dkrd| _d S d| _d S )Nr3   r   r   r   )	ry   rZ   r4   rl   r   rp   rq   r   num_pad_removerr   r&   r&   r'   rt   e  s   $z)FlaxWav2Vec2PositionalConvEmbedding.setupc                 C   sT   | d}| |}| jdkr|d d d | j d d f }| |}| d}|S )N)r   r   r   r   )Z	transposerl   r   rq   rv   r&   r&   r'   rw   j  s   




z,FlaxWav2Vec2PositionalConvEmbedding.__call__N
r   r    r!   r   r%   r#   rx   r4   rt   rw   r&   r&   r&   r'   r   a  s
   
 r   c                   @   r   )FlaxConvLayersCollectionrZ   r4   c                    sT    j jdkr fddt j jD  _d S  j jdkr tdtd j j d)Nlayerc                    s$   g | ]}t  j|t| jd qS ))r[   namer4   )rY   rZ   strr4   r:   irr   r&   r'   r>   }  s    z2FlaxConvLayersCollection.setup.<locals>.<listcomp>groupzFAt the moment only ``config.feat_extact_norm == 'layer'`` is supportedz`config.feat_extract_norm` is z), but has to be one of ['group', 'layer'])rZ   Zfeat_extract_normrG   Znum_feat_extract_layerslayersNotImplementedErrorr@   rr   r&   rr   r'   rt   {  s   

zFlaxConvLayersCollection.setupc                 C   s    t | jD ]\}}||}q|S ru   )	enumerater   )rs   r   r   
conv_layerr&   r&   r'   rw     s   
z!FlaxConvLayersCollection.__call__Nr   r&   r&   r&   r'   r   w  s
   
 r   c                   @   s<   e Zd ZU dZeed< ejZejed< dd Z	d
ddZ
d	S )FlaxWav2Vec2FeatureEncoderz.Construct the features from raw audio waveformrZ   r4   c                 C   s   t | j| jd| _d S )Nr3   )r   rZ   r4   conv_layersrr   r&   r&   r'   rt     s   z FlaxWav2Vec2FeatureEncoder.setupFc                 C   s4   |d d d d d f }|  |}|rtj|}|S ru   )r   ri   r   Zstop_gradient)rs   input_valuesfreeze_feature_encoderr   r&   r&   r'   rw     s
   
z#FlaxWav2Vec2FeatureEncoder.__call__N)F)r   r    r!   r"   r   r%   r#   rx   r4   rt   rw   r&   r&   r&   r'   r     s   
 r   c                   @   8   e Zd ZU eed< ejZejed< dd Zd	ddZ	dS )
FlaxWav2Vec2FeatureProjectionrZ   r4   c                 C   sR   t j| jj| jd| _t j| jjtj j	
| jj| jd| _t j| jjd| _d S )Nrc   ra   r4   Zrate)re   rm   rZ   rn   r4   ro   DenserV   ri   rj   normalinitializer_range
projectionDropoutZfeat_proj_dropoutdropoutrr   r&   r&   r'   rt     s   z#FlaxWav2Vec2FeatureProjection.setupTc                 C   s*   |  |}| |}| j||d}||fS Ndeterministic)ro   r   r   )rs   r   r   Znorm_hidden_statesr&   r&   r'   rw     s   

z&FlaxWav2Vec2FeatureProjection.__call__NTr   r&   r&   r&   r'   r     s
   
 	r   c                   @   s   e Zd ZU eed< eed< eed< dZeed< dZe	ed< e
jZe
jed< dddZdd Zdd Z	
	
	dde
jdee
j dee
j de	d	ee
j f
ddZd
S )FlaxWav2Vec2AttentionrZ   	embed_dim	num_heads        r   Tr   r4   r1   Nc                 C   s   | j | j | _| j| j | j krtd| j  d| j dttj| j | j| jt	jj
| jjd}| | | | _| _| _| | _tj| jd| _d S )Nz;embed_dim must be divisible by num_heads (got `embed_dim`: z and `num_heads`: rP   )r`   r4   ra   r   )r   r   head_dimr@   r   re   r   r   r4   ri   rj   r   rZ   r   q_projk_projv_projout_projr   r   Zdropout_layer)rs   Zdenser&   r&   r'   rt     s"   
zFlaxWav2Vec2Attention.setupc                 C   s    | |jd d | j| jf S Nr   )rI   r,   r   r   rv   r&   r&   r'   _split_heads  s    z"FlaxWav2Vec2Attention._split_headsc                 C   s   | |jd d | jf S r   )rI   r,   r   rv   r&   r&   r'   _merge_heads  s   z"FlaxWav2Vec2Attention._merge_headsr   key_value_statesr/   r   c                 C   s  |  |}| |}| |}| |}| |}| |}|dur)tj|dd}|durNt|dkt|j	d
| jt|j	t| jj
| j}nd}d}	|s^| jdkr^| d}	t||||	| jd|| jdd	}
td	|
|}| |}| |}||
fS )
z#Input shape: Batch x Time x ChannelN)r~   r   r   r   T)r   dropout_rngZdropout_rateZbroadcast_dropoutr   r4   	precisionz...hqk,...khd->...qhd)r   r   r   r   r#   expand_dimsr   selectfullr,   astyper4   Zfinfominr   make_rngr	   Zeinsumr   r   )rs   r   r   r/   r   Zquery_statesZ
key_statesZvalue_statesZattention_biasr   attn_weightsZattn_outputr&   r&   r'   rw     sB   









zFlaxWav2Vec2Attention.__call__)r1   N)NNT)r   r    r!   r   r%   rA   r   floatr   rE   r#   rx   r4   rt   r   r   r$   r   r   rw   r&   r&   r&   r'   r     s0   
 
r   c                   @   r   )
FlaxWav2Vec2FeedForwardrZ   r4   c                 C   s   t j| jjd| _t j| jjtj j	| jj
| jd| _t| jjtr+t| jj | _n| jj| _t j| jjtj j	| jj
| jd| _t j| jjd| _d S )Nr   r   )re   r   rZ   Zactivation_dropoutintermediate_dropoutr   Zintermediate_sizeri   rj   r   r   r4   intermediate_dense
isinstanceZ
hidden_actr   r   intermediate_act_fnrV   output_densehidden_dropoutoutput_dropoutrr   r&   r&   r'   rt     s   
zFlaxWav2Vec2FeedForward.setupTc                 C   s>   |  |}| |}| j||d}| |}| j||d}|S r   )r   r   r   r   r   rs   r   r   r&   r&   r'   rw   '  s   


z FlaxWav2Vec2FeedForward.__call__Nr   r   r&   r&   r&   r'   r     s
   
 r   c                   @   s8   e Zd ZU eed< ejZejed< dd Zd
dd	Z	dS )'FlaxWav2Vec2EncoderLayerStableLayerNormrZ   r4   c                 C   sx   t | j| jj| jj| jj| jd| _tj| jj	d| _
tj| jj| jd| _t| j| jd| _tj| jj| jd| _d S )N)rZ   r   r   r   r4   r   rc   r3   )r   rZ   rV   Znum_attention_headsZattention_dropoutr4   	attentionre   r   r   r   rm   rn   ro   r   feed_forwardfinal_layer_normrr   r&   r&   r'   rt   5  s   z-FlaxWav2Vec2EncoderLayerStableLayerNorm.setupNTFc                 C   sh   |}|  |}| j|||d\}}| j||d}|| }|| j| ||d }|f}|r2||f7 }|S )N)r/   r   r   )ro   r   r   r   r   )rs   r   r/   r   output_attentionsZattn_residualr   outputsr&   r&   r'   rw   B  s   



z0FlaxWav2Vec2EncoderLayerStableLayerNorm.__call__)NTFr   r&   r&   r&   r'   r   1  s
   
 r   c                	   @   sT   e Zd ZU eed< ejZejed< dd Z					dde	d	e	d
e	de	fddZ
dS )1FlaxWav2Vec2EncoderLayerStableLayerNormCollectionrZ   r4   c                         fddt  jjD  _d S )Nc                    "   g | ]}t  jt| jd qS )r   r4   )r   rZ   r   r4   r   rr   r&   r'   r>   [      zKFlaxWav2Vec2EncoderLayerStableLayerNormCollection.setup.<locals>.<listcomp>)rG   rZ   Znum_hidden_layersr   rr   r&   rr   r'   rt   Z     

z7FlaxWav2Vec2EncoderLayerStableLayerNormCollection.setupNTFr   r   output_hidden_statesreturn_dictc                 C   s   |rdnd }|r
dnd }t | jD ] \}	}
|r||f7 }|
||||d}|d }|r1||d f7 }q|r9||f7 }|||f}|sItdd |D S t|||dS )Nr&   )r   r   r   r   c                 s       | ]	}|d ur|V  qd S ru   r&   r:   vr&   r&   r'   	<genexpr>      zMFlaxWav2Vec2EncoderLayerStableLayerNormCollection.__call__.<locals>.<genexpr>r   r   r   )r   r   tupler   )rs   r   r/   r   r   r   r   Zall_attentionsZall_hidden_statesr   r   Zlayer_outputsr   r&   r&   r'   rw   `  s(   	


z:FlaxWav2Vec2EncoderLayerStableLayerNormCollection.__call__NTFFT)r   r    r!   r   r%   r#   rx   r4   rt   rE   rw   r&   r&   r&   r'   r   V  s$   
 	r   c                   @   sB   e Zd ZU eed< ejZejed< dd Z					d
dd	Z	dS )"FlaxWav2Vec2StableLayerNormEncoderrZ   r4   c                 C   sP   t | j| jd| _tj| jj| jd| _tj| jj	d| _
t| j| jd| _d S )Nr3   rc   r   )r   rZ   r4   pos_conv_embedre   rm   rn   ro   r   r   r   r   r   rr   r&   r&   r'   rt     s   z(FlaxWav2Vec2StableLayerNormEncoder.setupNTFc           
      C   s   |d urt t |d d d d d f |j|d}| |}|| }| j||d}| j|||||d}| |d }	d }|rK|d }|d d |	f }|sh|	|f|rX|dd  n|dd   }tdd |D S t	|	||j
d	S )
Nr   r   )r   r   r   r   r?   r   c                 s   r   ru   r&   r   r&   r&   r'   r     r   z>FlaxWav2Vec2StableLayerNormEncoder.__call__.<locals>.<genexpr>r   )r#   rJ   rH   r,   r   r   r   ro   r   r   r   )
rs   r   r/   r   r   r   r   Zposition_embeddingsr   r   r&   r&   r'   rw     s2   	"
$z+FlaxWav2Vec2StableLayerNormEncoder.__call__r   r   r&   r&   r&   r'   r     s   
 	r   c                   @   sJ   e Zd ZU dZeed< ejZejed< dd Z	e
dddZdddZdS )!FlaxWav2Vec2GumbelVectorQuantizerz
    Vector quantization using gumbel softmax. See [CATEGORICAL REPARAMETERIZATION WITH
    GUMBEL-SOFTMAX](https://arxiv.org/pdf/1611.01144.pdf) for more information.
    rZ   r4   c                 C   s   | j j| _| j j| _| j j| j dkr!td| j j d| j d| dtj	j
 d| j| j | j j| j f| _t	j| j| j tj	j
d| jd| _d S )	Nr   z`config.codevector_dim z5 must be divisible by `config.num_codevector_groups` z for concatenationcodevectorsr         ?r   )rZ   Znum_codevector_groups
num_groupsZnum_codevectors_per_groupnum_varsZcodevector_dimr@   r   ri   re   rj   uniformr   r   r   r4   weight_projrr   r&   r&   r'   rt     s$   



z'FlaxWav2Vec2GumbelVectorQuantizer.setupNc                 C   s   |d ur)t | d d d d f | j}t || t | } | jdd|  }n| jdd}t t j|t 	|d  dd  }|S )Nr   r~   gHz>r?   )
r#   rH   rU   r,   rJ   Z
zeros_likerR   meanexplog)ZprobsmaskZmask_extendedZmarginal_probs
perplexityr&   r&   r'   _compute_perplexity  s    (z5FlaxWav2Vec2GumbelVectorQuantizer._compute_perplexityTr   c                 C   s$  |j \}}}| |}||| | j d}|sD| d}tj||j }	t	||	 | }
tj	||| | jddd}| 
||}n"|jdd}tj||j d d }
|
|| | jd}
| 
|
|}|
|| d}
tj|
dd| j }||| | j| jd}|d||d}||fS )Nr?   gumbelr~   r   r   )r,   r   rI   r   r   ri   r7   r   re   Zsoftmaxr   ZargmaxZone_hotr#   r   r   r   rR   )rs   r   mask_time_indicesr   temperaturerK   r=   rV   
gumbel_rngZgumbelsZcodevector_probsZcodevector_soft_distr   Zcodevector_idxZcodevectors_per_groupr   r&   r&   r'   rw     s(   

z*FlaxWav2Vec2GumbelVectorQuantizer.__call__ru   )NTr   )r   r    r!   r"   r   r%   r#   rx   r4   rt   staticmethodr   rw   r&   r&   r&   r'   r     s   
 r   c                   @   r   )
FlaxWav2Vec2AdapterrZ   r4   c                 C   sp   | j j| j jkr'tj| j jtjj| j j| j	d| _
tj| j j| j	d| _nd  | _
| _t| j | j	d| _d S )Nr   rc   r3   )rZ   output_hidden_sizerV   re   r   ri   rj   r   r   r4   projrm   rn   proj_layer_norm#FlaxWav2Vec2AdapterLayersCollectionr   rr   r&   r&   r'   rt     s   zFlaxWav2Vec2Adapter.setupTc                 C   s6   | j d ur| jd ur|  |}| |}| |}|S ru   )r   r   r   r   r&   r&   r'   rw     s
   


zFlaxWav2Vec2Adapter.__call__Nr   r   r&   r&   r&   r'   r     s
   
 r   c                   @   r   )FlaxWav2Vec2AdapterLayerrZ   r4   c                 C   s@   t jd| jj | jjf| jjfdtj j| jj	| j
d| _d S )Nr   ))r   r   )r]   r^   r_   rb   ra   r4   )re   rf   rZ   r   Zadapter_kernel_sizeadapter_strideri   rj   r   r   r4   rl   rr   r&   r&   r'   rt   ,  s   
zFlaxWav2Vec2AdapterLayer.setupc                 C   s   |  |}tj|dd}|S )Nr   r~   )rl   re   Zglurv   r&   r&   r'   rw   6  s   
z!FlaxWav2Vec2AdapterLayer.__call__Nr   r&   r&   r&   r'   r  (  s
   
 
r  c                   @   r   )r   rZ   r4   c                    r   )Nc                    r   r   )r  rZ   r   r4   r   rr   r&   r'   r>   B  r   z=FlaxWav2Vec2AdapterLayersCollection.setup.<locals>.<listcomp>)rG   rZ   num_adapter_layersr   rr   r&   rr   r'   rt   A  r   z)FlaxWav2Vec2AdapterLayersCollection.setupc                 C   s   | j D ]}||}q|S ru   )r   )rs   r   r   r&   r&   r'   rw   G  s   

z,FlaxWav2Vec2AdapterLayersCollection.__call__Nr   r&   r&   r&   r'   r   =  s
   
 r   c                       s  e Zd ZU dZeZdZeed< dZ	dZ
ejed< ddejd	fd
edededejdef
 fddZd#dejjdededefddZee									d$dee dejjdedee dee dedee fddZ	d#deejef d ee fd!d"Z   Z!S )%FlaxWav2Vec2PreTrainedModelz
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    wav2vec2base_model_prefixr   Nmodule_class)r   i   r   TrZ   input_shapeseedr4   _do_initc                    s2   | j d||d|}t j||||||d d S )N)rZ   r4   )r  r	  r4   r
  r&   )r  super__init__)rs   rZ   r  r	  r4   r
  kwargsmodule	__class__r&   r'   r  Y  s   	z$FlaxWav2Vec2PreTrainedModel.__init__rngr   r1   c                 C   s   t j|dd}t |}tj|d\}}||d}| jj|||ddd }	|d urLtt	|	}	tt	|}| j
D ]}
|	|
 ||
< q9t | _
tt|S |	S )Ni4r3   r   )r   r   F)r   r   )r#   rD   Z	ones_likeri   r7   splitr  initr
   r   Z_missing_keyssetr   r   )rs   r  r  r   r   r/   Z
params_rngr   rngsZrandom_paramsZmissing_keyr&   r&   r'   init_weightse  s   


z(FlaxWav2Vec2PreTrainedModel.init_weightsFr   trainr   r   r   r   c                 C   s   |d ur|n| j j}|d ur|n| j j}|
d ur|
n| j j}
|j\}}|d u r.t||f}i }|d ur8||d< d|p=| ji}| jj	|tj
|ddtj
|dd|| |||	|
|d
S )Nr   r   f4r3   r  r  rZ   r   r   r   r,   r#   Zonesr   r  r   rF   )rs   r   r/   r   r   r   r  r   r   r   r   rK   r=   r  inputsr&   r&   r'   rw   x  s.   
z$FlaxWav2Vec2PreTrainedModel.__call__input_lengthsadd_adapterc                 C   s   | j j||dS )Nr  )r   _get_feat_extract_output_lengths)rs   r  r  r&   r&   r'   r        z<FlaxWav2Vec2PreTrainedModel._get_feat_extract_output_lengthsru   )	NNNNFNNFN)"r   r    r!   r"   r   config_classr  r   r%   Zmain_input_namer  re   Moduler#   rx   r   rA   r4   rE   r  ri   r7   PRNGKeyr   r  r   WAV2VEC2_INPUTS_DOCSTRINGr   dictrw   r   r$   r   __classcell__r&   r&   r  r'   r  N  sj   
  	
-r  c                   @   s   e Zd ZU eed< ejZejed< dd Z							ddd	Z		dd
e
ejef dee fddZ	ddedejfddZdS )FlaxWav2Vec2ModulerZ   r4   c                 C   s   t | j| jd| _t| j| jd| _| dtjj	
 | jjf| _| jjr/t| j| jd| _ntd| jjrBt| j| jd| _d S d | _d S )Nr3   masked_spec_embedzD``config.do_stable_layer_norm is False`` is currently not supported.)r   rZ   r4   feature_extractorr   feature_projectionr   ri   re   rj   r   rV   r)  Zdo_stable_layer_normr   encoderr   r  r   adapterrr   r&   r&   r'   rt     s   (zFlaxWav2Vec2Module.setupNTFc	              
   C   s   | j ||d}	|d ur| j|	jd |dd}| j|	|d\}
}	|d urEtt|d d d d d f |
jt| jd d d d f |
j|
}
| j|
|||||d}|d }
| j	d ur^| 	|
}
|sj|
|	f|dd   S t
|
|	|j|jdS )	N)r   r   Fr  r   )r/   r   r   r   r   r   )r   r   r   r   )r*  "_get_feature_vector_attention_maskr,   r+  r#   rJ   rH   r)  r,  r-  r   r   r   )rs   r   r/   r   r   r   r   r   r   r   r   Zencoder_outputsr&   r&   r'   rw     s>   	

zFlaxWav2Vec2Module.__call__r  r  c                 C   n   |du r| j jn|}dd }t| j j| j jD ]
\}}||||}q|r5t| j jD ]
}||d| j j}q*|S )H
        Computes the output length of the convolutional layers
        Nc                 S      | | | d S Nr   r&   Zinput_lengthr^   strider&   r&   r'   _conv_out_length  r!  zMFlaxWav2Vec2Module._get_feat_extract_output_lengths.<locals>._conv_out_lengthr   rZ   r  ziprg   rh   rG   r  r  rs   r  r  r5  r^   r4  r;   r&   r&   r'   r        z3FlaxWav2Vec2Module._get_feat_extract_output_lengthsfeature_vector_lengthr/   c                 C   s   |j ddd d df }| j||d}|jd }tj||f|jd}|jt|jd |d f d}t	t	|d dd
d}|S )Nr?   r~   r  r   r3   r   rE   )Zcumsumr   r,   r#   rD   r4   atr9   r  flipr   )rs   r:  r/   r  Znon_padded_lengthsZoutput_lengthsrK   r&   r&   r'   r.    s   
$ z5FlaxWav2Vec2Module._get_feature_vector_attention_maskNNTNNFNru   )r   r    r!   r   r%   r#   rx   r4   rt   rw   r   r$   rA   r   rE   r   r.  r&   r&   r&   r'   r(    s0   
 
5
r(  zbThe bare Wav2Vec2 Model transformer outputting raw hidden-states without any specific head on top.c                   @      e Zd ZeZdS )FlaxWav2Vec2ModelN)r   r    r!   r(  r  r&   r&   r&   r'   r?        r?  aJ  
    Returns:

    Example:

    ```python
    >>> from transformers import AutoProcessor, FlaxWav2Vec2Model
    >>> from datasets import load_dataset
    >>> import soundfile as sf

    >>> processor = AutoProcessor.from_pretrained("facebook/wav2vec2-large-lv60")
    >>> model = FlaxWav2Vec2Model.from_pretrained("facebook/wav2vec2-large-lv60")


    >>> def map_to_array(batch):
    ...     speech, _ = sf.read(batch["file"])
    ...     batch["speech"] = speech
    ...     return batch


    >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
    >>> ds = ds.map(map_to_array)

    >>> input_values = processor(
    ...     ds["speech"][0], sampling_rate=16_000, return_tensors="np"
    ... ).input_values  # Batch size 1
    >>> hidden_states = model(input_values).last_hidden_state
    ```
)output_typer"  c                   @   sj   e Zd ZU eed< ejZejed< dd Z							ddd	Z		dd
e
ejef dee fddZdS )FlaxWav2Vec2ForCTCModulerZ   r4   c                 C   sN   t | j| jd| _tj| jjd| _tj| jj	t
jj| jj| jd| _d S )Nr3   r   r   )r(  rZ   r4   r  re   r   Zfinal_dropoutr   r   Z
vocab_sizeri   rj   r   r   lm_headrr   r&   r&   r'   rt   N  s   zFlaxWav2Vec2ForCTCModule.setupNTFc	              
   C   sb   | j ||||||||d}	|	d }
| j|
|d}
| |
}|s(|f|	dd   S t||	j|	jdS )N)r/   r   r   r   r   r   r   r   r   r   )logitsr   r   )r  r   rC  r   r   r   )rs   r   r/   r   r   r   r   r   r   r   r   rD  r&   r&   r'   rw   W  s    
z!FlaxWav2Vec2ForCTCModule.__call__r  r  c                 C   r/  )r0  Nc                 S   r1  r2  r&   r3  r&   r&   r'   r5    r!  zSFlaxWav2Vec2ForCTCModule._get_feat_extract_output_lengths.<locals>._conv_out_lengthr   r6  r8  r&   r&   r'   r   w  s   	z9FlaxWav2Vec2ForCTCModule._get_feat_extract_output_lengthsr=  ru   )r   r    r!   r   r%   r#   rx   r4   rt   rw   r   r$   rA   r   rE   r   r&   r&   r&   r'   rB  J  s$   
 
#rB  zfWav2Vec2 Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).c                   @   r>  )FlaxWav2Vec2ForCTCN)r   r    r!   rB  r  r&   r&   r&   r'   rE    r@  rE  a  
    Returns:

    Example:

    ```python
    >>> import jax.numpy as jnp
    >>> from transformers import AutoProcessor, FlaxWav2Vec2ForCTC
    >>> from datasets import load_dataset
    >>> import soundfile as sf

    >>> processor = AutoProcessor.from_pretrained("facebook/wav2vec2-large-960h-lv60")
    >>> model = FlaxWav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60")


    >>> def map_to_array(batch):
    ...     speech, _ = sf.read(batch["file"])
    ...     batch["speech"] = speech
    ...     return batch


    >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
    >>> ds = ds.map(map_to_array)

    >>> input_values = processor(
    ...     ds["speech"][0], sampling_rate=16_000, return_tensors="np"
    ... ).input_values  # Batch size 1
    >>> logits = model(input_values).logits
    >>> predicted_ids = jnp.argmax(logits, axis=-1)

    >>> transcription = processor.decode(predicted_ids[0])
    >>> # should give:  "A MAN SAID TO THE UNIVERSE SIR I EXIST"
    ```
c                   @   sv   e Zd ZU eed< ejZejed< dd Z								dd	e	d
e
fddZ	ddeeje	f dee
 fddZdS ) FlaxWav2Vec2ForPreTrainingModulerZ   r4   c                 C   s   t | j| jd| _t| jj| _t| j| jd| _	tj
| jjtjj| jj| jd| _tj
| jjtjj| jj| jd| _d S )Nr3   r   )r(  rZ   r4   r  re   r   Zfeat_quantizer_dropoutdropout_featuresr   	quantizerr   Zproj_codevector_dimri   rj   r   r   	project_qproject_hidrr   r&   r&   r'   rt     s   z&FlaxWav2Vec2ForPreTrainingModule.setupNr   TFgumbel_temperaturer   c
              
   C   s   |	dur|	n| j j}	| j||||||||	d}
| |
d }| j|
d |d}| j||||d\}}| |}|	sD|||f|
dd  S t||||
j|
j	dS )	zC
        Returns:

        Example:

        ```python

        ```N)r/   r   r   r   r   r   r   r   r   r   )r   r   r   )r)   r*   r+   r   r   )
rZ   Zuse_return_dictr  rJ  rG  rH  rI  r(   r   r   )rs   r   r/   r   rK  r   r   r   r   r   r   Ztransformer_featuresr   Zquantized_featuresr+   r&   r&   r'   rw     s4   

z)FlaxWav2Vec2ForPreTrainingModule.__call__r  r  c                 C   r/  )r0  Nc                 S   r1  r2  r&   r3  r&   r&   r'   r5    r!  z[FlaxWav2Vec2ForPreTrainingModule._get_feat_extract_output_lengths.<locals>._conv_out_lengthr   r6  r8  r&   r&   r'   r     r9  zAFlaxWav2Vec2ForPreTrainingModule._get_feat_extract_output_lengths)NNr   TNNFNru   )r   r    r!   r   r%   r#   rx   r4   rt   rA   rE   rw   r   r$   r   r   r&   r&   r&   r'   rF    s.   
 
8rF  z5Wav2Vec2 Model with a quantizer and `VQ` head on top.c                   @   sv   e Zd ZeZee											ddedee	 de
jjde
jjded	ee d
ee dedee fddZdS )FlaxWav2Vec2ForPreTrainingNr   FrK  r   r   r   r  r   r   r   r   c                 C   s   |	d ur|	n| j j}	|
d ur|
n| j j}
|d ur|n| j j}|j\}}|d u r.t||f}i }|d ur8||d< |d ur@||d< d|pE| ji}| jj	|tj
|ddtj
|dd||| |	|
|||dS )Nr   r   r   r  r3   r  r  r  )rs   r   r/   r   rK  r   r   r   r  r   r   r   r   rK   r=   r  r  r&   r&   r'   rw   *  s4   
z#FlaxWav2Vec2ForPreTraining.__call__)NNr   NNNFNNFN)r   r    r!   rF  r  r   r%  rA   r   r&  ri   r7   r$  rE   rw   r&   r&   r&   r'   rL  &  sB    	
rL  a  
    Returns:

    Example:

    ```python
    >>> import optax
    >>> import numpy as np
    >>> import jax.numpy as jnp
    >>> from transformers import AutoFeatureExtractor, FlaxWav2Vec2ForPreTraining
    >>> from transformers.models.wav2vec2.modeling_flax_wav2vec2 import _compute_mask_indices
    >>> from datasets import load_dataset
    >>> import soundfile as sf

    >>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-large-lv60")
    >>> model = FlaxWav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-large-lv60")


    >>> def map_to_array(batch):
    ...     speech, _ = sf.read(batch["file"])
    ...     batch["speech"] = speech
    ...     return batch


    >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
    >>> ds = ds.map(map_to_array)

    >>> input_values = feature_extractor(ds["speech"][0], return_tensors="np").input_values  # Batch size 1

    >>> # compute masked indices
    >>> batch_size, raw_sequence_length = input_values.shape
    >>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length)
    >>> mask_time_indices = _compute_mask_indices((batch_size, sequence_length), mask_prob=0.2, mask_length=2)

    >>> outputs = model(input_values, mask_time_indices=mask_time_indices)

    >>> # compute cosine similarity between predicted (=projected_states) and target (=projected_quantized_states)
    >>> cosine_sim = optax.cosine_similarity(outputs.projected_states, outputs.projected_quantized_states)

    >>> # show that cosine similarity is much higher than random
    >>> assert np.asarray(cosine_sim)[mask_time_indices].mean() > 0.5
    ```
)rE  rL  r?  r  )Nr   ru   )Rr"   	functoolsr   typingr   r   r   ZflaxZ
flax.linenZlinenre   ri   Z	jax.numpynumpyr#   r6   Zflax.core.frozen_dictr   r   r   Zflax.linen.attentionr	   Zflax.traverse_utilr
   r   r   Zmodeling_flax_outputsr   r   Zmodeling_flax_utilsr   r   r   r   utilsr   r   r   r   Zconfiguration_wav2vec2r   Z
get_loggerr   loggerstruct	dataclassr   r(   rA   r   r$   rM   rX   ZWAV2VEC2_START_DOCSTRINGr%  r#  rY   ry   r   r   r   r   r   r   r   r   r   r   r   r  r   r  r(  r?  ZFLAX_WAV2VEC2_MODEL_DOCSTRINGrB  rE  ZFLAX_WAV2VEC2_FOR_CTC_DOCSTRINGrF  rL  Z'FLAX_WAV2VEC2_FOR_PRETRAINING_DOCSTRING__all__r&   r&   r&   r'   <module>   s   
&

I!'#$["%07N]pG#c8,