o
    Zh                     @   sr  d Z ddlZddlZddlmZmZmZ ddlZddlmZ ddl	m
Z
 ddlmZ ddlmZ dd	lmZmZ dd
lmZmZ ddlmZ ddlmZmZ ddlmZ eeZG dd dejZ G dd dejZ!G dd dej"Z#G dd dej"Z$G dd dej"Z%eG dd deZ&G dd de&Z'eddG dd  d e&Z(ed!dG d"d# d#e&eZ)d#dgZ*dS )$z/PyTorch TrOCR decoder model (based on RoBERTa).    N)OptionalTupleUnion)nn)CrossEntropyLoss   )ACT2FN)GenerationMixin)_prepare_4d_attention_mask!_prepare_4d_causal_attention_mask))BaseModelOutputWithPastAndCrossAttentions!CausalLMOutputWithCrossAttentions)PreTrainedModel)auto_docstringlogging   )TrOCRConfigc                       sJ   e Zd ZdZdedef fddZddejd	ed
ejf fddZ  Z	S )TrOCRLearnedPositionalEmbeddingzN
    This module learns positional embeddings up to a fixed maximum size.
    num_embeddingsembedding_dimc                    s   d| _ t || j  | d S )N   )offsetsuper__init__)selfr   r   	__class__ W/var/www/auris/lib/python3.10/site-packages/transformers/models/trocr/modeling_trocr.pyr   .   s   z(TrOCRLearnedPositionalEmbedding.__init__r   N	input_idspast_key_values_lengthposition_idsc                    s\   |du r |j dd \}}tj||| tj| jjd|d}n|d}t 	|| j
 S )z3`input_ids' shape is expected to be [bsz x seqlen].Nr   )dtypedevicer   )shapetorcharangelongweightr#   expand	unsqueezer   forwardr   )r   r   r    r!   bszseq_lenr   r   r   r,   4   s   
z'TrOCRLearnedPositionalEmbedding.forwardr   N)
__name__
__module____qualname____doc__intr   r&   Tensorr,   __classcell__r   r   r   r   r   )   s    (r   c                
       sL   e Zd ZdZddedededee f fddZd	ej	f fd
dZ
  ZS )TrOCRScaledWordEmbeddingz\
    This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
          ?r   r   padding_idxembed_scalec                    s   t  ||| || _d S N)r   r   r:   )r   r   r   r9   r:   r   r   r   r   H   s   
z!TrOCRScaledWordEmbedding.__init__r   c                    s   t  || j S r;   )r   r,   r:   )r   r   r   r   r   r,   L   s   z TrOCRScaledWordEmbedding.forward)r8   )r0   r1   r2   r3   r4   r   floatr   r&   r5   r,   r6   r   r   r   r   r7   C   s    $r7   c                	       s   e Zd ZdZddededee f fddZeddededee fd	d
Ze	
 dde	jdefddZ	dde	jdedee fddZ  ZS )"TrOCRSinusoidalPositionalEmbeddingzDThis module produces sinusoidal positional embeddings of any length.Nnum_positionsr   r9   c                    sB   t    d| _|| _|| _| |||| _| dt	d d S )Nr   _float_tensorr   )
r   r   r   r   r9   get_embeddingweightsZregister_bufferr&   FloatTensor)r   r>   r   r9   r   r   r   r   S   s   
z+TrOCRSinusoidalPositionalEmbedding.__init__r   c                 C   s   |d }t d|d  }ttj|tjd |  }tj| tjd d|d }tjt	|t
|gdd| d}|d dkrUtj|t| dgdd}|durad||ddf< |t S )	z
        Build sinusoidal embeddings. This matches the implementation in tensor2tensor, but differs slightly from the
        description in Section 3.5 of "Attention Is All You Need".
        r   i'  r   )r"   r   dimr$   N)mathlogr&   expr'   Zint64r<   r+   catsincosviewZzerostoZget_default_dtype)r   r   r9   Zhalf_dimZembr   r   r   r@   [   s    $&z0TrOCRSinusoidalPositionalEmbedding.get_embeddingr   r   r    c                 C   s   |  \}}| || j||j}| jd | }| jd u s&|| j dkr0| || j| j| _| j| j| _| j	d|
d
||d }|S )Nr   r   r$   )size"create_position_ids_from_input_idsr9   rL   r#   rA   r@   r   r?   index_selectrK   detach)r   r   r    r-   r.   r!   Zmax_posxr   r   r   r,   n   s   "z*TrOCRSinusoidalPositionalEmbedding.forwardc                 C   s6   | | }tj|dd|| | }| | S )z
        Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding
        symbols are ignored. This is modified from fairseq's `utils.make_positions`.
        r   rC   )ner4   r&   ZcumsumZtype_asr(   )r   r   r9   r    maskZincremental_indicesr   r   r   rN      s   zETrOCRSinusoidalPositionalEmbedding.create_position_ids_from_input_idsr;   )r   )r0   r1   r2   r3   r4   r   r   staticmethodr@   r&   Zno_gradr5   r,   rN   r6   r   r   r   r   r=   P   s     r=   c                       s   e Zd ZdZ						ddededee d	ee d
edededef fddZde	j
dedefddZ					dde	j
dee	j
 deee	j
  dee	j
 dee	j
 dedee	j
ee	j
 eee	j
  f fddZ  ZS ) TrOCRAttentionz>Multi-headed attention from 'Attention Is All You Need' paper.N        FT	embed_dim	num_headskdimvdimdropout
is_decoderbiasis_cross_attentionc
           
         s   t    || _|d ur|n|| _|d ur|n|| _|| _|| _|| | _| j| | jks9td| j d| d| jd | _	|| _
tj| j||d| _tj| j||d| _tj|||d| _tj|||d| _d S )Nz;embed_dim must be divisible by num_heads (got `embed_dim`: z and `num_heads`: z).g      ࿩r]   )r   r   rW   rY   rZ   rX   r[   head_dim
ValueErrorscalingr\   r   Lineark_projv_projq_projout_proj)
r   configrW   rX   rY   rZ   r[   r\   r]   r^   r   r   r   r      s&   


zTrOCRAttention.__init__tensorr.   r-   c                 C   s    | ||| j| jdd S )Nr   r   )rK   rX   r`   	transpose
contiguous)r   ri   r.   r-   r   r   r   _shape   s    zTrOCRAttention._shapehidden_stateskey_value_statespast_key_valueattention_masklayer_head_maskoutput_attentionsreturnc                 C   sV  |du}|  \}}	}
| || j }|r"|dur"|d }|d }nZ|r9| | |d|}| | |d|}nC|durh| | |d|}| | |d|}tj|d |gdd}tj|d |gdd}n| | |d|}| | |d|}| jr||f}|| j	 d| j
f}| ||	|j| }|j| }|j| }| d}t||dd}|  || j	 |	|fkrtd|| j	 |	|f d|   |dur|  |d|	|fkrtd	|d|	|f d|   ||| j	|	|| }||| j	 |	|}tjj|dd}|dur@|  | j	fkr%td
| j	f d|   |dddd||| j	|	| }||| j	 |	|}|rW||| j	|	|}||| j	 |	|}nd}tjj|| j| jd}t||}|  || j	 |	| j
fkrtd|| j	|	| j
f d|   ||| j	|	| j
}|dd}|||	|
}| |}|||fS )z#Input shape: Batch x Time x ChannelNr   r   r$   r   rC   z$Attention weights should be of size z	, but is z!Attention mask should be of size z/Head mask for a single layer should be of size ptrainingz `attn_output` should be of size )rM   rf   rb   rl   rd   re   r&   rH   r\   rX   r`   rK   Zbmmrj   ra   r   
functionalZsoftmaxr[   rv   Zreshaperg   )r   rm   rn   ro   rp   rq   rr   r^   r-   tgt_lenrW   Zquery_statesZ
key_statesZvalue_statesZ
proj_shapeZsrc_lenZattn_weightsZattn_weights_reshapedZ
attn_probsZattn_outputr   r   r   r,      s   





