o
    Zh8                     @   s   d Z ddlZddlZddlmZmZmZmZ ddlZ	ddl
mZ ddlmZmZmZmZmZ ddlmZmZmZmZmZ ddlmZ e rIddlZeeZd	d
iZdd Zdd Z eddG dd deZ!dgZ"dS )z!Tokenization class for Pop2Piano.    N)ListOptionalTupleUnion   )BatchFeature)
AddedTokenBatchEncodingPaddingStrategyPreTrainedTokenizerTruncationStrategy)
TensorTypeis_pretty_midi_availableloggingrequires_backendsto_numpy)requiresvocabz
vocab.jsonc                 C   s   || 7 }|d urt ||}|S N)minnumbercutoff_time_idxcurrent_idx r   c/var/www/auris/lib/python3.10/site-packages/transformers/models/pop2piano/tokenization_pop2piano.pytoken_time_to_note(   s   
r   c           	      C   sZ   ||  d ur'||  }||k r%|}| ||| |g |dkrd n|}||| < |S ||| < |S )Nr   )append)	r   current_velocitydefault_velocitynote_onsets_readyr   notes	onset_idx
offset_idxZonsets_readyr   r   r   token_note_to_note0   s   r$   )pretty_midiZtorch)backendsc                       s  e Zd ZdZddgZeZ							dC fd
d	Zedd Z	dd Z
dedefddZdDdefddZdejdededefddZ			dEdejdejdededef
d d!Z	"dFdejd#edee fd$d%ZdGd'ejdejd(efd)d*ZdFd+ed,ee dee fd-d.Z	"	"dHd'eejeej f d/ee d0ee defd1d2Z 	"	"dHd'eejeej f d/ee d0ee defd3d4Z!	5	"	"	"	"	"	6dId'eejeej eeej  f d7ee"ee#f d8ee"eef d0ee d9ee d:ee" d;eeee$f  d<e"defd=d>Z%	6dJd?e&d@e"fdAdBZ'  Z(S )KPop2PianoTokenizera  
    Constructs a Pop2Piano tokenizer. This tokenizer does not require training.

    This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
    this superclass for more information regarding those methods.

    Args:
        vocab (`str`):
            Path to the vocab file which contains the vocabulary.
        default_velocity (`int`, *optional*, defaults to 77):
            Determines the default velocity to be used while creating midi Notes.
        num_bars (`int`, *optional*, defaults to 2):
            Determines cutoff_time_idx in for each token.
        unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"-1"`):
            The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
            token instead.
        eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to 1):
            The end of sequence token.
        pad_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to 0):
             A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by
            attention mechanisms or loss computation.
        bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to 2):
            The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
    	token_idsattention_maskM      -1102c           
         s   t |trt|dddn|}t |trt|dddn|}t |tr(t|dddn|}t |tr6t|dddn|}|| _|| _t|d}	t|	| _W d    n1 sTw   Y  dd | j	 D | _
t jd||||d| d S )NF)lstriprstriprbc                 S   s   i | ]\}}||qS r   r   ).0kvr   r   r   
<dictcomp>u       z/Pop2PianoTokenizer.__init__.<locals>.<dictcomp>)	unk_token	eos_token	pad_token	bos_tokenr   )
isinstancestrr   r   num_barsopenjsonloadencoderitemsdecodersuper__init__)
selfr   r   r>   r8   r9   r:   r;   kwargsfile	__class__r   r   rF   ]   s$   
zPop2PianoTokenizer.__init__c                 C   s
   t | jS )z-Returns the vocabulary size of the tokenizer.)lenrB   rG   r   r   r   
vocab_size   s   
zPop2PianoTokenizer.vocab_sizec                 C   s   t | jfi | jS )z(Returns the vocabulary of the tokenizer.)dictrB   Zadded_tokens_encoderrM   r   r   r   	get_vocab   s   zPop2PianoTokenizer.get_vocabtoken_idreturnc                 C   sH   | j || j d}|d}d|dd t|d }}||gS )a?  
        Decodes the token ids generated by the transformer into notes.

        Args:
            token_id (`int`):
                This denotes the ids generated by the transformers to be converted to Midi tokens.

        Returns:
            `List`: A list consists of token_type (`str`) and value (`int`).
        Z_TOKEN_TIME_   Nr   )rD   getr8   splitjoinint)rG   rQ   Ztoken_type_value
token_typevaluer   r   r   _convert_id_to_token   s   
 z'Pop2PianoTokenizer._convert_id_to_token
TOKEN_TIMEc                 C   s   | j | d| t| jS )a  
        Encodes the Midi tokens to transformer generated token ids.

        Args:
            token (`int`):
                This denotes the token value.
            token_type (`str`):
                This denotes the type of the token. There are four types of midi tokens such as "TOKEN_TIME",
                "TOKEN_VELOCITY", "TOKEN_NOTE" and "TOKEN_SPECIAL".

        Returns:
            `int`: returns the id of the token.
        rS   )rB   rU   rX   r8   )rG   tokenrY   r   r   r   _convert_token_to_id   s   z'Pop2PianoTokenizer._convert_token_to_idtokensbeat_offset_idxbars_per_batchr   c                 C   s   d}t t|D ]1}|| }||| d  }|| }	| j|||	d}
t|
dkr)q|du r0|
}qtj||
fdd}q|du r@g S |S )a  
        Converts relative tokens to notes which are then used to generate pretty midi object.

        Args:
            tokens (`numpy.ndarray`):
                Tokens to be converted to notes.
            beat_offset_idx (`int`):
                Denotes beat offset index for each note in generated Midi.
            bars_per_batch (`int`):
                A parameter to control the Midi output generation.
            cutoff_time_idx (`int`):
                Denotes the cutoff time index for each note in generated Midi.
        N   )	start_idxr   r   )Zaxis)rangerL   relative_tokens_ids_to_notesnpZconcatenate)rG   r_   r`   ra   r   r!   index_tokensZ
