o
    Zh                    @   s`  d Z ddlZddlZddlmZmZmZ ddlZddlmZ ddl	m
Z
 ddlmZ ddlmZ dd	lmZmZmZ dd
lmZ ddlmZ ddlmZmZmZ ddlmZ ddlmZmZm Z  ddl!m"Z"m#Z#m$Z$m%Z%m&Z& ddl'm(Z( e# rddl)m*Z* ddl+m,Z, e&-e.Z/dZ0zddl1m2Z2 dZ0e/3d W n e4y   Y n e5y   e/6d Y nw G dd dej7Z8e0se2Z8e9e8 G dd dej7Z:G dd dej7Z;G dd  d ej7Z<G d!d" d"ej7Z=G d#d$ d$ej7Z>G d%d& d&ej7Z?G d'd( d(ej7Z@e"G d)d* d*eZAG d+d, d,eAZBG d-d. d.ej7ZCe"d/d0G d1d2 d2eAeZDd2d*gZEdS )3zPyTorch Pop2Piano model.    N)OptionalTupleUnion)nn)CrossEntropyLoss)GenerationConfig   )ACT2FN)CacheDynamicCacheEncoderDecoderCache)GenerationMixin)AttentionMaskConverter)BaseModelOutput)BaseModelOutputWithPastAndCrossAttentionsSeq2SeqLMOutput)PreTrainedModel)ALL_LAYERNORM_LAYERS find_pruneable_heads_and_indicesprune_linear_layer)auto_docstringis_torch_flex_attn_availableis_torch_fx_proxyis_torchdynamo_compilinglogging   )Pop2PianoConfig)	BlockMask)make_flex_block_causal_maskT)FusedRMSNormFzVDiscovered apex.normalization.FusedRMSNorm - will use it instead of Pop2PianoLayerNormzIDiscovered apex but it failed to load, falling back to Pop2PianoLayerNormc                       s&   e Zd Zd fdd	Zdd Z  ZS )Pop2PianoLayerNormư>c                    s&   t    tt|| _|| _dS )zj
        Construct a layernorm module in the Pop2Piano style. No bias and no subtraction of mean.
        N)super__init__r   	Parametertorchonesweightvariance_epsilon)selfZhidden_sizeeps	__class__ _/var/www/auris/lib/python3.10/site-packages/transformers/models/pop2piano/modeling_pop2piano.pyr#   @   s   

zPop2PianoLayerNorm.__init__c                 C   s\   | tjdjddd}|t|| j  }| jjtj	tj
fv r)| | jj}| j| S )N   T)Zkeepdim)tor%   Zfloat32powmeanZrsqrtr(   r'   dtypefloat16Zbfloat16)r)   hidden_statesZvariancer-   r-   r.   forwardH   s
   
zPop2PianoLayerNorm.forward)r!   )__name__
__module____qualname__r#   r7   __classcell__r-   r-   r+   r.   r    ?   s    r    c                       *   e Zd Zdef fddZdd Z  ZS )Pop2PianoDenseActDenseconfigc                    sT   t    tj|j|jdd| _tj|j|jdd| _t|j	| _
t|j | _d S NFbias)r"   r#   r   Lineard_modeld_ffwiwoDropoutdropout_ratedropoutr	   dense_act_fnactr)   r>   r+   r-   r.   r#   `   s
   
zPop2PianoDenseActDense.__init__c                 C   sl   |  |}| |}| |}t| jjtjr/|j| jjjkr/| jjjtj	kr/|
| jjj}| |}|S N)rE   rK   rI   
isinstancerF   r'   r%   Tensorr4   int8r1   )r)   r6   r-   r-   r.   r7   g   s   



zPop2PianoDenseActDense.forwardr8   r9   r:   r   r#   r7   r;   r-   r-   r+   r.   r=   _   s    r=   c                       r<   )Pop2PianoDenseGatedActDenser>   c                    sj   t    tj|j|jdd| _tj|j|jdd| _tj|j|jdd| _t	|j
| _t|j | _d S r?   )r"   r#   r   rB   rC   rD   wi_0wi_1rF   rG   rH   rI   r	   rJ   rK   rL   r+   r-   r.   r#   w   s   
z$Pop2PianoDenseGatedActDense.__init__c                 C   sz   |  | |}| |}|| }| |}t| jjtjr6|j	| jjj	kr6| jjj	tj
kr6|| jjj	}| |}|S rM   )rK   rS   rT   rI   rN   rF   r'   r%   rO   r4   rP   r1   )r)   r6   Zhidden_geluZhidden_linearr-   r-   r.   r7      s   


