o
    Zh                    @   sj  d Z ddlZddlZddlmZ ddlmZmZmZm	Z	m
Z
 ddlZddlm  mZ ddlmZ ddlmZ ddlmZmZmZ dd	lmZ dd
lmZmZmZmZ ddlmZmZm Z m!Z! ddl"m#Z#m$Z$m%Z% e &e'Z(dd Z)dd Z*dd Z+d\ddZ,dej-dej-fddZ.eG dd deZ/eG dd deZ0eG dd deZ1G d d! d!ej2Z3G d"d# d#ej2Z4G d$d% d%ej2Z5G d&d' d'ej2Z6G d(d) d)ej2Z7G d*d+ d+ej2Z8G d,d- d-ej2Z9G d.d/ d/ej2Z:G d0d1 d1ej2Z;G d2d3 d3ej2Z<G d4d5 d5ej2Z=G d6d7 d7ej2Z>G d8d9 d9ej2Z?G d:d; d;ej2Z@G d<d= d=ej2ZAG d>d? d?ej2ZBd@eAiZCG dAdB dBej2ZDG dCdD dDej2ZEG dEdF dFej2ZFG dGdH dHej2ZGG dIdJ dJej2ZHG dKdL dLej2ZIeG dMdN dNeZJG dOdP dPeJZKedQdRG dSdT dTeJZLeG dUdV dVeJZMeG dWdX dXeJZNeG dYdZ dZeJZOg d[ZPdS )]zPyTorch CLAP model.    N)	dataclass)AnyListOptionalTupleUnion)nn   )ACT2FN))BaseModelOutputWithPastAndCrossAttentionsBaseModelOutputWithPooling,BaseModelOutputWithPoolingAndCrossAttentions)PreTrainedModel)apply_chunking_to_forward find_pruneable_heads_and_indicesmeshgridprune_linear_layer)ModelOutputauto_docstringlogging	torch_int   )ClapAudioConfig
ClapConfigClapTextConfigc                 C   sJ   | j \}}}| dddddddf dd|d}|||| |}|S )ae  
    Interpolate data in time domain. This is used to compensate the resolution reduction in downsampling of a CNN.

    Args:
        hidden_states (`torch.FloatTensor` of shape (batch_size, time_length, classes_num)):
            Input hidden states
        ratio (`int`):
            The ratio of the length of the output to the length of the input.
    Nr   )shaperepeatreshape)hidden_statesratio
batch_sizetime_lengthZclasses_numZ	upsampled r"   U/var/www/auris/lib/python3.10/site-packages/transformers/models/clap/modeling_clap.pyinterpolate*   s   
(r$   c                 C   sR   | j \}}}}| ||| ||| ||} | dddddd d|||}|S )aR  
    Returns the resized hidden states. The output shape should be `(batch_size * num_windows, window_size, window_size,
    num_channels)`

    Args:
        hidden_states (`torch.FloatTensor` of shape `(batch_size, height, width, num_channels)`):
            Input hidden states
        window_size (`int`):
            Window size
    r   r   r	            r   viewpermute
contiguous)r   window_sizer    heightwidthnum_channelswindowsr"   r"   r#   window_partition;   s   $r2   c                 C   sN   | j d }| d|| || |||} | dddddd d|||} | S )a  
    Merges windows to produce higher resolution features.
    Args:
        windows (`torch.FloatTensor` of shape `(num_windows * batch_size, window_size, window_size, num_channels)`):
            Input windows
        window_size (`int`):
            Window size
        height (`int`):
            Height of the resized audio
        width (`int`):
            Width of the resized audio
    r(   r   r   r	   r%   r&   r'   r)   )r1   r-   r.   r/   r0   r"   r"   r#   window_reverseP   s   
$r3   c                 C   s6   |  | }tj|dd|| | }| | S )a  
    Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
    are ignored. This is modified from fairseq's `utils.make_positions`.

    Args:
        x: torch.Tensor x:

    Returns: torch.Tensor
    r   dim)neinttorchZcumsumZtype_aslong)	input_idspadding_idxpast_key_values_lengthmaskZincremental_indicesr"   r"   r#   "create_position_ids_from_input_idsd   s   r>   logitsreturnc                 C   s"   t jt| | jd}tj| |S )Ndevice)r8   arangelenrB   r   
functionalZcross_entropy)r?   labelsr"   r"   r#   contrastive_lossv   s   rG   c                   @   j   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eeejdf  ed< dZeeejdf  ed< dS )ClapTextModelOutputa  
    Base class for text model's outputs that also contains a pooling of the last hidden states.

    Args:
        text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
            The text embeddings obtained by applying the projection layer to the pooler_output.
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    Ntext_embedslast_hidden_state.r   
attentions)__name__
__module____qualname____doc__rJ   r   r8   FloatTensor__annotations__rK   r   r   rL   r"   r"   r"   r#   rI   {   s   
 rI   c                   @   rH   )ClapAudioModelOutputak  
    ClapAudio model output to mimic the output of the original implementation.

    Args:
        audio_embeds (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
            The Audio embeddings obtained by applying the projection layer to the pooler_output.
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
    Naudio_embedsrK   .r   rL   )rM   rN   rO   rP   rT   r   r8   rQ   rR   rK   r   r   rL   r"   r"   r"   r#   rS      s   
 rS   c                   @   s   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eej ed< dZeej ed< dZeej ed< dZeed< dZeed	< d
ee fddZdS )
ClapOutputa  
    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
            Contrastive loss for audio-text similarity.
        logits_per_audio (`torch.FloatTensor` of shape `(audio_batch_size, text_batch_size)`):
            The scaled dot product scores between `audio_embeds` and `text_embeds`. This represents the audio-text
            similarity scores.
        logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, audio_batch_size)`):
            The scaled dot product scores between `text_embeds` and `audio_embeds`. This represents the text-audio
            similarity scores.
        text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
            The text embeddings obtained by applying the projection layer to the pooled output of [`ClapTextModel`].
        audio_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
            The audio embeddings obtained by applying the projection layer to the pooled output of [`ClapAudioModel`].
        text_model_output (`BaseModelOutputWithPooling`):
            The output of the [`ClapTextModel`].
        audio_model_output (`BaseModelOutputWithPooling`):
            The output of the [`ClapAudioModel`].
    Nlosslogits_per_audiologits_per_textrJ   rT   text_model_outputaudio_model_outputr@   c                    s   t  fdd  D S )Nc                 3   s.    | ]}|d vr | nt  | V  qdS ))rY   rZ   N)getattrto_tuple).0kselfr"   r#   	<genexpr>   s
    
z&ClapOutput.to_tuple.<locals>.<genexpr>)tuplekeysr_   r"   r_   r#   r\      s   zClapOutput.to_tuple)rM   rN   rO   rP   rV   r   r8   rQ   rR   rW   rX   rJ   rT   rY   r   rZ   r   r   r\   r"   r"   r"   r#   rU      s   
 rU   c                       s*   e Zd ZdZd fdd	Zdd Z  ZS )ClapDropPathz
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is a slightly
    refactored version of the `SwinDropPath` implementation.
    Nc                    s   t    || _d S N)super__init__	drop_prob)r`   rh   	__class__r"   r#   rg      s   