_start_idxZ_cutoff_time_idxZ_notesr   r   r   "relative_batch_tokens_ids_to_notes   s$   z5Pop2PianoTokenizer.relative_batch_tokens_ids_to_notesr      beatstepc                 C   s:   |du rdn|}| j ||||d}| j|||| d}|S )al  
        Converts tokens to Midi. This method calls `relative_batch_tokens_ids_to_notes` method to convert batch tokens
        to notes then uses `notes_to_midi` method to convert them to Midi.

        Args:
            tokens (`numpy.ndarray`):
                Denotes tokens which alongside beatstep will be converted to Midi.
            beatstep (`np.ndarray`):
                We get beatstep from feature extractor which is also used to get Midi.
            beat_offset_idx (`int`, *optional*, defaults to 0):
                Denotes beat offset index for each note in generated Midi.
            bars_per_batch (`int`, *optional*, defaults to 2):
                A parameter to control the Midi output generation.
            cutoff_time_idx (`int`, *optional*, defaults to 12):
                Denotes the cutoff time index for each note in generated Midi.
        Nr   )r_   r`   ra   r   )
offset_sec)ri   notes_to_midi)rG   r_   rk   r`   ra   r   r!   Zmidir   r   r   !relative_batch_tokens_ids_to_midi   s   z4Pop2PianoTokenizer.relative_batch_tokens_ids_to_midiNrc   c              	      s^   fdd|D }|}d}dd t tdd  j D d D }g }|D ]5\}	}
|	dkr5|
dkr4 n(q&|	dkrAt|
||d	}q&|	d
krH|
}q&|	dkrXt|
| j|||d}q&tdt|D ]'\}}|dur|du rq|d }nt	||d }t	||}|
||| jg q`t|dkrg S t|}|dddf d |dddf  }||  }|S )a  
        Converts relative tokens to notes which will then be used to create Pretty Midi objects.

        Args:
            tokens (`numpy.ndarray`):
                Relative Tokens which will be converted to notes.
            start_idx (`float`):
                A parameter which denotes the starting index.
            cutoff_time_idx (`float`, *optional*):
                A parameter used while converting tokens to notes.
        c                    s   g | ]}  |qS r   )r[   )r3   r]   rM   r   r   
<listcomp>  r7   zCPop2PianoTokenizer.relative_tokens_ids_to_notes.<locals>.<listcomp>r   c                 S   s   g | ]}d qS r   r   r3   ir   r   r   ro         c                 S   s   g | ]}| d qS )NOTE)endswith)r3   r4   r   r   r   ro     r7   rT   ZTOKEN_SPECIALr\   r   TOKEN_VELOCITY
TOKEN_NOTE)r   r   r   r    r   r!   zToken type not understood!N   )rd   sumrB   keysr   r$   r   
ValueError	enumeratemaxr   rL   rf   arrayZargsort)rG   r_   rc   r   wordsr   r   r    r!   rY   r   pitchZ
note_onsetcutoffr#   Z
note_orderr   rM   r   re      sP   *	


$z/Pop2PianoTokenizer.relative_tokens_ids_to_notes        r!   rl   c                 C   s   t | dg tjddd}tjdd}g }|D ]\}}}	}
tj|
|	|| | || | d}|| q||_|j| |  |S )a  
        Converts notes to Midi.

        Args:
            notes (`numpy.ndarray`):
                This is used to create Pretty Midi objects.
            beatstep (`numpy.ndarray`):
                This is the extrapolated beatstep that we get from feature extractor.
            offset_sec (`int`, *optional*, defaults to 0.0):
                This represents the offset seconds which is used while creating each Pretty Midi Note.
        r%   i  g      ^@)
resolutionZinitial_tempor   )program)velocityr   startend)	r   r%   Z
PrettyMIDIZ
InstrumentNoter   r!   instrumentsZremove_invalid_notes)rG   r!   rk   rl   Znew_pmZnew_instZ	new_notesr"   r#   r   r   Znew_noter   r   r   rm   8  s    

z Pop2PianoTokenizer.notes_to_midisave_directoryfilename_prefixc                 C   s   t j|std| d dS t j||r|d ndtd  }t|d}|t	
| j W d   |fS 1 s=w   Y  |fS )a}  
        Saves the tokenizer's vocabulary dictionary to the provided save_directory.

        Args:
            save_directory (`str`):
                A path to the directory where to saved. It will be created if it doesn't exist.
            filename_prefix (`Optional[str]`, *optional*):
                A prefix to add to the names of the files saved by the tokenizer.
        zVocabulary path (z) should be a directoryN- r   w)ospathisdirloggererrorrW   VOCAB_FILES_NAMESr?   writer@   dumpsrB   )rG   r   r   Zout_vocab_filerI   r   r   r   save_vocabularyX  s   

z"Pop2PianoTokenizer.save_vocabularytruncation_strategy
max_lengthc                 K   s~  t | dg t|d tjrtdd |D dd}t|tj	}|ddddf 
 }d	d t|d
 D }|D ]\}}}	}
|| |	|
g || |	dg q>g }d}t|D ]9\}}t|dkrjq_|| |d |D ]"\}	}
t|
dk}
||
kr|
}|| |
d || |	d quq_t|}|tjkr|r||kr| jd||| |d|\}}}td|iS )a  
        This is the `encode_plus` method for `Pop2PianoTokenizer`. It converts the midi notes to the transformer
        generated token ids. It only works on a single batch, to process multiple batches please use
        `batch_encode_plus` or `__call__` method.

        Args:
            notes (`numpy.ndarray` of shape `[sequence_length, 4]` or `list` of `pretty_midi.Note` objects):
                This represents the midi notes. If `notes` is a `numpy.ndarray`:
                    - Each sequence must have 4 values, they are `onset idx`, `offset idx`, `pitch` and `velocity`.
                If `notes` is a `list` containing `pretty_midi.Note` objects:
                    - Each sequence must have 4 attributes, they are `start`, `end`, `pitch` and `velocity`.
            truncation_strategy ([`~tokenization_utils_base.TruncationStrategy`], *optional*):
                Indicates the truncation strategy that is going to be used during truncation.
            max_length (`int`, *optional*):
                Maximum length of the returned list and optionally padding length (see above).

        Returns:
            `BatchEncoding` containing the tokens ids.
        r%   r   c                 S   s    g | ]}|j |j|j|jgqS r   )r   r   r   r   )r3   Z	each_noter   r   r   ro     s     z2Pop2PianoTokenizer.encode_plus.<locals>.<listcomp>rb   Nr+   c                 S   s   g | ]}g qS r   r   rp   r   r   r   ro     rr   rT   r\   ru   rv   )ZidsZnum_tokens_to_remover   r(   r   )r   r<   r%   r   rf   r}   ZreshaperoundZastypeZint32r|   rd   r   r{   rL   r^   rX   r   ZDO_NOT_TRUNCATEZtruncate_sequencesr	   )rG   r!   r   r   rH   Zmax_time_idxtimesZonsetoffsetr   r   r_   r   rq   timeZ	total_lenrS   r   r   r   encode_pluso  sH   zPop2PianoTokenizer.encode_plusc                 K   sH   g }t t|D ]}|| j|| f||d|d  qtd|iS )a  
        This is the `batch_encode_plus` method for `Pop2PianoTokenizer`. It converts the midi notes to the transformer
        generated token ids. It works on multiple batches by calling `encode_plus` multiple times in a loop.

        Args:
            notes (`numpy.ndarray` of shape `[batch_size, sequence_length, 4]` or `list` of `pretty_midi.Note` objects):
                This represents the midi notes. If `notes` is a `numpy.ndarray`:
                    - Each sequence must have 4 values, they are `onset idx`, `offset idx`, `pitch` and `velocity`.
                If `notes` is a `list` containing `pretty_midi.Note` objects:
                    - Each sequence must have 4 attributes, they are `start`, `end`, `pitch` and `velocity`.
            truncation_strategy ([`~tokenization_utils_base.TruncationStrategy`], *optional*):
                Indicates the truncation strategy that is going to be used during truncation.
            max_length (`int`, *optional*):
                Maximum length of the returned list and optionally padding length (see above).

        Returns:
            `BatchEncoding` containing the tokens ids.
        )r   r   r(   )rd   rL   r   r   r	   )rG   r!   r   r   rH   Zencoded_batch_token_idsrq   r   r   r   batch_encode_plus  s   	z$Pop2PianoTokenizer.batch_encode_plusFTpadding
truncationpad_to_multiple_ofreturn_attention_maskreturn_tensorsverbosec	              	   K   s   t |tjr|jdknt |d t}
| jd|||||d|	\}}}}	|
r;|du r,dn|}| jd|||d|	}n| jd|||d|	}| j|||||||d}|S )	a  
        This is the `__call__` method for `Pop2PianoTokenizer`. It converts the midi notes to the transformer generated
        token ids.

        Args:
            notes (`numpy.ndarray` of shape `[batch_size, max_sequence_length, 4]` or `list` of `pretty_midi.Note` objects):
                This represents the midi notes.

                If `notes` is a `numpy.ndarray`:
                    - Each sequence must have 4 values, they are `onset idx`, `offset idx`, `pitch` and `velocity`.
                If `notes` is a `list` containing `pretty_midi.Note` objects:
                    - Each sequence must have 4 attributes, they are `start`, `end`, `pitch` and `velocity`.
            padding (`bool`, `str` or [`~file_utils.PaddingStrategy`], *optional*, defaults to `False`):
                Activates and controls padding. Accepts the following values:

                - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
                  sequence if provided).
                - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
                  acceptable input length for the model if that argument is not provided.
                - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
                  lengths).
            truncation (`bool`, `str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*, defaults to `False`):
                Activates and controls truncation. Accepts the following values:

                - `True` or `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or
                  to the maximum acceptable input length for the model if that argument is not provided. This will
                  truncate token by token, removing a token from the longest sequence in the pair if a pair of
                  sequences (or a batch of pairs) is provided.
                - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the
                  maximum acceptable input length for the model if that argument is not provided. This will only
                  truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
                - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the
                  maximum acceptable input length for the model if that argument is not provided. This will only
                  truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
                - `False` or `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths
                  greater than the model maximum admissible input size).
            max_length (`int`, *optional*):
                Controls the maximum length to use by one of the truncation/padding parameters. If left unset or set to
                `None`, this will use the predefined model maximum length if a maximum length is required by one of the
                truncation/padding parameters. If the model has no specific maximum input length (like XLNet)
                truncation/padding to a maximum length will be deactivated.
            pad_to_multiple_of (`int`, *optional*):
                If set will pad the sequence to a multiple of the provided value. This is especially useful to enable
                the use of Tensor Cores on NVIDIA hardware with compute capability `>= 7.5` (Volta).
            return_attention_mask (`bool`, *optional*):
                Whether to return the attention mask. If left to the default, will return the attention mask according
                to the specific tokenizer's default, defined by the `return_outputs` attribute.

                [What are attention masks?](../glossary#attention-mask)
            return_tensors (`str` or [`~file_utils.TensorType`], *optional*):
                If set, will return tensors instead of list of python integers. Acceptable values are:

                - `'tf'`: Return TensorFlow `tf.constant` objects.
                - `'pt'`: Return PyTorch `torch.Tensor` objects.
                - `'np'`: Return Numpy `np.ndarray` objects.
            verbose (`bool`, *optional*, defaults to `True`):
                Whether or not to print more information and warnings.

        Returns:
            `BatchEncoding` containing the token_ids.
        r   r   )r   r   r   r   r   NT)r!   r   r   )r   r   r   r   r   r   r   )	r<   rf   ndarrayndimlistZ"_get_padding_truncation_strategiesr   r   pad)rG   r!   r   r   r   r   r   r   r   rH   Z