z#Pop2PianoDenseGatedActDense.forwardrQ   r-   r-   r+   r.   rR   v   s    rR   c                       r<   )Pop2PianoLayerFFr>   c                    sJ   t    |jrt|| _nt|| _t|j|jd| _	t
|j| _d S )Nr*   )r"   r#   Zis_gated_actrR   DenseReluDenser=   r    rC   layer_norm_epsilon
layer_normr   rG   rH   rI   rL   r+   r-   r.   r#      s   

zPop2PianoLayerFF.__init__c                 C   s&   |  |}| |}|| | }|S rM   )rY   rW   rI   )r)   r6   Zforwarded_statesr-   r-   r.   r7      s   

zPop2PianoLayerFF.forwardrQ   r-   r-   r+   r.   rU      s    
rU   c                       sl   e Zd Z		ddedee f fddZdd ZedddZ	dddZ
									dddZ  ZS )Pop2PianoAttentionFNr>   	layer_idxc                    s  t    |j| _|| _|j| _|j| _|j| _|j| _|j	| _
|j| _| j
| j | _|| _|d u r@| jr@td| jj d tj| j| jdd| _tj| j| jdd| _tj| j| jdd| _tj| j| jdd| _| jrxt| j| j
| _t | _d| _d S )NzInstantiating a decoder z without passing `layer_idx` is not recommended and will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` when creating this class.Fr@   )r"   r#   
is_decoderhas_relative_attention_biasrelative_attention_num_bucketsrelative_attention_max_distancerC   d_kvkey_value_proj_dim	num_headsn_headsrH   rI   	inner_dimr[   loggerwarning_oncer,   r8   r   rB   qkvo	Embeddingrelative_attention_biassetpruned_headsgradient_checkpointingr)   r>   r]   r[   r+   r-   r.   r#      s.   

zPop2PianoAttention.__init__c                 C   s   t |dkrd S t|| j| j| j\}}t| j|| _t| j|| _t| j|| _t| j	|dd| _	| jt | | _| j| j | _
| j|| _d S )Nr   r   dim)lenr   rc   ra   rn   r   rg   rh   ri   rj   rd   union)r)   Zheadsindexr-   r-   r.   prune_heads   s   zPop2PianoAttention.prune_headsT       c                 C   s   d}|r|d }|| dk tj| 7 }t| } n
t| t|  } |d }| |k }|t|  | t||  ||   tj }t|t	||d }|t
|| |7 }|S )a  
        Adapted from Mesh Tensorflow:
        https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593

        Translate relative position to a bucket number for relative attention. The relative position is defined as
        memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
        position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
        small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
        positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
        This should allow for more graceful generalization to longer sequences than the model has been trained on

        Args:
            relative_position: an int32 Tensor
            bidirectional: a boolean - whether the attention is bidirectional
            num_buckets: an integer
            max_distance: an integer

        Returns:
            a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
        r   r/   r   )r1   r%   longabsminZ
zeros_likelogfloatmathZ	full_likewhere)relative_positionbidirectionalnum_bucketsmax_distanceZrelative_bucketsZ	max_exactZis_smallZrelative_position_if_larger-   r-   r.   _relative_position_bucket   s*   z,Pop2PianoAttention._relative_position_bucketc           
      C   s   |du r	| j jj}|du rtj|tj|ddddf }n|dddf |}tj|tj|ddddf }|| }| j|| j | j	| j
d}|  |}	|	g dd}	|	S )z%Compute binned relative position biasN)r4   device)r   r   r   )r/   r   r   r   )rl   r'   r   r%   arangery   r1   r   r\   r^   r_   Zpermute	unsqueeze)
r)   query_length
key_lengthr   cache_positionZcontext_positionZmemory_positionr   Zrelative_position_bucketvaluesr-   r-   r.   compute_bias  s    
 
zPop2PianoAttention.compute_biasc                 C   s  |j dd \}}|du}| |}||d| j| jdd}|dur4|j| j}|r1|j	}n|j
}|r8|n|}|rO|durO|rO|j| j }|j| j }nE| |}| |}||d| j| jdd}||d| j| jdd}|dur|s}|
nd}
|||| jd|
i\}}|rd|j| j< t||dd}|du r|j d }|dur|n|
d d }| jstjd| j||f|j|jd	}| jr| jrd|_n| j|||j|
d
}|dddd| dddf }|dur|ddddddd|j d f }|| }| jr%t|j d }d|t| j< |dd| f }n|}||7 }tjj |! dd"|}tjj#|| j#| jd}|durL|| }t||}|dd$ }||d| j%}| &|}|||f}|	rt||f }|S )z
        Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
        Nr/   r0   r   r   Tr   )r   r4   )r   r   r   rq   )ptraining)'shaperg   viewrc   ra   Z	transpose
is_updatedgetr[   Zcross_attention_cacheself_attention_cacheZ	key_cacheZvalue_cacherh   ri   updater%   matmulr]   Zzerosr   r4   ro   r   Zrequires_gradr   rn   r&   listboolr   Z
functionalZsoftmaxr}   Ztype_asrI   
contiguousrd   rj   )r)   r6   maskkey_value_statesposition_biaspast_key_valuelayer_head_maskr   	use_cacheoutput_attentionsr   
batch_size
seq_lengthZis_cross_attentionZquery_statesr   Zcurr_past_key_valueZcurrent_statesZ
key_statesZvalue_statesZscoresr   Zreal_seq_lengthcausal_maskZposition_bias_maskedZattn_weightsZattn_outputoutputsr-   r-   r.   r7     sx   





"
&



zPop2PianoAttention.forwardFN)Trw   rx   )NN)	NNNNNNFFN)r8   r9   r:   r   r   intr#   rv   staticmethodr   r   r7   r;   r-   r-   r+   r.   rZ      s,    #
/rZ   c                       s@   e Zd Zddee f fddZ							d	ddZ  ZS )
Pop2PianoLayerSelfAttentionFNr[   c                    s>   t    t|||d| _t|j|jd| _t	|j
| _d S )Nr]   r[   rV   )r"   r#   rZ   SelfAttentionr    rC   rX   rY   r   rG   rH   rI   rp   r+   r-   r.   r#     s   
z$Pop2PianoLayerSelfAttention.__init__c	              
   C   sL   |  |}	| j|	|||||||d}
|| |
d  }|f|
dd   }|S )N)r   r   r   r   r   r   r   r   r   )rY   r   rI   )r)   r6   attention_maskr   r   r   r   r   r   normed_hidden_statesattention_outputr   r-   r-   r.   r7     s   

z#Pop2PianoLayerSelfAttention.forwardr   )NNNNFFNr8   r9   r:   r   r   r#   r7   r;   r-   r-   r+   r.   r     s    r   c                       sB   e Zd Zddee f fddZ								d	ddZ  ZS )
Pop2PianoLayerCrossAttentionNr[   c                    s>   t    t|d|d| _t|j|jd| _t	|j
| _d S )NFr   rV   )r"   r#   rZ   EncDecAttentionr    rC   rX   rY   r   rG   rH   rI   )r)   r>   r[   r+   r-   r.   r#     s   
z%Pop2PianoLayerCrossAttention.__init__Fc                 C   sP   |  |}| j|||||||||	|
d
}|| |d  }|f|dd   }|S )N)	r   r   r   r   r   r   r   r   r   r   r   )rY   r   rI   )r)   r6   r   r   r   r   r   r   r   r   r   r   r   Zlayer_outputr   r-   r-   r.   r7     s    
z$Pop2PianoLayerCrossAttention.forwardrM   )NNNNFNFNr   r-   r-   r+   r.   r     s    
r   c                       sJ   e Zd Zd	dee f fddZ												d
ddZ  ZS )Pop2PianoBlockFNr[   c                    s`   t    |j| _t | _| jt|||d | jr&| jt||d | jt	| d S )Nr   )r[   )
r"   r#   r\   r   
ModuleListlayerappendr   r   rU   rp   r+   r-   r.   r#     s   