zClapDropPath.__init__c                 C   sj   | j dks| js
|S d| j  }|jd fd|jd   }|tj||j|jd }|  |	|| }|S )N        r   r   )r   dtyperB   )
rh   trainingr   ndimr8   Zrandrm   rB   Zfloor_div)r`   r   Z	keep_probr   Zrandom_tensoroutputr"   r"   r#   forward   s   
zClapDropPath.forwardre   )rM   rN   rO   rP   rg   rr   __classcell__r"   r"   ri   r#   rd      s    rd   c                       s.   e Zd ZdZdef fddZdd Z  ZS )ClapAudioAFFBlockz
    ATTENTIONAL FEATURE FUSION Block from CLAP, since in CLAP we are always in 2D mode, it is not needed to implement
    the 1D version.
    configc                    s   t    |j}|j}t|| }ttj||ddddt|tj	ddtj||ddddt|| _
ttdtj||ddddt|tj	ddtj||ddddt|| _t | _d S )Nr   r   Zkernel_sizeZstridepaddingT)Zinplace)rf   rg   patch_embeds_hidden_sizeZaff_block_rr7   r   Z
SequentialConv2dBatchNorm2dZReLU	local_attZAdaptiveAvgPool2d
global_attZSigmoidsigmoid)r`   ru   channelsZdownsize_ratioZinter_channelsri   r"   r#   rg      s(   


	zClapAudioAFFBlock.__init__c                 C   sF   || }|  || | }| |}d| | d| d|   }|S )Nr%   r   )r{   r|   r}   )r`   r   ZresidualZattention_inputZfused_layer_outputrq   r"   r"   r#   rr     s
   
zClapAudioAFFBlock.forwardrM   rN   rO   rP   r   rg   rr   rs   r"   r"   ri   r#   rt      s    rt   c                       s0   e Zd ZdZdef fddZdddZ  ZS )	ClapAudioPatchEmbedz
    This module converts the hidden states reshaped as an image to patch embeddings ready to be passed to the
    Transformer block.
    ru   c                    s  t    t|jtr|j|jfn|j}t|jtr |j|jfn|j}t|jtr/|j|jfn|j}|| _|| _|d |d  |d |d  f| _| jd | jd  | _	|j
| _|j| _|d |d  d |d |d  d f}| jrz|jdkrzdnd}tj|j| |j|||d| _|jrt|jnt | _| jrt|| _tj|j|j|d |d d f|d |d d f|d| _d S d S )Nr   r   r%   Zchannel_mapr&   rv   r	   )rf   rg   
isinstance	spec_sizer7   
patch_sizepatch_strideimg_size	grid_sizeZnum_patchesZflatten_patch_embedsflattenenable_fusionZfusion_typer   ry   Zpatch_embed_input_channelsrx   projZenable_patch_layer_norm	LayerNormIdentitynormrt   fusion_model
mel_conv2d)r`   ru   r   r   r   rw   Zscale_factorri   r"   r#   rg   $  s@   
"(
zClapAudioPatchEmbed.__init__Nc              
   C   s  | j r|d d ddd d d d f }|j\}}}}|| jd ks(|| jd kr?td| d| d| jd  d| jd  d	| |}|d}t|dkr||dd d d d d f  }	|	j\}}}}|	|| d||}	| 	|	}	|	j\}
}}}|	|||||}	|	
d d	}	|	d}tjj|	d|| fd
d}	| || |	||< |}n1|j\}
}
}}|| jd ks|| jd krtd| d| d| jd  d| jd  d	| |}| jr|ddd}| |}|S )Nr   r   zInput audio size (*z) doesn't match model (z).r(   )r   r%   r	   r   r&   r	   Zconstantr%   )r   r   r   
ValueErrorr   sizerD   r,   r*   r   r+   r   r8   r   rE   padr   	transposer   )r`   r   Zis_longer_idxZglobal_hidden_statesr    r0   r.   r/   Zoutput_widthZlocal_hidden_states_featuresZlocal_widthr"   r"   r#   rr   N  sF    (

 

(

zClapAudioPatchEmbed.forwardre   r   r"   r"   ri   r#   r     s    *r   c                       b   e Zd Z fddZdd Z			ddejdeej d	eej d
ee	 de
ej f
ddZ  ZS )ClapAudioSelfAttentionc                    s
  t    || dkrtd| d| d|| _t|| | _| j| j | _t|tj	j
r0|n||f| _ttd| jd  d d| jd  d  || _t| jd }t| jd }tt||gdd}t|d}|d d d d d f |d d d d d f  }	|	ddd }	|	d d d d df  | jd d 7  < |	d d d d df  | jd d 7  < |	d d d d df  d| jd  d 9  < |	d	}
| d
|
 tj| j| j|jd| _tj| j| j|jd| _tj| j| j|jd| _t|j| _ d S )Nr   The hidden size (6) is not a multiple of the number of attention heads ()r%   r   Zij)Zindexingr(   relative_position_indexbias)!rf   rg   r   num_attention_headsr7   attention_head_sizeall_head_sizer   collectionsabcIterabler-   r   	Parameterr8   zerosrelative_position_bias_tablerC   stackr   r   r+   r,   sumregister_bufferLinearZqkv_biasquerykeyvalueDropoutattention_probs_dropout_probdropout)r`   ru   r5   	num_headsr-   Zcoords_hZcoords_wZcoordsZcoords_flattenZrelative_coordsr   ri   r"   r#   rg     s8   
*,((,
zClapAudioSelfAttention.__init__c                 C   6   |  d d | j| jf }||}|ddddS Nr(   r   r%   r   r	   r   r   r   r*   r+   r`   xZnew_x_shaper"   r"   r#   transpose_for_scores     
z+ClapAudioSelfAttention.transpose_for_scoresNFr   attention_mask	head_maskoutput_attentionsr@   c                 C   s  |j \}}}| |}| | |}	| | |}
| |}t||	dd}|t	| j
 }| j| jd }|| jd | jd  | jd | jd  d}|ddd }||d }|d ur|j d }||| || j||}||dd }|d| j||}tjj|dd}| |}|d ur|| }t||
}|dddd }| d d | jf }||}|r||f}|S |f}|S )Nr(   r   r   r%   r4   r	   )r   r   r   r   r   r8   matmulr   mathsqrtr   r   r   r*   r-   r+   r,   	unsqueezer   r   rE   softmaxr   r   r   )r`   r   r   r   r   r    r5   r0   mixed_query_layer	key_layervalue_layerquery_layerattention_scoresZrelative_position_biasZ
mask_shapeattention_probscontext_layernew_context_layer_shapeoutputsr"   r"   r#   rr     s@   

&


zClapAudioSelfAttention.forwardNNF)rM   rN   rO   rg   r   r8   Tensorr   rQ   boolr   rr   rs   r"   r"   ri   r#   r     s"    %r   c                       8   e Zd Z fddZdejdejdejfddZ  ZS )ClapAudioSelfOutputc                    s*   t    t||| _t|j| _d S re   )rf   rg   r   r   denser   r   r   r`   ru   r5   ri   r"   r#   rg     s   
zClapAudioSelfOutput.__init__r   input_tensorr@   c                 C      |  |}| |}|S re   r   r   r`   r   r   r"   r"   r#   rr     s   

zClapAudioSelfOutput.forwardrM   rN   rO   rg   r8   r   rr   rs   r"   r"   ri   r#   r     s    $r   c                       r   )ClapAudioAttentionc                    s2   t    t||||| _t||| _t | _d S re   )rf   rg   r   r`   r   rq   setpruned_heads)r`   ru   r5   r   r-   ri   r"   r#   rg     s   
zClapAudioAttention.__init__c                 C      t |dkrd S t|| jj| jj| j\}}t| jj|| j_t| jj|| j_t| jj	|| j_	t| j
j|dd| j
_| jjt | | j_| jj| jj | j_| j|| _d S Nr   r   r4   rD   r   r`   r   r   r   r   r   r   r   rq   r   r   unionr`   Zheadsindexr"   r"   r#   prune_heads     zClapAudioAttention.prune_headsNFr   r   r   r   r@   c                 C   s6   |  ||||}| |d |}|f|dd   }|S Nr   r   r`   rq   )r`   r   r   r   r   self_outputsattention_outputr   r"   r"   r#   rr     s   zClapAudioAttention.forwardr   )rM   rN   rO   rg   r   r8   r   r   rQ   r   r   rr   rs   r"   r"   ri   r#   r     s"    r   c                       2   e Zd Z fddZdejdejfddZ  ZS )ClapAudioIntermediatec                    sJ   t    t|t|j| | _t|jt	rt
|j | _d S |j| _d S re   )rf   rg   r   r   r7   	mlp_ratior   r   
hidden_actstrr
   intermediate_act_fnr   ri   r"   r#   rg     s
   
zClapAudioIntermediate.__init__r   r@   c                 C   r   re   r   r   r`   r   r"   r"   r#   rr   $     