is_batchedZpadding_strategyr   r(   r   r   r   __call__  sH   $R	
zPop2PianoTokenizer.__call__feature_extractor_outputreturn_midic                 C   s  t t|dot|dot|d}|s |d jd dkr td|r~t|d dddf dk|d jd ksE|d jd |d	 jd kr_td
|jd  d|d jd  d|d	 jd  |d jd |jd kr}td|d jd  d|jd  n'|d jd dks|d	 jd dkrtd|d jd  d|d	 jd  d|rt|d dddf dkd }n|jd g}g }g }d}t|D ]\}	}
|||
 }|dddtt|t	| j
kd d f }|d |	 }|d	 |	 }|r+|d |	 }|d |	 }|dtt|dkd d  }|dtt|dkd d  }t|}t|}t|}| j||| j| jd d d}|jd jD ]}| j|d 7  _| j|d 7  _|| qL|| ||
d 7 }q|r}t||dS td|iS )aF  
        This is the `batch_decode` method for `Pop2PianoTokenizer`. It converts the token_ids generated by the
        transformer to midi_notes and returns them.

        Args:
            token_ids (`Union[np.ndarray, torch.Tensor, tf.Tensor]`):
                Output token_ids of `Pop2PianoConditionalGeneration` model.
            feature_extractor_output (`BatchFeature`):
                Denotes the output of `Pop2PianoFeatureExtractor.__call__`. It must contain `"beatstep"` and
                `"extrapolated_beatstep"`. Also `"attention_mask_beatsteps"` and
                `"attention_mask_extrapolated_beatstep"`
                 should be present if they were returned by the feature extractor.
            return_midi (`bool`, *optional*, defaults to `True`):
                Whether to return midi object or not.
        Returns:
            If `return_midi` is True:
                - `BatchEncoding` containing both `notes` and `pretty_midi.pretty_midi.PrettyMIDI` objects.
            If `return_midi` is False:
                - `BatchEncoding` containing `notes`.
        r)   attention_mask_beatsteps$attention_mask_extrapolated_beatstep	beatstepsr   rT   zattention_mask, attention_mask_beatsteps and attention_mask_extrapolated_beatstep must be present for batched inputs! But one of them were not present.Nextrapolated_beatstepzbLength mistamtch between token_ids, beatsteps and extrapolated_beatstep! Found token_ids length - z, beatsteps shape - z$ and extrapolated_beatsteps shape - z!Found attention_mask of length - z but token_ids of length - zLength mistamtch of beatsteps and extrapolated_beatstep! Since attention_mask is not present the number of examples must be 1, But found beatsteps length - z", extrapolated_beatsteps length - .rb   )r_   rk   ra   r   )r!   Zpretty_midi_objectsr!   )boolhasattrshaperz   rx   rf   wherer{   r|   rX   r9   r   rn   r>   r   r!   r   r   r   r	   )rG   r(   r   r   Zattention_masks_presentZ	batch_idxZ
notes_listZpretty_midi_objects_listrc   rg   Zend_idxZeach_tokens_idsr   r   r   r   Zpretty_midi_objectZnoter   r   r   batch_decodeY  s   
$2$
zPop2PianoTokenizer.batch_decode)r*   r+   r,   r-   r.   r/   )r\   )r   r+   rj   r   )r   )NN)FNNNNNT)T))__name__
__module____qualname____doc__Zmodel_input_namesr   Zvocab_files_namesrF   propertyrN   rP   rX   r   r[   r^   rf   r   ri   rn   floatr   re   rm   r=   r   r   r   r   r%   r   r   r	   r   r   r   r
   r   r   r   r   __classcell__r   r   rJ   r   r'   ?   s    "

0
%
>  
J
.	

 r'   )#r   r@   r   typingr   r   r   r   numpyrf   Zfeature_extraction_utilsr   Ztokenization_utilsr   r	   r
   r   r   utilsr   r   r   r   r   Zutils.import_utilsr   r%   Z
get_loggerr   r   r   r   r$   r'   __all__r   r   r   r   <module>   s0   
     