zPop2PianoBlock.__init__Tc                 C   s  | j d |||||	|
||d}|d d \}}	|dd  }|jtjkrDtt| t|jjd t|jj}tj	|| |d}| j
oJ|d u}|r| j d ||||||	|d d |
|d	}|d d \}}	|jtjkrtt| t|jjd t|jj}tj	|| |d}||dd   }| j d |}|jtjkrtt| t|jjd t|jj}tj	|| |d}|f}|
r||	f | }|S || }|S )	Nr   )r   r   r   r   r   r   r   r/   i  )r{   maxr   r0   )r   r   r   r   r   r   r   r   )r   r4   r%   r5   r   isinfanyfinfor   clampr\   )r)   r6   r   r   encoder_hidden_statesencoder_attention_maskencoder_decoder_position_biasr   cross_attn_layer_head_maskr   r   r   return_dictr   Zself_attention_outputsZattention_outputsZclamp_valueZdo_cross_attentionZcross_attention_outputsr   r-   r-   r.   r7     sn   

zPop2PianoBlock.forwardr   )NNNNNNNNFFTNr   r-   r-   r+   r.   r     s    r   c                   @   s@   e Zd ZeZdZdZdZdZdZ	dgZ
dgZdd Zdd	 Zd
S )Pop2PianoPreTrainedModelZtransformerFTr   rF   c                 C   s  | j j}t|tr|jj|d  dS t|tr'|jjjj	d|d d dS t|t
rS|jjjj	d|d d t|drO| j jsQ|jjjj	d|d d dS dS dS t|tr|jjjj	d|| j jd  d t|jdr{|jjdur{|jjj  |jjjj	d|| j jd  d t|jdr|jjdur|jjj  dS dS dS t|tr|jjjj	d|| j jd  d t|jdr|jjdur|jjj  |jjjj	d|| j jd  d t|jdr|jjdur|jjj  |jjjj	d|| j jd  d t|jdr|jjdur|jjj  dS dS dS t|tr}| j j}| j j}| j j}|jjjj	d||| d  d |jjjj	d||d  d |jjjj	d||d  d |jjjj	d||| d  d |j r|j!jjj	d||d  d dS dS dS )zInitialize the weights      ?        )r3   Zstdlm_head      rA   N)"r>   Zinitializer_factorrN   r    r'   dataZfill_Pop2PianoConcatEmbeddingToMel	embeddingZnormal_!Pop2PianoForConditionalGenerationsharedhasattrtie_word_embeddingsr   r=   rE   rC   rA   Zzero_rF   rD   rR   rS   rT   rZ   r`   rb   rg   rh   ri   rj   r]   rl   )r)   modulefactorrC   ra   rc   r-   r-   r.   _init_weightsI  sR   



        
z&Pop2PianoPreTrainedModel._init_weightsc                 C   s   | j j}| j j}|d u rtdt|r1t|jd d d |}tj||dd df gdd}n|	|j}|dd df 
 |ddd f< ||d< |d u rStd||d	k| |S )