zClapAudioIntermediate.forwardr   r"   r"   ri   r#   r         r   c                       r   )ClapAudioOutputc                    s4   t    tt|j| || _t|j| _	d S re   )
rf   rg   r   r   r7   r   r   r   hidden_dropout_probr   r   ri   r"   r#   rg   ,  s   
zClapAudioOutput.__init__r   r@   c                 C   r   re   r   r   r"   r"   r#   rr   1  r   zClapAudioOutput.forwardr   r"   r"   ri   r#   r   +      r   c                       s   e Zd Zd fdd	Zdd Zdd Zd	d
 Z			ddejde	e
e
f deej dee dee de	ejejf fddZ  ZS )ClapAudioLayerrk   r   c                    s   t    |j| _|| _|j| _|| _tj||jd| _	t
|||| jd| _|dkr.t|nt | _tj||jd| _t||| _t||| _d S )NZeps)r-   rk   )rf   rg   chunk_size_feed_forward
shift_sizer-   input_resolutionr   r   layer_norm_epslayernorm_beforer   	attentionrd   r   	drop_pathlayernorm_afterr   intermediater   rq   )r`   ru   r5   r   r   drop_path_rater   ri   r"   r#   rg   9  s   
zClapAudioLayer.__init__c                 C   sD   t || jkr td| _tj rt t|nt || _d S d S Nr   )minr-   r   r   r8   Zjit
is_tracingtensor)r`   r   r"   r"   r#   set_shift_and_window_sizeF  s
   
 z(ClapAudioLayer.set_shift_and_window_sizec              	   C   s  | j dkrtjd||df||d}td| j t| j | j  t| j  d f}td| j t| j | j  t| j  d f}d}|D ]}	|D ]}
||d d |	|
d d f< |d7 }qEqAt|| j}|d| j| j }|d|d }||dkt	d|dkt	d}|S d }|S )Nr   r   rl   r(   r%   g      Yrk   )
r   r8   r   slicer-   r2   r*   r   Zmasked_fillfloat)r`   r.   r/   rm   rB   Zimg_maskZheight_slicesZwidth_slicescountZheight_sliceZwidth_sliceZmask_windows	attn_maskr"   r"   r#   get_attn_maskN  s.   

$zClapAudioLayer.get_attn_maskc                 C   sR   | j || j   | j  }| j || j   | j  }ddd|d|f}tj||}||fS r   )r-   r   rE   r   )r`   r   r.   r/   	pad_rightZ
pad_bottom
pad_valuesr"   r"   r#   	maybe_padj  s
   zClapAudioLayer.maybe_padNFr   input_dimensionsr   r   always_partitionr@   c                 C   s  |s|  | n	 |\}}| \}}	}
|}| |}|||||
}| |||\}}|j\}	}}}	| jdkrGtj|| j | j fdd}n|}t	|| j
}|d| j
| j
 |
}| j|||j|jd}| j||||d}|d }|d| j
| j
|
}t|| j
||}| jdkrtj|| j| jfdd}n|}|d dkp|d dk}|r|d d d |d |d d f  }|||| |
}|| | }| |}| |}|| | }|r||d	 f}|S |f}|S )
Nr   )r   r%   )Zshiftsdimsr(   rl   )r   r	   r'   r   )r  r   r   r*   r  r   r   r8   Zrollr2   r-   r  rm   rB   r   r3   r,   r   r   r   rq   )r`   r   r  r   r   r  r.   r/   r    r   r~   Zshortcutr
  Z
height_padZ	width_padZshifted_hidden_statesZhidden_states_windowsr  Zattention_outputsr   Zattention_windowsZshifted_windowsZ
was_paddedlayer_outputlayer_outputsr"   r"   r#   rr   q  sN   


$

zClapAudioLayer.forward)rk   r   NFF)rM   rN   rO   rg   r  r  r  r8   r   r   r7   r   rQ   r   rr   rs   r"   r"   ri   r#   r   8  s*    
r   c                       sd   e Zd Z fddZ			ddejdeeef deej	 dee
 d	ee
 d
eej fddZ  ZS )ClapAudioStagec                    sh   t     | _| _t fddt|D | _|d ur,|tjd| _	nd | _	d| _
d S )Nc              
      s:   g | ]}t  | |d  dkrdn jd  dqS )r%   r   )ru   r5   r   r   r   r   )r   r-   r]   iru   r5   r   r   r   r"   r#   
<listcomp>  s    	z+ClapAudioStage.__init__.<locals>.<listcomp>)r5   
norm_layerF)rf   rg   ru   r5   r   
ModuleListrangeblocksr   
downsampleZpointing)r`   ru   r5   r   depthr   r   r  ri   r  r#   rg     s   
	
zClapAudioStage.__init__NFr   r  r   r   r  r@   c                 C   s   |\}}t | jD ]\}}	|d ur|| nd }
|	|||
||}|d }q	|}| jd urE|d d |d d }}||||f}| ||}n||||f}|||f}|rZ||dd  7 }|S )Nr   r   r%   )	enumerater  r  )r`   r   r  r   r   r  r.   r/   r  layer_modulelayer_head_maskr  !hidden_states_before_downsamplingZheight_downsampledZwidth_downsampledoutput_dimensionsZstage_outputsr"   r"   r#   rr     s"   



zClapAudioStage.forwardr  )rM   rN   rO   rg   r8   r   r   r7   r   rQ   r   rr   rs   r"   r"   ri   r#   r    s$    
r  c                	       sh   e Zd ZdZejfdee dedejddf fddZ	d	d
 Z