"

zTrOCRAttention.forward)NNrV   FTF)NNNNF)r0   r1   r2   r3   r4   r   r<   boolr   r&   r5   rl   r   r,   r6   r   r   r   r   rU      s\    	
!rU   c                       s   e Zd Zdef fddZ								ddejdeej d	eej d
eej deej deej deeej  dee	 dee	 fddZ
  ZS )TrOCRDecoderLayerrh   c              
      s   t    |j| _t|| j|j|jdd| _|j| _t	|j
 | _|j| _t| j| _|jrGt|| j|j|j|j|jddd| _t| j| _t| j|j| _t|j| j| _t| j| _d S )NT)rW   rX   r[   r\   )rW   rX   rY   rZ   r[   r\   r^   )r   r   hidden_sizerW   rU   Zdecoder_attention_headsZattention_dropout	self_attnr[   r   Zactivation_functionactivation_fnactivation_dropoutr   	LayerNormself_attn_layer_normr\   Zcross_attention_hidden_sizeencoder_attnencoder_attn_layer_normrc   Zdecoder_ffn_dimfc1fc2final_layer_normr   rh   r   r   r   r   $  s8   

zTrOCRDecoderLayer.__init__NFTrm   rp   encoder_hidden_statesencoder_attention_maskrq   cross_attn_layer_head_maskro   rr   	use_cachec
                 C   sZ  |}