Nzoself.model.config.decoder_start_token_id has to be defined. In Pop2Piano it is usually set to the pad_token_id.r0   )r   .rq   r   ).r   z1self.model.config.pad_token_id has to be defined.)r>   decoder_start_token_idpad_token_id
ValueErrorr   r%   fullr   catZ	new_zeroscloneZmasked_fill_)r)   	input_idsr   r   Zshifted_input_idsr-   r-   r.   _shift_rightw  s      z%Pop2PianoPreTrainedModel._shift_rightN)r8   r9   r:   r   Zconfig_classZbase_model_prefixZis_parallelizableZsupports_gradient_checkpointingZ_supports_cache_classZ_supports_static_cacheZ_no_split_modulesZ_keep_in_fp32_modulesr   r   r-   r-   r-   r.   r   >  s    .r   c                       s   e Zd Zd fdd	Zdd Zdd Z													ddd	Z	
ddeej	df dej	dej	de
def
ddZedej	dededejdej	defddZ  ZS )Pop2PianoStackNc                    sx   t    || _ j| _t fddt jD | _t	 j
 jd| _t j| _|   d| _d | _d| _d S )Nc                    s"   g | ]}t  t|d k|dqS )r   r   )r   r   ).0ir>   r-   r.   
<listcomp>  s    z+Pop2PianoStack.__init__.<locals>.<listcomp>rV   F)r"   r#   embed_tokensr\   r   r   range
num_layersblockr    rC   rX   final_layer_normrG   rH   rI   	post_initZmodel_parallelZ
device_mapro   )r)   r>   r   r+   r   r.   r#     s   

zPop2PianoStack.__init__c                 C      | j S rM   r   r)   r-   r-   r.   get_input_embeddings     z#Pop2PianoStack.get_input_embeddingsc                 C   
   || _ d S rM   r   r)   Znew_embeddingsr-   r-   r.   set_input_embeddings     