dejdeeef dejfddZ  ZS )ClapAudioPatchMerginga'  
    Patch Merging Layer.

    Args:
        input_resolution (`Tuple[int]`):
            Resolution of input feature.
        dim (`int`):
            Number of input channels.
        norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`):
            Normalization layer class.
    r   r5   r  r@   Nc                    sB   t    || _|| _tjd| d| dd| _|d| | _d S )Nr&   r%   Fr   )rf   rg   r   r5   r   r   	reductionr   )r`   r   r5   r  ri   r"   r#   rg      s
   
zClapAudioPatchMerging.__init__c                 C   sF   |d dkp|d dk}|r!ddd|d d|d f}t j||}|S )Nr%   r   r   )r   rE   r   )r`   input_featurer.   r/   Z
should_padr
  r"   r"   r#   r    s
   zClapAudioPatchMerging.maybe_padr$  r  c                 C   s   |\}}|j \}}}|||||}| |||}|d d dd ddd dd d f }|d d dd ddd dd d f }	|d d dd ddd dd d f }
|d d dd ddd dd d f }t||	|
|gd}||dd| }| |}| |}|S )Nr   r%   r   r(   r&   )r   r*   r  r8   catr   r#  )r`   r$  r  r.   r/   r    r5   r0   Zinput_feature_0Zinput_feature_1Zinput_feature_2Zinput_feature_3r"   r"   r#   rr     s   $$$$

zClapAudioPatchMerging.forward)rM   rN   rO   rP   r   r   r   r7   Modulerg   r  r8   r   rr   rs   r"   r"   ri   r#   r"    s
    **r"  c                       s   e Zd Z fddZdd Z							ddeej d	eej d
ee dee dee dee dee de	e
ef fddZ  ZS )ClapAudioEncoderc                    s  t    t j_ _t _ j_jj	_	 j
_
 j
 j _t jdjd   _dd tjd jt jddD jjfddtjD _t fd	dtjD _d
_t j_tj_ j_td_ d S )Nr%   r   c                 S   s   g | ]}|  qS r"   )item)r]   r   r"   r"   r#   r  7      z-ClapAudioEncoder.__init__.<locals>.<listcomp>r   cpurA   c                    s,   g | ]} d  d|   d d|  fqS )r   r%   r   r"   r  )r   r"   r#   r  :  s   , c                    s|   g | ]:}t  t jd |  j|  j|  j| t jd| t jd|d   |jd k r8tnddqS )r%   Nr   )ru   r5   r   r  r   r   r  )	r  r7   rx   input_resolutionsdepthsr   r   
num_layersr"  )r]   Zi_layer)ru   r   r`   r"   r#   r  =  s    
*F)!rf   rg   rD   r,  r-  ru   r   patch_embedr   r   r   Znum_mel_bins
freq_ratior7   rx   Znum_featuresr8   Zlinspacer   r   r   r  r+  r   r  layersgradient_checkpointingrz   
batch_normr   r   ZAdaptiveAvgPool1davgpoolr`   ru   ri   )ru   r   r   r`   r#   rg   *  s,   


$
zClapAudioEncoder.__init__c                 C   s   |j \}}}}t| j| j }| j| j }||ks||kr!td||k r1tjj|||fddd}||k rAtjj|||fddd}|j \}}}	}
|||| j |	| j |
}|	dddd
 }||||
| j |	| j }|S )	z
        The input is 4 normalized log mel spectrograms. It is reshape to the common shape of images. Each channel
        should represent 1 of the 4 crops of the spectrogram. For more details, refer to the [`ClapFeatureExtractor`].
        z@the wav size should be less than or equal to the swin input sizeZbicubicT)modeZalign_cornersr   r   r	   r%   )r   r7   r   r/  r   r   rE   r$   r   r+   r,   )r`   normalized_input_featuresr   r!   Zfreq_lengthZ
spec_widthZspec_heightbatchr~   timefreqr"   r"   r#   reshape_mel2imgR  s,   z ClapAudioEncoder.reshape_mel2imgNFT	is_longerr   r   output_hidden_states(output_hidden_states_before_downsamplingr  return_dictr@   c	           $      C   sD  | dd}| |}	|	 dd}	d }
| jr%||j}t|dkd }
| |	}|jd }| 	||
}|r9dnd }|r?dnd }|rEdnd }| j
d }|rq|j\}}}|j|g||R  }|dddd}||f7 }||f7 }t| jD ]\}}|d ur|| nd }| j
| }| jr| jr| |j||||}n||||||}|d }|d }|d }|d |d f}|r|r|j\}}}|j|g|d |d f|R  }|dddd}||f7 }||f7 }n)|r|s|j\}}}|j|g||R  }|dddd}||f7 }||f7 }|r||dd  7 }qv| |}|j\}}}|dt| jd   | jd  }|dt| jd   | jd  }|ddd ||||}|j\}}} }!| | j }"|||| |" |"|!}|ddddd |||"d}| t|d}#t|#d}#|std	d
 ||#||fD S t||#||dS )Nr   r	   r   r%   r"   r   r(   r&   c                 s       | ]	}|d ur|V  qd S re   r"   r]   vr"   r"   r#   ra     s    z+ClapAudioEncoder.forward.<locals>.<genexpr>)rK   pooler_outputr   rL   )r   r2  r   torB   r8   wherer:  r   r.  r+  r*   r+   r  r0  r1  rn   _gradient_checkpointing_func__call__r   rD   r,  r   r,   r   r/  r3  r   rb   r   )$r`   input_featuresr;  r   r   r<  r=  r  r>  r6  Zis_longer_list_idxZis_longer_listr   Z
frames_numall_hidden_statesZall_reshaped_hidden_statesall_self_attentionsr  r    r   hidden_sizeZreshaped_hidden_stater  r  r  r  r   r!  rK   Z
n_channelsZ
freq_shapeZtemporal_shapeZn_frequenciesZn_tempZ
c_freq_binZlatent_outputr"   r"   r#   rr   v  s   











  
 zClapAudioEncoder.forward)NNFFFFT)rM   rN   rO   rg   r:  r   r8   rQ   r   r   r   rS   rr   rs   r"   r"   ri   r#   r'  )  s6    ('	

r'  c                       s2   e Zd Zdeeef f fddZdd Z  ZS )ClapProjectionLayerru   c                    sH   t    || _|j}|j}t||| _t|j	 | _
t||| _d S re   )rf   rg   ru   rJ  projection_dimr   r   linear1r
   Zprojection_hidden_act