|dur|dd nd}| j |||||d\}}}tjj|| j| jd}|
| }| |}d}d}|durk|}
|durC|dd nd}| j||||||d\}}}tjj|| j| jd}|
| }| |}|| }|}
| | 	|}tjj|| j
| jd}| |}tjj|| j| jd}|
| }| |}|f}|r|||f7 }|	r||f7 }|S )a  
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`): attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
            encoder_hidden_states (`torch.FloatTensor`):
                cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
            encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
            layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
                `(encoder_attention_heads,)`.
            cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
                size *(decoder_attention_heads,)*.
            past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
        Nr   )rm   ro   rp   rq   rr   rt   )rm   rn   rp   rq   ro   rr   )r|   r   rw   r[   rv   r   r   r   r}   r   r~   r   r   )r   rm   rp   r   r   rq   r   ro   rr   r   ZresidualZself_attn_past_key_valueZself_attn_weightsZpresent_key_valueZcross_attn_present_key_valueZcross_attn_weightsZcross_attn_past_key_valueoutputsr   r   r   r,   F  sT   
	



zTrOCRDecoderLayer.forward)NNNNNNFT)r0   r1   r2   r   r   r&   r5   r   r   ry   r,   r6   r   r   r   r   rz   #  s:    %	
rz   c                   @   s&   e Zd ZeZdZdZdgZdd ZdS )TrOCRPreTrainedModelmodelTrz   c                 C   s   | j j}t|tjtjfr%|jjjd|d |j	d ur#|j	j
  d S d S t|tjrD|jjjd|d |jd urF|jj|j 
  d S d S d S )NrV   )meanstd)rh   Zinit_std
isinstancer   rc   ZConv1dr)   dataZnormal_r]   Zzero_	Embeddingr9   )r   moduler   r   r   r   _init_weights  s   

z"TrOCRPreTrainedModel._init_weightsN)	r0   r1   r2   r   Zconfig_classZbase_model_prefixZsupports_gradient_checkpointingZ_no_split_modulesr   r   r   r   r   r     s    r   c                       sX   e Zd ZdZdef fddZdd Zdd Z																								dd
dZ  Z	S )TrOCRDecoderz
    Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`TrOCRDecoderLayer`]

    Args:
        config: TrOCRConfig
    rh   c                    s   t     j| _ j| _ j| _ jrt	 j
nd}t j j
| j|d| _ jr5t j j
| _nt j| j d  j
| j| _ jrOt j
| _nd | _t fddt jD | _d| _|   d S )Nr8   )r:   r   c                    s   g | ]}t  qS r   )rz   ).0_rh   r   r   
<listcomp>  s    z)TrOCRDecoder.__init__.<locals>.<listcomp>F)r   r   r[   Zdecoder_layerdrop	layerdropZpad_token_idr9   Zscale_embeddingrE   sqrtr{   r7   
vocab_sizeembed_tokensuse_learned_position_embeddingsr   Zmax_position_embeddingsembed_positionsr=   layernorm_embeddingr   r   Z
ModuleListrangeZdecoder_layerslayersgradient_checkpointing	post_init)r   rh   r:   r   r   r   r     s*    zTrOCRDecoder.__init__c                 C      | j S r;   r   r   r   r   r   get_input_embeddings     z!TrOCRDecoder.get_input_embeddingsc                 C   
   || _ d S r;   r   r   valuer   r   r   set_input_embeddings     
z!TrOCRDecoder.set_input_embeddingsNc                 C   s  |
dur|
n| j j}
|dur|n| j j}|	dur|	n| j j}	|dur$|n| j j}|dur4|dur4td|durD|}|d|jd }n|dur\| dd }|dddddf }ntd|durm|d d jd nd}|du rx| 	|}| j j
r| j||d}n| j||d}|| }| jdur| |}tjj|| j| jd}|j}t||||}|dur|durt||j|d d	}| jr| jr|	rtd
 d}	|rdnd}|
rdnd}|
r|durdnd}|	rdnd}t||gddgD ]*\}}|dur| d t| jkrtd| dt| j d| d  dqt| jD ]\}}|r0||f7 }| jrAtg }|| jk rAq$|durJ|| nd}| jrv| jrv| |j|||||dure|| nd|duro|| ndd|
|	
}n||||||dur|| nd|dur|| nd||
|	d	}|d }|	r|||
rdnd f7 }|
r||d f7 }|dur||d f7 }q$|r||f7 }|	r|nd}|st dd |||||fD S t!|||||dS )a  
        Args:
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
                Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
                provide it.

                Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
                [`PreTrainedTokenizer.__call__`] for details.

                [What are input IDs?](../glossary#input-ids)
            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                [What are attention masks?](../glossary#attention-mask)
            encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
                Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
                of the decoder.
            encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
                Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
                selected in `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                [What are attention masks?](../glossary#attention-mask)
            head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
                Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:

                - 1 indicates the head is **not masked**,
                - 0 indicates the head is **masked**.

            cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
                Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
                on hidden heads. Mask values selected in `[0, 1]`:

                - 1 indicates the head is **not masked**,
                - 0 indicates the head is **masked**.

            past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
                Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
                shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
                shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.

                Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
                cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.

                If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
                that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
                all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
            inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
                Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
                This is useful if you want more control over how to convert `input_ids` indices into associated vectors
                than the model's internal embedding lookup matrix.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            output_hidden_states (`bool`, *optional*):
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
                for more detail.
            return_dict (`bool`, *optional*):
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
        NzTYou cannot specify both decoder_input_ids and decoder_inputs_embeds at the same timer$   zEYou have to specify either decoder_input_ids or decoder_inputs_embedsr   r   )r    rt   )rx   z^`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...Fr   	head_maskcross_attn_head_maskzThe `z` should be specified for z layers, but it is for .)rp   r   r   rq   r   ro   rr   r   r   r   c                 s   s    | ]	}|d ur|V  qd S r;   r   )r   vr   r   r   	<genexpr>  s    z'TrOCRDecoder.forward.<locals>.<genexpr>)Zlast_hidden_statepast_key_valuesrm   
attentionscross_attentions)"rh   rr   output_hidden_statesr   use_return_dictra   rK   r%   rM   r   r   r   r   r   rw   r[   rv   r   r
   r"   r   loggerZwarning_onceziplenr   	enumerater&   Zrandr   Z_gradient_checkpointing_func__call__tupler   )r   r   rp   r   r   r   r   r   inputs_embedsr   rr   r   return_dictinputZinput_shaper    Z	embed_posrm   Zall_hidden_statesZall_self_attnsZall_cross_attentionsZnext_decoder_cacheZ	attn_maskZ	mask_nameidxZdecoder_layerZdropout_probabilityro   Zlayer_outputsZ
next_cacher   r   r   r,     s   P








zTrOCRDecoder.forward)NNNNNNNNNNNN)
r0   r1   r2   r3   r   r   r   r   r,   r6   r   r   r   r   r     s$    r   a  
    The TrOCR Model with a language modeling head. Can be used for summarization.
    This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
    used in combination with the [`EncoderDecoderModel`] framework.
    )Zcustom_introc                       s$   e Zd Z fddZdd Z  ZS )TrOCRDecoderWrapperc                    s   t  | t|| _d S r;   )r   r   r   decoderr   r   r   r   r     s   zTrOCRDecoderWrapper.__init__c                 O   s   | j |i |S r;   )r   )r   argskwargsr   r   r   r,     s   zTrOCRDecoderWrapper.forward)r0   r1   r2   r   r,   r6   r   r   r   r   r     s    r   zy
    The TrOCR Decoder with a language modeling head. Can be used as the decoder part of [`EncoderDecoderModel`] and
    c                        s  e Zd ZdgZ fddZdd Zdd Zdd	 Zd
d Zdd Z	dd Z
e													d#deej deej deej deej deej deej deeeej   deej deej dee dee dee dee deeef fdd Zed!d" Z  ZS )$TrOCRForCausalLMzoutput_projection.weightc                    sN   t |}d|_d|_t | t|| _tj	|j
|jdd| _|   d S )NTFr_   )copydeepcopyr\   Zis_encoder_decoderr   r   r   r   r   rc   r{   r   output_projectionr   r   r   r   r   r     s   

zTrOCRForCausalLM.__init__c                 C   s
   | j jjS r;   r   r   r   r   r   r   r   r     r   z%TrOCRForCausalLM.get_input_embeddingsc                 C   s   || j j_d S r;   r   r   r   r   r   r     s   z%TrOCRForCausalLM.set_input_embeddingsc                 C   r   r;   r   r   r   r   r   get_output_embeddings  r   z&TrOCRForCausalLM.get_output_embeddingsc                 C   r   r;   r   )r   Znew_embeddingsr   r   r   set_output_embeddings  r   z&TrOCRForCausalLM.set_output_embeddingsc                 C   s   || j _d S r;   r   r   )r   r   r   r   r   set_decoder  s   zTrOCRForCausalLM.set_decoderc                 C   s   | j jS r;   r   r   r   r   r   get_decoder  s   zTrOCRForCausalLM.get_decoderNr   rp   r   r   r   r   r   r   labelsr   rr   r   r   rs   c                 C   s   |dur|n| j j}|dur|n| j j}|dur|n| j j}| jj|||||||||
|||d}| |d }d}|	durNt }||d| j j	|	d}|sd|f|dd  }|durb|f| S |S t
|||j|j|j|jdS )a
  
        cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
            Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

        Example:

        ```python
        >>> from transformers import (
        ...     TrOCRConfig,
        ...     TrOCRProcessor,
        ...     TrOCRForCausalLM,
        ...     ViTConfig,
        ...     ViTModel,
        ...     VisionEncoderDecoderModel,
        ... )
        >>> import requests
        >>> from PIL import Image

        >>> # TrOCR is a decoder model and should be used within a VisionEncoderDecoderModel
        >>> # init vision2text model with random weights
        >>> encoder = ViTModel(ViTConfig())
        >>> decoder = TrOCRForCausalLM(TrOCRConfig())
        >>> model = VisionEncoderDecoderModel(encoder=encoder, decoder=decoder)

        >>> # If you want to start from the pretrained model, load the checkpoint with `VisionEncoderDecoderModel`
        >>> processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
        >>> model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")

        >>> # load image from the IAM dataset
        >>> url = "https://fki.tic.heia-fr.ch/static/img/a01-122-02.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
        >>> pixel_values = processor(image, return_tensors="pt").pixel_values
        >>> text = "industry, ' Mr. Brown commented icily. ' Let us have a"

        >>> # training
        >>> model.config.decoder_start_token_id = processor.tokenizer.eos_token_id
        >>> model.config.pad_token_id = processor.tokenizer.pad_token_id
        >>> model.config.vocab_size = model.config.decoder.vocab_size

        >>> labels = processor.tokenizer(text, return_tensors="pt").input_ids
        >>> outputs = model(pixel_values, labels=labels)
        >>> loss = outputs.loss
        >>> round(loss.item(), 2)
        5.30

        >>> # inference
        >>> generated_ids = model.generate(pixel_values)
        >>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
        >>> generated_text
        'industry, " Mr. Brown commented icily. " Let us have a'
        ```N)r   rp   r   r   r   r   r   r   r   rr   r   r   r   r$   r   )losslogitsr   rm   r   r   )rh   rr   r   r   r   r   r   r   rK   r   r   r   rm   r   r   )r   r   rp   r   r   r   r   r   r   r   r   rr   r   r   r   r   r   Zloss_fctoutputr   r   r   r,     sD   LzTrOCRForCausalLM.forwardc                    s.   d}| D ]}|t  fdd|D f7 }q|S )Nr   c                 3   s$    | ]}| d  |jV  qdS r/   )rO   rL   r#   )r   Z
past_statebeam_idxr   r   r   q  s   " z2TrOCRForCausalLM._reorder_cache.<locals>.<genexpr>)r   )r   r   Zreordered_pastZ
layer_pastr   r   r   _reorder_cachel  s   zTrOCRForCausalLM._reorder_cache)NNNNNNNNNNNNN)r0   r1   r2   Z_tied_weights_keysr   r   r   r   r   r   r   r   r   r&   Z
LongTensorr5   rB   r   ry   r   r   r,   rT   r   r6   r   r   r   r   r     sl    	

ur   )+r3   r   rE   typingr   r   r   r&   r   Ztorch.nnr   Zactivationsr   Z
generationr	   Zmodeling_attn_mask_utilsr
   r   Zmodeling_outputsr   r   Zmodeling_utilsr   utilsr   r   Zconfiguration_trocrr   Z
get_loggerr0   r   r   r   r7   Moduler=   rU   rz   r   r   r   r   __all__r   r   r   r   <module>   sJ   
>    	 "