z#Pop2PianoStack.set_input_embeddingsc           )      C   s  |	d ur|	n| j j}	|
d ur|
n| j j}
|d ur|n| j j}|d ur$|n| j j}|d urB|d urB| jr5dnd}td| d| d|d urS| }|d|d }n|d ur`| d d }n| jrednd}td| d| d	| j	r| j
r|	rtd
 d}	|d u r| jd u rtd| |}|\}}|	du r| jstd|  dd}d}| jr|	s|d urt|trt|tsd}t|t }n#t|tsd}td t|}n|d u rtt t }n| jsd }|d ur| nd}|d u rtj||| |jd}|d u rt s|| }tj|||jd}| j jr0| ||||d ur+|jnd |
}n|d d d d d d f }|j|jd}d| t|jj }| jru|d uru| \}}}||f}|d u rotj||jd}| |}nd }|  || j j!}|  || j j!}|rdnd }|
rdnd }|
r| jrdnd }d }d } | "|}!t#| j$D ]\}"}#||" }$||" }%|r||!f }| j	r| j
r| %|#j&|!||||| |$|%d |	|
|}&n|#|!||||| |$|%||	|
|d}&|	du r |&d d d |&dd   }&|&d d \}!}'|&d }| jr|d ur|&|
rdnd } |
r3||&d f }| jr3||&d f }q| '|!}!| "|!}!|rG||!f }|	rL|'nd }(|rT|j}(|r[|( }(|slt)dd |!|(|||fD S t*|!|(|||dS ) NZdecoder_ zYou cannot specify both zinput_ids and zinputs_embeds at the same timer0   zYou have to specify either zinput_ids or inputs_embedszZ`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...Fz<You have to initialize the model with valid token embeddingsTz)`use_cache` can only be set to `True` if z is used as a decoderzPassing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.r   r   )r4   r   r-   )r   r   r   r   r   r   r   r   r   r   r   r   rM   r/      r      c                 s   s    | ]	}|d ur|V  qd S rM   r-   )r   ri   r-   r-   r.   	<genexpr>t  s    z)Pop2PianoStack.forward.<locals>.<genexpr>)last_hidden_statepast_key_valuesr6   
attentionscross_attentions)+r>   r   r   output_hidden_statesuse_return_dictr\   r   sizer   ro   r   re   rf   r   rN   r
   r   r   Zfrom_legacy_cacheget_seq_lengthr%   r   r   r   r&   _update_causal_maskr   r1   r4   r   r{   Zinvert_attention_maskZget_head_maskr   rI   	enumerater   Z_gradient_checkpointing_funcr7   r   Zto_legacy_cachetupler   ))r)   r   r   r   r   r   	head_maskcross_attn_head_maskr   r   r   r   r   r   Zerr_msg_prefixZinput_shaper   r   Zreturn_legacy_cacheZreturn_self_attention_cachepast_key_values_lengthZmask_seq_lengthr   Zencoder_batch_sizeZencoder_sequence_length_Zencoder_hidden_shapeZencoder_extended_attention_maskZall_hidden_statesZall_attentionsZall_cross_attentionsr   r   r6   r   Zlayer_moduler   r   Zlayer_outputsZnext_decoder_cacheZ
next_cacher-   r-   r.   r7     s2  











zPop2PianoStack.forwardFr   r   input_tensorr   r   r   c                 C   s:  | j jdkr|d ur|dk r|S d S | j jdkr&t|tjr$t|}|S |d ur.| nd}|d ur7|jnd}| j jdkrO|sO|sOt	j
|||| jdrOd S |j}|jd }	|r^| }
nt|tjri|jd	 n||	 d }
| j||	|
|||jd d
}| j jdkr|d ur|jjdv r|st|j}t	||}|S )NZflash_attention_2r   Zflex_attentionr   FZsdpa)r   r   Zis_trainingr   r0   )sequence_lengthtarget_lengthr4   r   r   )cudaZxpuZnpu)r>   Z_attn_implementationr   rN   r%   rO   r   r   Zis_compileabler   Z_ignore_causal_mask_sdpar   r4   r   Zget_max_cache_shape5_prepare_4d_causal_attention_mask_with_cache_positionr   typer   r{   Z_unmask_unattended)r)   r   r   r   r   r   Zpast_seen_tokensZusing_compilable_cacher4   r  r  r   	min_dtyper-   r-   r.   r     sT   