activationlinear2)r`   ru   rJ  rL  ri   r"   r#   rg     s   
zClapProjectionLayer.__init__c                 C   s"   |  |}| |}| |}|S re   )rM  rN  rO  r   r"   r"   r#   rr     s   


zClapProjectionLayer.forward)	rM   rN   rO   r   r   r   rg   rr   rs   r"   r"   ri   r#   rK    s    
rK  c                       s4   e Zd ZdZ fddZ	d
ddZdd	 Z  ZS )ClapTextEmbeddingszV
    Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
    c                    s   t    tj|j|j|jd| _t|j|j| _	t|j
|j| _tj|j|jd| _t|j| _t|dd| _| jdt|jddd | jd	tj| j tjd
dd |j| _tj|j|j| jd| _	d S )N)r;   r   position_embedding_typeabsoluteposition_ids)r   r(   T)
persistenttoken_type_idsrm   )rf   rg   r   	EmbeddingZ
vocab_sizerJ  Zpad_token_idword_embeddingsmax_position_embeddingsposition_embeddingsZtype_vocab_sizetoken_type_embeddingsr   r   r   r   r   r[   rQ  r   r8   rC   expandr   rS  r   r9   r;   r4  ri   r"   r#   rg     s"   
zClapTextEmbeddings.__init__Nr   c                 C   s   |d u r|d urt || j|}n| |}|d ur| }n| d d }|d }|d u rTt| drI| jd d d |f }||d |}	|	}ntj|tj	| j
jd}|d u r]| |}| |}
||
 }| jdkrt| |}||7 }| |}| |}|S )Nr(   r   rU  r   rl   rR  )r>   r;   &create_position_ids_from_inputs_embedsr   hasattrrU  r\  r8   r   r9   rS  rB   rX  r[  rQ  rZ  r   r   )r`   r:   rU  rS  inputs_embedsr<   input_shape
seq_lengthbuffered_token_type_ids buffered_token_type_ids_expandedr[  
embeddingsrZ  r"   r"   r#   rr   %  s0   








zClapTextEmbeddings.forwardc                 C   sN   |  dd }|d }tj| jd || j d tj|jd}|d|S )z
        We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.

        Args:
            inputs_embeds: torch.Tensor

        Returns: torch.Tensor
        Nr(   r   rl   r   )r   r8   rC   r;   r9   rB   r   r\  )r`   r_  r`  Zsequence_lengthrS  r"   r"   r#   r]  M  s   	z9ClapTextEmbeddings.create_position_ids_from_inputs_embeds)NNNNr   )rM   rN   rO   rP   rg   rr   r]  rs   r"   r"   ri   r#   rP    s    
(rP  c                       s   e Zd Zd fdd	ZdejdejfddZ						dd	ejd
eej deej deej deej dee	e	ej   dee
 de	ej fddZ  ZS )ClapTextSelfAttentionNc                    s   t    |j|j dkrt|dstd|j d|j d|j| _t|j|j | _| j| j | _t	
|j| j| _t	
|j| j| _t	
|j| j| _t	|j| _|p\t|dd| _| jdksh| jd	kry|j| _t	d
|j d | j| _|j| _d S )Nr   Zembedding_sizer   r   r   rQ  rR  relative_keyrelative_key_queryr%   r   )rf   rg   rJ  r   r^  r   r7   r   r   r   r   r   r   r   r   r   r   r[   rQ  rY  rW  distance_embedding
is_decoderr`   ru   rQ  ri   r"   r#   rg   a  s*   

zClapTextSelfAttention.__init__r   r@   c                 C   r   r   r   r   r"   r"   r#   r   {  r   z*ClapTextSelfAttention.transpose_for_scoresFr   r   r   encoder_hidden_statesencoder_attention_maskpast_key_valuer   c                 C   s  |  |}|d u}	|	r|d ur|d }
|d }|}nP|	r/| | |}
| | |}|}n;|d urZ| | |}
| | |}tj|d |
gdd}
tj|d |gdd}n| | |}
| | |}| |}|d u}| jrz|
|f}t||
dd}| j	dks| j	dkr	|j
d |
j
d }}|rtj|d tj|jd	dd}ntj|tj|jd	dd}tj|tj|jd	dd}|| }| || j d }|j|jd
}| j	dkrtd||}|| }n| j	dkr	td||}td|
|}|| | }|t| j }|d ur|| }tjj|dd}| |}|d ur0|| }t||}|dddd }| d d | jf }||}|rX||fn|f}| jrd||f }|S )Nr   r   r%   r4   r(   r   rf  rg  rl   rV  zbhld,lrd->bhlrzbhrd,lrd->bhlrr	   ) r   r   r   r   r8   r%  ri  r   r   rQ  r   r  r9   rB   r*   rC   rh  rY  rC  rm   Zeinsumr   r   r   r   rE   r   r   r+   r,   r   r   )r`   r   r   r   rk  rl  rm  r   r   Zis_cross_attentionr   r   r   	use_cacher   Zquery_lengthZ
key_lengthZposition_ids_lZposition_ids_rZdistanceZpositional_embeddingZrelative_position_scoresZrelative_position_scores_queryZrelative_position_scores_keyr   r   r   r   r"   r"   r#   rr     sn   









zClapTextSelfAttention.forwardre   NNNNNF)rM   rN   rO   rg   r8   r   r   r   rQ   r   r   rr   rs   r"   r"   ri   r#   re  `  s4    	re  c                       r   )ClapTextSelfOutputc                    sB   t    t|j|j| _tj|j|jd| _t|j	| _
d S Nr   )rf   rg   r   r   rJ  r   r   r   r   r   r   r4  ri   r"   r#   rg        
zClapTextSelfOutput.__init__r   r   r@   c                 C   &   |  |}| |}| || }|S re   r   r   r   r   r"   r"   r#   rr        

zClapTextSelfOutput.forwardr   r"   r"   ri   r#   rp        $rp  eagerc                       s   e Zd Zd fdd	Zdd Z						ddejdeej d	eej d
eej deej dee	e	ej   dee
 de	ej fddZ  ZS )ClapTextAttentionNc                    s4   t    t|j ||d| _t|| _t | _d S )NrQ  )	rf   rg    CLAP_TEXT_SELF_ATTENTION_CLASSESZ_attn_implementationr`   rp  rq   r   r   rj  ri   r"   r#   rg     s   

zClapTextAttention.__init__c                 C   r   r   r   r   r"   r"   r#   r     r   zClapTextAttention.prune_headsFr   r   r   rk  rl  rm  r   r@   c              	   C   s<   |  |||||||}| |d |}	|	f|dd   }
|
S r   r   )r`   r   r   r   rk  rl  rm  r   r   r   r   r"   r"   r#   rr     s   
	zClapTextAttention.forwardre   ro  )rM   rN   rO   rg   r   r8   r   r   rQ   r   r   rr   rs   r"   r"   ri   r#   rx    s4    	rx  c                       r   )ClapTextIntermediatec                    sD   t    t|j|j| _t|jt	rt
|j | _d S |j| _d S re   )rf   rg   r   r   rJ  intermediate_sizer   r   r   r   r
   r   r4  ri   r"   r#   rg   0  s
   
zClapTextIntermediate.__init__r   r@   c                 C   r   re   r   r   r"   r"   r#   rr   8  r   zClapTextIntermediate.forwardr   r"   r"   ri   r#   r{  /  r   r{  c                       r   )ClapTextOutputc                    sB   t    t|j|j| _tj|j|jd| _t	|j
| _d S rq  )rf   rg   r   r   r|  rJ  r   r   r   r   r   r   r4  ri   r"   r#   rg   @  rr  zClapTextOutput.__init__r   r   r@   c                 C   rs  re   rt  r   r"   r"   r#   rr   F  ru  zClapTextOutput.forwardr   r"   r"   ri   r#   r}  ?  rv  r}  c                       s   e Zd Z fddZ						ddejdeej deej deej d	eej d
eeeej   dee	 deej fddZ
dd Z  ZS )ClapTextLayerc                    sr   t    |j| _d| _t|| _|j| _|j| _| jr-| js&t|  dt|dd| _	t
|| _t|| _d S )Nr   z> should be used as a decoder model if cross attention is addedrR  ry  )rf   rg   r   seq_len_dimrx  r   ri  add_cross_attentionr   crossattentionr{  r   r}  rq   r4  ri   r"   r#   rg   O  s   


zClapTextLayer.__init__NFr   r   r   rk  rl  rm  r   r@   c              	   C   s  |d ur
|d d nd }| j |||||d}	|	d }
| jr(|	dd }|	d }n|	dd  }d }| jro|d urot| dsDtd|  d|d urN|d	d  nd }| |
||||||}|d }
||dd  }|d }|| }t| j| j| j|
}|f| }| jr||f }|S )
Nr%   )r   rm  r   r   r(   r  z'If `encoder_hidden_states` are passed, z` has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`r   )	r   ri  r^  r   r  r   feed_forward_chunkr   r  )r`   r   r   r   rk  rl  rm  r   Zself_attn_past_key_valueZself_attention_outputsr   r   Zpresent_key_valueZcross_attn_present_key_valueZcross_attn_past_key_valueZcross_attention_outputsr  r"   r"   r#   rr   ]  sP   


	

zClapTextLayer.forwardc                 C   s   |  |}| ||}|S re   )r   rq   )r`   r   Zintermediate_outputr  r"   r"   r#   r    s   
z ClapTextLayer.feed_forward_chunkro  )rM   rN   rO   rg   r8   r   r   rQ   r   r   rr   r  rs   r"   r"   ri   r#   r~  N  s4    	
Ar~  c                       s   e Zd Z fddZ									ddejdeej deej d	eej d
eej deeeej   dee	 dee	 dee	 dee	 de
eej ef fddZ  ZS )ClapTextEncoderc                    s:   t     | _t fddt jD | _d| _d S )Nc                    s   g | ]}t  qS r"   )r~  )r]   r   ru   r"   r#   r    r)  z,ClapTextEncoder.__init__.<locals>.<listcomp>F)	rf   rg   ru   r   r  r  num_hidden_layerslayerr1  r4  ri   r  r#   rg     s   
 
zClapTextEncoder.__init__NFTr   r   r   rk  rl  past_key_valuesrn  r   r<  r>  r@   c                 C   s^  |	rdnd }|r
dnd }|r| j jrdnd }| jr%| jr%|r%td d}|r)dnd }t| jD ]^\}}|	r;||f }|d urC|| nd }|d urM|| nd }| jrc| jrc| |j	|||||||}n
||||||||}|d }|rz||d f7 }|r||d f }| j jr||d f }q0|	r||f }|
st
dd	 |||||fD S t|||||d
S )Nr"   zZ`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...Fr   r(   r   r%   c                 s   r?  re   r"   r@  r"   r"   r#   ra     s    z*ClapTextEncoder.forward.<locals>.<genexpr>)rK   r  r   rL   cross_attentions)ru   r  r1  rn   loggerZwarning_oncer  r  rE  rF  rb   r   )r`   r   r   r   rk  rl  r  rn  r   r<  r>  rH  rI  Zall_cross_attentionsZnext_decoder_cacher  r  r  rm  r  r"   r"   r#   rr     sz   


zClapTextEncoder.forward)	NNNNNNFFT)rM   rN   rO   rg   r8   r   r   rQ   r   r   r   r   rr   rs   r"   r"   ri   r#   r    sD    		
r  c                       r   )ClapTextPoolerc                    s*   t    t|j|j| _t | _d S re   )rf   rg   r   r   rJ  r   ZTanhrN  r4  ri   r"   r#   rg     s   
zClapTextPooler.__init__r   r@   c                 C   s(   |d d df }|  |}| |}|S r   )r   rN  )r`   r   Zfirst_token_tensorpooled_outputr"   r"   r#   rr   	  s   

zClapTextPooler.forwardr   r"   r"   ri   r#   r    r   r  c                   @   s    e Zd ZeZdZdZdd ZdS )ClapPreTrainedModelclapFc                 C   s6  | j j}t|tr#|jjjjd|d d |jjjjd|d d dS t|t	r@t
jj|j|d d t
jj|j|d d dS t|t
jrS|jjjd|d d dS t|t
jrh|jj  |jjd dS t|t
jt
jfr| j jd d| j j d  | }t
jj|j|d |jdur|jj  dS dS dS )	zInitialize the weightsrk   g{Gz?)meanstd)r  g      ?g      r%   N)ru   Zinitializer_factorr   rP  rZ  weightdataZnormal_r[  	ClapModelr   initlogit_scale_alogit_scale_trW  r   r   Zzero_Zfill_ry   r   rJ  r  )r`   modulefactorZin_proj_stdr"   r"   r#   _init_weights  s&   

 
z!ClapPreTrainedModel._init_weightsN)rM   rN   rO   r   config_classZbase_model_prefixZsupports_gradient_checkpointingr  r"   r"   r"   r#   r    s
    r  c                          e Zd ZeZdZdef fddZdejfddZ	e
					ddeej d	eej d
ee dee dee deeef fddZ  ZS )ClapAudioModelrG  ru   c                    s"   t  | t|| _|   d S re   )rf   rg   r'  audio_encoder	post_initr4  ri   r"   r#   rg   3  s   
zClapAudioModel.__init__r@   c                 C   
   | j jjS re   )r  r.  r   r_   r"   r"   r#   get_input_embeddings9     
z#ClapAudioModel.get_input_embeddingsNr;  r   r<  r>  c                 C   sP   |dur|n| j j}|dur|n| j j}|dur|n| j j}| j|||||dS )a  
        input_features (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Input audio features. This should be returned by the [`ClapFeatureExtractor`] class that you can also
            retrieve from [`AutoFeatureExtractor`]. See [`ClapFeatureExtractor.__call__`] for details.
        is_longer (`torch.FloatTensor`, of shape `(batch_size, 1)`, *optional*):
            Whether the audio clip is longer than `max_length`. If `True`, a feature fusion will be enabled to enhance
            the features.

        Examples:

        ```python
        >>> from datasets import load_dataset
        >>> from transformers import AutoProcessor, ClapAudioModel

        >>> dataset = load_dataset("hf-internal-testing/ashraq-esc50-1-dog-example")
        >>> audio_sample = dataset["train"]["audio"][0]["array"]

        >>> model = ClapAudioModel.from_pretrained("laion/clap-htsat-fused")
        >>> processor = AutoProcessor.from_pretrained("laion/clap-htsat-fused")

        >>> inputs = processor(audios=audio_sample, return_tensors="pt")

        >>> outputs = model(**inputs)
        >>> last_hidden_state = outputs.last_hidden_state
        ```NrG  r;  r   r<  r>  )ru   use_return_dictr   r<  r  )r`   rG  r;  r   r<  r>  r"   r"   r#   rr   <  s   "zClapAudioModel.forwardNNNNN)rM   rN   rO   r   r  main_input_namerg   r   r&  r  r   r   r8   rQ   
BoolTensorr   r   r   r   rr   rs   r"   r"   ri   r#   r  /  s0    
r  a(  
    The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
    cross-attention is added between the self-attention layers, following the architecture described in *Attention is
    all you need*_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz
    Kaiser and Illia Polosukhin.

    To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
    to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
    `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.

    .. _*Attention is all you need*: https://arxiv.org/abs/1706.03762
    )Zcustom_introc                        s   e Zd ZeZd fdd	Zdd Zdd Ze													dd	e	e