z"Pop2PianoStack._update_causal_maskr  r  r4   r   c                 K   sD  | dur|   dkr| }|S t|j}tj||f|||jd}|dkr+tj|dd}|tj||jd|ddk9 }|ddddddf 	|ddd}| dur|
 }| jd }	|ddddddd|	f | ddddddf |j }
|
dk}
|ddddddd|	f |
||ddddddd|	f< |S )	aM  
        Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
        `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.

        Args:
            attention_mask (`torch.Tensor`):
                A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
                `(batch_size, 1, query_length, key_value_length)`.
            sequence_length (`int`):
                The sequence length being processed.
            target_length (`int`):
                The target length: when generating with static cache, the mask should be as long as the static cache,
                to account for the 0 padding, the part of the cache that is not filled yet.
            dtype (`torch.dtype`):
                The dtype to use for the 4D attention mask.
            cache_position (`torch.Tensor`):
                Indices depicting the position of the input sequence tokens in the sequence.
            batch_size (`torch.Tensor`):
                Batch size.
        Nr   )Z
fill_valuer4   r   r   )Zdiagonalr   r0   r   )rr   r%   r   r{   r   r   Ztriur   Zreshapeexpandr   r   r1   Zmasked_fill)r   r  r  r4   r   r   kwargsr   r  Zmask_lengthZpadding_maskr-   r-   r.   r    s,    $
6  zDPop2PianoStack._prepare_4d_causal_attention_mask_with_cache_positionrM   )NNNNNNNNNNNNN)F)r8   r9   r:   r#   r   r   r7   r   r%   rO   r
   r   r   r   r   r4   r  r;   r-   r-   r+   r.   r     sZ    
 ]
Dr   c                       s(   e Zd ZdZ fddZdd Z  ZS )r   z'Embedding Matrix for `composer` tokens.c                    s"   t    tj|j|jd| _d S )N)Znum_embeddingsZembedding_dim)r"   r#   r   rk   composer_vocab_sizerC   r   rL   r+   r-   r.   r#     s   
z&Pop2PianoConcatEmbeddingToMel.__init__c                 C   s.   || }|  |d}tj||gdd}|S )Nr   rq   )r   r   r%   r   )r)   featureindex_valueembedding_offsetZindex_shiftedZcomposer_embeddingr   r-   r-   r.   r7     s   z%Pop2PianoConcatEmbeddingToMel.forward)r8   r9   r:   __doc__r#   r7   r;   r-   r-   r+   r.   r     s    r   zA
    Pop2Piano Model with a `language modeling` head on top.
    )Zcustom_introc                *       s  e Zd Zg dZdef fddZdd Zdd Zd	d
 Zdd Z	dd Z
dd Z	d2dejdededeej fddZe																		d3deej deej deej deej deej deej deej deeeej   deeeej   d eej deej d!eej d"eej d#ee d$ee d%ee d&ee d'eej d(eeej ef f&d)d*Ze 		+	d4 fd,d-	Zd"ejfd.d/Zd0d1 Z  ZS )5r   )zencoder.embed_tokens.weightzdecoder.embed_tokens.weightzlm_head.weightr>   c                    s   t  | || _|j| _t|j|j| _t	|| _
t|}d|_d|_d|_t|| j| _t|}d|_d|_|j|_t|| j| _tj|j|jdd| _|   d S )NFTr@   )r"   r#   r>   rC   	model_dimr   rk   Z
vocab_sizer   r   mel_conditionercopydeepcopyr\   r   Zis_encoder_decoderr   encoderZnum_decoder_layersr   decoderrB   r   r   )r)   r>   Zencoder_configZdecoder_configr+   r-   r.   r#     s"   


z*Pop2PianoForConditionalGeneration.__init__c                 C   r   rM   )r   r   r-   r-   r.   r   6  r   z6Pop2PianoForConditionalGeneration.get_input_embeddingsc                 C   s"   || _ | j| | j| d S rM   )r   r  r   r  r   r-   r-   r.   r   9  s   z6Pop2PianoForConditionalGeneration.set_input_embeddingsc                 C   r   rM   r   r   r-   r-   r.   set_output_embeddings>  r   z7Pop2PianoForConditionalGeneration.set_output_embeddingsc                 C   r   rM   r  r   r-   r-   r.   get_output_embeddingsA  r   z7Pop2PianoForConditionalGeneration.get_output_embeddingsc                 C   r   rM   )r  r   r-   r-   r.   get_encoderD  r   z-Pop2PianoForConditionalGeneration.get_encoderc                 C   r   rM   )r  r   r-   r-   r.   get_decoderG  r   z-Pop2PianoForConditionalGeneration.get_decoderNinput_featurescomposergeneration_configr   c                 C   s   |j }|| vrtdt|  d| || }tj|| jd}||jd }t	|
 }| j|||d}|durad||dddf   < tj|dddf dd	|gd	d
}||fS |dfS )a  
        This method is used to concatenate mel conditioner tokens at the front of the input_features in order to
        control the type of MIDI token generated by the model.

        Args:
            input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
                input features extracted from the feature extractor.
            composer (`str`):
                composer token which determines the type of MIDI tokens to be generated.
            generation_config (`~generation.GenerationConfig`):
                The generation is used to get the composer-feature_token pair.
            attention_mask (``, *optional*):
                For batched generation `input_features` are padded to have the same shape across all examples.
                `attention_mask` helps to determine which areas were padded and which were not.
                - 1 for tokens that are **not padded**,
                - 0 for tokens that are **padded**.
        zPlease choose a composer from z. Composer received - r   r   )r
  r  r  Nr   r0   r   )Zaxis)composer_to_feature_tokenkeysr   r   r%   Ztensorr   repeatr   r{   r   r  r   Zconcatenater   )r)   r  r  r  r   r  Zcomposer_valuer  r-   r-   r.   get_mel_conditioner_outputsJ  s&   &z=Pop2PianoForConditionalGeneration.get_mel_conditioner_outputsr   decoder_input_idsdecoder_attention_maskr   decoder_head_maskr   encoder_outputsr   r   decoder_inputs_embedslabelsr   r   r   r   r   returnc                 C   s  |dur|n| j j}|dur|n| j j}|
dur |dur td|dur*|
du r*|}
|du r;| j|||
||||d}n$|r_t|ts_t|d t|dkrP|d ndt|dkr[|d ndd}|d }|durt|du rt|du rt| |}| j	||||	|||||||||d}|d }| j j
r|| jd	  }| |}d}|durtd
d}||d|d|d}|s|f|dd  | }|dur|f| S |S t|||j|j|j|j|j|j|jd	S )a`  
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            Indices of input sequence tokens in the vocabulary. Pop2Piano is a model with relative position embeddings
            so you should be able to pad the inputs on both the right and the left. Indices can be obtained using
            [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for detail.
            [What are input IDs?](../glossary#input-ids) To know more on how to prepare `input_ids` for pretraining
            take a look a [Pop2Piano Training](./Pop2Piano#training).
        decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
            Indices of decoder input sequence tokens in the vocabulary. Indices can be obtained using
            [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
            [What are decoder input IDs?](../glossary#decoder-input-ids) Pop2Piano uses the `pad_token_id` as the
            starting token for `decoder_input_ids` generation. If `past_key_values` is used, optionally only the last
            `decoder_input_ids` have to be input (see `past_key_values`). To know more on how to prepare
        decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
            Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
            be used by default.
        decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
            Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0,
            1]`:
            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.
        cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
            Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in
            `[0, 1]`:
            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.
        input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
            Does the same task as `inputs_embeds`. If `inputs_embeds` is not present but `input_features` is present
            then `input_features` will be considered as `inputs_embeds`.
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
            config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
            labels in `[0, ..., config.vocab_size]`
        NzSBoth `inputs_embeds` and `input_features` received! Please provide only one of them)r   r   r   r   r   r   r   r   r   r/   )r   r6   r   )r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   )Zignore_indexr0   )	lossZlogitsr   Zdecoder_hidden_statesZdecoder_attentionsr   Zencoder_last_hidden_stater   Zencoder_attentions)r>   r   r   r   r  rN   r   rs   r   r  r   r  r   r   r   r   r   r   r6   r   r   r   )r)   r   r   r   r!  r   r"  r   r#  r   r   r  r$  r%  r   r   r   r   r   r6   Zdecoder_outputsZsequence_outputZ	lm_logitsr'  Zloss_fctoutputr-   r-   r.   r7   {  s|   8	


z)Pop2PianoForConditionalGeneration.forward	composer1c                    s   |du r| j }|jd	i | t|dstdt|j| jjkr1td| jj dt|j d| j||||d\}}t	 j
d	d|||d|S )
a  
        Generates token ids for midi outputs.

        <Tip warning={true}>

        Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
        model's default generation configuration. You can override any `generation_config` by passing the corresponding
        parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`. For an overview of generation
        strategies and code examples, check out the [following guide](./generation_strategies).

        </Tip>

        Parameters:
            input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
                This is the featurized version of audio generated by `Pop2PianoFeatureExtractor`.
            attention_mask:
                For batched generation `input_features` are padded to have the same shape across all examples.
                `attention_mask` helps to determine which areas were padded and which were not.
                - 1 for tokens that are **not padded**,
                - 0 for tokens that are **padded**.
            composer (`str`, *optional*, defaults to `"composer1"`):
                This value is passed to `Pop2PianoConcatEmbeddingToMel` to generate different embeddings for each
                `"composer"`. Please make sure that the composet value is present in `composer_to_feature_token` in
                `generation_config`. For an example please see
                https://huggingface.co/sweetcocoa/pop2piano/blob/main/generation_config.json .
            generation_config (`~generation.GenerationConfig`, *optional*):
                The generation configuration to be used as base parametrization for the generation call. `**kwargs`
                passed to generate matching the attributes of `generation_config` will override them. If
                `generation_config` is not provided, the default will be used, which had the following loading
                priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
                configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
                default values, whose documentation should be checked to parameterize generation.
            kwargs:
                Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
                forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
                specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
        Return:
            [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
            or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`.
                Since Pop2Piano is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
                [`~utils.ModelOutput`] types are:
                    - [`~generation.GenerateEncoderDecoderOutput`],
                    - [`~generation.GenerateBeamEncoderDecoderOutput`]
        Nr  z`composer_to_feature_token` was not found! Please refer to https://huggingface.co/sweetcocoa/pop2piano/blob/main/generation_config.jsonand parse a dict like that.ztconfig.composer_vocab_size must be same as the number of keys in generation_config.composer_to_feature_token! Found z vs .)r  r   r  r  )Zinputsr   r   r  r-   )r  r   r   r   rs   r  r>   r	  r  r"   generate)r)   r  r   r  r  r  r+   r-   r.   r+    s:   6

z*Pop2PianoForConditionalGeneration.generatec                 C   s
   |  |S rM   )r   )r)   r%  r-   r-   r.   %prepare_decoder_input_ids_from_labels]  r   zGPop2PianoForConditionalGeneration.prepare_decoder_input_ids_from_labelsc              	   C   s   |d u rt d |S d}|D ]M}d}|D ]}||d||jf }q|d j|d jkr@td|d j d|d j dt|t|krWtdt| dt| d||f }q|S )	NzHYou might want to consider setting `use_cache=True` to speed up decodingr-   r   z%reordered_layer_past_states[0] shape z  and layer_past_states[0] shape z mismatchedz&length of reordered_layer_past_states z! and length of layer_past_states )re   warningZindex_selectr1   r   r   r   rs   )r)   r   Zbeam_idxZreordered_decoder_pastZlayer_past_statesZreordered_layer_past_statesZlayer_past_stater-   r-   r.   _reorder_cache`  s(   
z0Pop2PianoForConditionalGeneration._reorder_cacherM   )NNNNNNNNNNNNNNNNNN)Nr)  N) r8   r9   r:   Z_tied_weights_keysr   r#   r   r   r  r  r  r  r%   ZFloatTensorstrr   r   r  r   Z
LongTensorZ
BoolTensorrO   r   r   r   r   r7   Zno_gradr+  r,  r.  r;   r-   r-   r+   r.   r     s    
1	
 Yr   )Fr  r  r~   typingr   r   r   r%   r   Ztorch.nnr   Ztransformers.generationr   Zactivationsr	   Zcache_utilsr
   r   r   Z
generationr   Zmodeling_attn_mask_utilsr   Zmodeling_outputsr   r   r   Zmodeling_utilsr   Zpytorch_utilsr   r   r   utilsr   r   r   r   r   Zconfiguration_pop2pianor   Z!torch.nn.attention.flex_attentionr   Zintegrations.flex_attentionr   Z
get_loggerr8   re   Z_load_pop2piano_layer_normZapex.normalizationr   infoImportError	Exceptionr-  Moduler    r   r=   rR   rU   rZ   r   r   r   r   r   r   r   __all__r-   r-   r-   r.   <module>   sr   


 f%'fS  u  i