j d
e	e
j de	e
j de	e
j de	e
j de	e
j de	e
j de	e
j de	ee
j  de	e de	e de	e de	e deee
j ef fddZ  ZS )ClapTextModelTc                    sD   t  | || _t|| _t|| _|rt|nd| _| 	  dS )zv
        add_pooling_layer (bool, *optional*, defaults to `True`):
            Whether to add a pooling layer
        N)
rf   rg   ru   rP  rd  r  encoderr  poolerr  )r`   ru   Zadd_pooling_layerri   r"   r#   rg   ~  s   

zClapTextModel.__init__c                 C   s   | j jS re   rd  rX  r_   r"   r"   r#   r    s   z"ClapTextModel.get_input_embeddingsc                 C   s   || j _d S re   r  r`   r   r"   r"   r#   set_input_embeddings     z"ClapTextModel.set_input_embeddingsNr:   r   rU  rS  r   r_  rk  rl  r  rn  r   r<  r>  r@   c                 C   sP  |d ur|n| j j}|d ur|n| j j}|d ur|n| j j}| j jr-|
d ur(|
n| j j}
nd}
|d ur;|d ur;td|d urJ| || | }n|d urW| d d }ntd|\}}|d urf|j	n|j	}|	d urv|	d d j
d nd}|d u rtj||| f|d}|d u rt| jdr| jjd d d |f }|||}|}n	tj|tj|d	}| ||}| j jr|d ur| \}}}||f}|d u rtj||d}| |}nd }| || j j}| j|||||d
}| j||||||	|
|||d
}|d }| jd ur| |nd }|s||f|dd   S t|||j|j|j|jdS )NFzDYou cannot specify both input_ids and inputs_embeds at the same timer(   z5You have to specify either input_ids or inputs_embedsr   r%   rA   rU  rl   )r:   rS  rU  r_  r<   )	r   r   rk  rl  r  rn  r   r<  r>  r   )rK   rB  r  r   rL   r  )ru   r   r<  r  ri  rn  r   Z%warn_if_padding_and_no_attention_maskr   rB   r   r8   Zonesr^  rd  rU  r\  r   r9   Zget_extended_attention_maskZinvert_attention_maskZget_head_maskr  r  r  r   r  r   rL   r  )r`   r:   r   rU  rS  r   r_  rk  rl  r  rn  r   r<  r>  r`  r    ra  rB   r<   rb  rc  Zextended_attention_maskZencoder_batch_sizeZencoder_sequence_lengthr   Zencoder_hidden_shapeZencoder_extended_attention_maskZembedding_outputZencoder_outputsZsequence_outputr  r"   r"   r#   rr     s   
zClapTextModel.forward)T)NNNNNNNNNNNNN)rM   rN   rO   r   r  rg   r  r  r   r   r8   r   r   rQ   r   r   r   r   rr   rs   r"   r"   ri   r#   r  m  s`    	
r  c                       sV  e Zd ZeZdef fddZe						ddeej	 deej	 deej	 dee
 d	ee
 d
ee
 dejfddZe						ddeej	 deej	 deej	 dee
 d	ee
 d
ee
 dejfddZe									ddeej deej deej deej	 deej dee
 dee
 d	ee
 d
ee
 deeef fddZ  ZS )r  ru   c                    s   t  | t|jtstdt|j dt|jts(tdt|j d|j}|j}t	
tt|j| _t	
tt|j| _|j| _t|| _t|| _t|| _t|| _|   d S )NzKconfig.text_config is expected to be of type ClapTextConfig but is of type .zMconfig.audio_config is expected to be of type ClapAudioConfig but is of type )rf   rg   r   text_configr   	TypeErrortypeaudio_configr   r   r   r8   r  r   logZlogit_scale_init_valuer  r  rL  r  
text_modelrK  text_projectionr  audio_modelaudio_projectionr  )r`   ru   r  r  ri   r"   r#   rg     s.   



zClapModel.__init__Nr:   r   rS  r   r<  r>  r@   c           
      C   s   |dur|n| j j}|dur|n| j j}|dur|n| j j}| j||||||d}|dur1|d n|j}| |}	tj|	dd}	|	S )a  
        Returns:
            text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
            applying the projection layer to the pooled output of [`ClapTextModel`].

        Examples:

        ```python
        >>> from transformers import AutoTokenizer, ClapModel

        >>> model = ClapModel.from_pretrained("laion/clap-htsat-unfused")
        >>> tokenizer = AutoTokenizer.from_pretrained("laion/clap-htsat-unfused")

        >>> inputs = tokenizer(["the sound of a cat", "the sound of a dog"], padding=True, return_tensors="pt")
        >>> text_features = model.get_text_features(**inputs)
        ```Nr:   r   rS  r   r<  r>  r   r(   r4   )	ru   r   r<  r  r  rB  r  F	normalize)
r`   r:   r   rS  r   r<  r>  text_outputsr  Ztext_featuresr"   r"   r#   get_text_features(  s    	
zClapModel.get_text_featuresrG  r;  c           
      C   sz   |dur|n| j j}|dur|n| j j}|dur|n| j j}| j|||d}|s,|d n|j}| |}	tj|	dd}	|	S )a  
        input_features (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Input audio features. This should be returned by the [`ClapFeatureExtractor`] class that you can also
            retrieve from [`AutoFeatureExtractor`]. See [`ClapFeatureExtractor.__call__`] for details.
        is_longer (`torch.FloatTensor`, of shape `(batch_size, 1)`, *optional*):
            Whether the audio clip is longer than `max_length`. If `True`, a feature fusion will be enabled to enhance
            the features.

        Returns:
            audio_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The audio embeddings obtained by
            applying the projection layer to the pooled output of [`ClapAudioModel`].

        Examples:

        ```python
        >>> from transformers import AutoFeatureExtractor, ClapModel
        >>> import torch

        >>> model = ClapModel.from_pretrained("laion/clap-htsat-unfused")
        >>> feature_extractor = AutoFeatureExtractor.from_pretrained("laion/clap-htsat-unfused")
        >>> random_audio = torch.rand((16_000))
        >>> inputs = feature_extractor(random_audio, return_tensors="pt")
        >>> audio_features = model.get_audio_features(**inputs)
        ```N)rG  r;  r>  r   r(   r4   )	ru   r   r<  r  r  rB  r  r  r  )
r`   rG  r;  r   r   r<  r>  audio_outputsr  Zaudio_featuresr"   r"   r#   get_audio_featuresX  s   "
zClapModel.get_audio_featuresreturn_lossc
              	   C   sj  |dur|n| j j}|dur|n| j j}|	dur|	n| j j}	| j|||||	d}
| j||||||	d}|	s9|
d n|
j}| |}|	sG|d n|j}| |}||j	dddd }||j	dddd }| j
 }| j }t|| | }t|| | }d}|rt|}t| }|| d	 }|	s||||||
f}|dur|f| S |S t|||||||
d
S )a  
        input_features (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Input audio features. This should be returned by the [`ClapFeatureExtractor`] class that you can also
            retrieve from [`AutoFeatureExtractor`]. See [`ClapFeatureExtractor.__call__`] for details.
        return_loss (`bool`, *optional*):
            Whether or not to return the contrastive loss.
        is_longer (`torch.FloatTensor`, of shape `(batch_size, 1)`, *optional*):
            Whether the audio clip is longer than `max_length`. If `True`, a feature fusion will be enabled to enhance
            the features.

        Examples:

        ```python
        >>> from datasets import load_dataset
        >>> from transformers import AutoProcessor, ClapModel

        >>> dataset = load_dataset("hf-internal-testing/ashraq-esc50-1-dog-example")
        >>> audio_sample = dataset["train"]["audio"][0]["array"]

        >>> model = ClapModel.from_pretrained("laion/clap-htsat-unfused")
        >>> processor = AutoProcessor.from_pretrained("laion/clap-htsat-unfused")

        >>> input_text = ["Sound of a dog", "Sound of vaccum cleaner"]

        >>> inputs = processor(text=input_text, audios=audio_sample, return_tensors="pt", padding=True)

        >>> outputs = model(**inputs)
        >>> logits_per_audio = outputs.logits_per_audio  # this is the audio-text similarity score
        >>> probs = logits_per_audio.softmax(dim=-1)  # we can take the softmax to get the label probabilities
        ```Nr  r  r   r%   r(   T)pr5   Zkeepdimg       @)rV   rW   rX   rJ   rT   rY   rZ   )ru   r   r<  r  r  r  rB  r  r  r   r  expr  r8   r   trG   rU   )r`   r:   rG  r;  r   rS  r  r   r<  r>  r  r  rT   rJ   Zlogit_scale_textZlogit_scale_audiorX   rW   rV   Zcaption_lossZ
audio_lossrq   r"   r"   r#   rr     s\   ,	



zClapModel.forwardNNNNNN)	NNNNNNNNN)rM   rN   rO   r   r  rg   r   r   r8   r   r   rQ   r  r  Z
LongTensorr  r   r   rU   rr   rs   r"   r"   ri   r#   r    s     /4	

r  c                       s   e Zd ZeZdef fddZdejfddZdd Z	e
												dd
eej deej deej dee dee dee deeef fddZ  ZS )ClapTextModelWithProjectionru   c                    ,   t  | t|| _t|| _|   d S re   )rf   rg   r  r  rK  r  r  r4  ri   r"   r#   rg        

z$ClapTextModelWithProjection.__init__r@   c                 C   r  re   r  rd  rX  r_   r"   r"   r#   r     r  z0ClapTextModelWithProjection.get_input_embeddingsc                 C   s   || j j_d S re   r  r  r"   r"   r#   r    s   z0ClapTextModelWithProjection.set_input_embeddingsNr:   r   rS  r   r<  r>  c                 C   s   |dur|n| j j}| j||||||d}|s|d n|j}| |}	|s:|	|d f|dd  }
tdd |
D S t|	|j|j|j	dS )	a  
        Examples:

        ```python
        >>> from transformers import AutoTokenizer, ClapTextModelWithProjection

        >>> model = ClapTextModelWithProjection.from_pretrained("laion/clap-htsat-unfused")
        >>> tokenizer = AutoTokenizer.from_pretrained("laion/clap-htsat-unfused")

        >>> inputs = tokenizer(["a sound of a cat", "a sound of a dog"], padding=True, return_tensors="pt")

        >>> outputs = model(**inputs)
        >>> text_embeds = outputs.text_embeds
        ```Nr  r   r   r%   c                 s   r?  re   r"   r]   rq   r"   r"   r#   ra   /      z6ClapTextModelWithProjection.forward.<locals>.<genexpr>)rJ   rK   r   rL   )
ru   r  r  rB  r  rb   rI   rK   r   rL   )r`   r:   r   rS  r   r<  r>  r  r  rJ   r   r"   r"   r#   rr     s(   	
z#ClapTextModelWithProjection.forwardr  )rM   rN   rO   r   r  rg   r   r&  r  r  r   r   r8   r   r   r   r   rI   rr   rs   r"   r"   ri   r#   r    s6    
r  c                       r  )ClapAudioModelWithProjectionrG  ru   c                    r  re   )rf   rg   r  r  rK  r  r  r4  ri   r"   r#   rg   >  r  z%ClapAudioModelWithProjection.__init__r@   c                 C   s   | j jjjS re   )r  r  r.  r   r_   r"   r"   r#   r  E  r  z1ClapAudioModelWithProjection.get_input_embeddingsNr;  r   r<  r>  c           
      C   s   |dur|n| j j}|dur|n| j j}|dur|n| j j}| j|||||d}|s.|d n|j}| |}|sM||d f|dd  }	tdd |	D S t||j	|j
|jdS )	a  
        input_features (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Input audio features. This should be returned by the [`ClapFeatureExtractor`] class that you can also
            retrieve from [`AutoFeatureExtractor`]. See [`ClapFeatureExtractor.__call__`] for details.
        is_longer (`torch.FloatTensor`, of shape `(batch_size, 1)`, *optional*):
            Whether the audio clip is longer than `max_length`. If `True`, a feature fusion will be enabled to enhance
            the features.

        Examples:

        ```python
        >>> from datasets import load_dataset
        >>> from transformers import ClapAudioModelWithProjection, ClapProcessor

        >>> model = ClapAudioModelWithProjection.from_pretrained("laion/clap-htsat-fused")
        >>> processor = ClapProcessor.from_pretrained("laion/clap-htsat-fused")

        >>> dataset = load_dataset("hf-internal-testing/ashraq-esc50-1-dog-example")
        >>> audio_sample = dataset["train"]["audio"][0]["array"]

        >>> inputs = processor(audios=audio_sample, return_tensors="pt")
        >>> outputs = model(**inputs)
        >>> audio_embeds = outputs.audio_embeds
        ```Nr  r   r   r%   c                 s   r?  re   r"   r  r"   r"   r#   ra   }  r  z7ClapAudioModelWithProjection.forward.<locals>.<genexpr>)rT   rK   rL   r   )ru   r  r   r<  r  rB  r  rb   rS   rK   rL   r   )
r`   rG  r;  r   r<  r>  r  r  rT   r   r"   r"   r#   rr   H  s,   !
z$ClapAudioModelWithProjection.forwardr  )rM   rN   rO   r   r  r  rg   r   r&  r  r   r   r8   rQ   r  r   r   r   rS   rr   rs   r"   r"   ri   r#   r  9  s0    
r  )r  r  r  r  r  r  )r   )QrP   r   r   dataclassesr   typingr   r   r   r   r   r8   Ztorch.nn.functionalr   rE   r  Zactivationsr
   Zmodeling_outputsr   r   r   Zmodeling_utilsr   Zpytorch_utilsr   r   r   r   utilsr   r   r   r   Zconfiguration_clapr   r   r   Z
get_loggerrM   r  r$   r2   r3   r>   r   rG   rI   rS   rU   r&  rd   rt   r   r   r   r   r   r   r   r  r"  r'  rK  rP  re  rp  rz  rx  r{  r}  r~  r  r  r  r  r  r  r  r  __all__r"   r"   r"   r#   <module>   s   

%(ce'~=6 KZ 4W^> 
 qCM