o
    Zh                     @   s  d Z ddlZddlZddlmZ ddlmZmZm	Z	m
Z
 ddlZddlZddlmZ ddlmZ ddlmZ dd	lmZmZmZmZmZmZ dd
lmZ ddlmZmZmZ ddlm Z m!Z! ddl"m#Z# e!$e%Z&eG dd deZ'G dd dej(Z)G dd dej(Z*G dd dej(Z+G dd dej(Z,G dd dej(Z-G dd dej(Z.G dd dej(Z/G dd  d ej(Z0G d!d" d"ej(Z1G d#d$ d$ej(Z2e G d%d& d&eZ3e G d'd( d(e3Z4G d)d* d*ej(Z5e d+d,G d-d. d.e3Z6G d/d0 d0ej(Z7G d1d2 d2ej(Z8e d3d,G d4d5 d5e3Z9e d6d,G d7d8 d8e3Z:e d9d,G d:d; d;e3Z;e G d<d= d=e3Z<g d>Z=dS )?zPyTorch ViLT model.    N)	dataclass)ListOptionalTupleUnion)nn)CrossEntropyLoss   )ACT2FN)BaseModelOutputBaseModelOutputWithPoolingMaskedLMOutputModelOutputSequenceClassifierOutputTokenClassifierOutput)PreTrainedModel) find_pruneable_heads_and_indicesmeshgridprune_linear_layer)auto_docstringlogging   )
ViltConfigc                   @   sj   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eeeej   ed< dZeeeej   ed< dS )(ViltForImagesAndTextClassificationOutputa  
    Class for outputs of [`ViltForImagesAndTextClassification`].

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Classification (or regression if config.num_labels==1) loss.
        logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
            Classification (or regression if config.num_labels==1) scores (before SoftMax).
        hidden_states (`List[tuple(torch.FloatTensor)]`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            List of tuples of `torch.FloatTensor` (one for each image-text pair, each tuple containing the output of
            the embeddings + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (`List[tuple(torch.FloatTensor)]`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            List of tuples of `torch.FloatTensor` (one for each image-text pair, each tuple containing the attention
            weights of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights after the
            attention softmax, used to compute the weighted average in the self-attention heads.
    Nlosslogitshidden_states
attentions)__name__
__module____qualname____doc__r   r   torchFloatTensor__annotations__r   r   r   r   r    r%   r%   U/var/www/auris/lib/python3.10/site-packages/transformers/models/vilt/modeling_vilt.pyr   -   s   
 r   c                       s6   e Zd ZdZ fddZd
ddZ	ddd	Z  ZS )ViltEmbeddingsz
    Construct the text and patch embeddings.

    Text embeddings are equivalent to BERT embeddings.

    Patch embeddings are equivalent to ViT embeddings.
    c                    s   t    t|| _ttdd|j| _	t
|| _| jj}ttd|d |j| _t|j|j| _t|j| _|| _d S Nr   )super__init__TextEmbeddingstext_embeddingsr   	Parameterr"   zeroshidden_size	cls_tokenViltPatchEmbeddingspatch_embeddingsnum_patchesposition_embeddings	EmbeddingZmodality_type_vocab_sizetoken_type_embeddingsDropouthidden_dropout_probdropoutconfig)selfr:   r3   	__class__r%   r&   r*   P   s   



zViltEmbeddings.__init__   c                    s  | j jjj\}}}}|  |}|d d d d d d d f  }tjj||jd |jd fd }|d d df j	ddd d df }	|d d df j	ddd d df }
|j\}} | j
j| j
j }| jd d dd d d f ddd|||tj fddt|	|
D dd}|ddd}|ddd}tjtt|jd	 t|jd
 ddd
dj|jd}|d d d d d d d d f }||jd |jd d
d
d
}|dd}|d}dk sd u stts|	|
 }| n|	|
 }t| |jddd| jddd d df  }fdd|D }fdd|D }dd |D }dd |D }fdd|D }g }t t|||D ]B\}\}}}|dkrut!t"| }|#|| |  qStj!t"| |dd}|#tj|| || | gdd qStj|dd}||d d df |d d df f |d
|}||d d df |d d df f |d
}||d d df |d d df f |d
d}||d d df |d d df f |d
|}| j$|d
d
}tj||fdd}tj| jd d dd d f d d d d d f |d
d
|fdd}|| }| %|}tjt"|jd d||gdd}||| fffS )N   r	   sizer   r   dimc              
      sB   g | ]\}}t jt jj||fd ddd| d | fqS )ZbilinearT)rA   modeZalign_cornersr   )r   
functionalpadinterpolate).0hw)heightspatial_poswidthr%   r&   
<listcomp>l   s    
z/ViltEmbeddings.visual_embed.<locals>.<listcomp>Zij)ZindexingdeviceF)as_tuplec                    $   g | ]}  d d df |k qS Nr   r%   rH   u)	valid_idxr%   r&   rN         $ c                    rT   rU   r%   rV   )non_valid_idxr%   r&   rN      rY   c                 S      g | ]}| d qS r   r@   rH   vr%   r%   r&   rN          c                 S   r[   r\   r@   r]   r%   r%   r&   rN      r_   c                    s   g | ]} | qS r%   r%   r]   max_image_lengthr%   r&   rN          T)replacement)&r2   
projectionweightshapefloatr   rE   rG   longsumr:   
image_size
patch_sizer4   	transposeviewr"   catzipflattenstackr   arangetorR   expand
isinstanceintmaxminZnonzerounique	enumerateZmultinomialonesappendr0   r9   )r;   pixel_values
pixel_maskra   _phpwxZx_maskZx_hZx_w
batch_sizenum_channelsZ	patch_dimZ	pos_embedpatch_indexZeffective_resolutionZunique_rowsZvalid_row_idxZnon_valid_row_idxZ
valid_numsZnon_valid_numsZpad_numsselectir^   nvpZvalid_choiceZ
pad_choiceZ
cls_tokensr%   )rK   ra   rZ   rL   rX   rM   r&   visual_embed_   sx   
 $$$0
&


(.,..8
&zViltEmbeddings.visual_embedr   c	              	   C   s   | j |||d}	|d u r| j||| jjd\}}
}n|d}
|d u r%d}|	| tj|tj|	j	d }	|| tj
|
|tj|	j	d }tj|	|gdd}tj||
gdd}||fS )N)	input_idstoken_type_idsinputs_embedsr`   r   dtyperR   rB   )r,   r   r:   ra   rp   r6   r"   Z
zeros_likerh   rR   Z	full_likern   )r;   r   attention_maskr   r}   r~   r   image_embedsimage_token_type_idxZtext_embedsZimage_masksr   
embeddingsmasksr%   r%   r&   forward   s&   

zViltEmbeddings.forward)r>   )r   )r   r   r    r!   r*   r   r   __classcell__r%   r%   r<   r&   r'   G   s    
ar'   c                       s*   e Zd ZdZ fddZdddZ  ZS )r+   zGConstruct the embeddings from word, position and token_type embeddings.c                    s   t    tj|j|j|jd| _t|j|j| _	t|j
|j| _tj|j|jd| _t|j| _t|dd| _| jdt|jddd | jd	tj| j tjd
dd d S )N)padding_idxZepsposition_embedding_typeabsoluteposition_ids)r   rP   F)
persistentr   r   )r)   r*   r   r5   
vocab_sizer/   Zpad_token_idword_embeddingsZmax_position_embeddingsr4   Ztype_vocab_sizer6   	LayerNormlayer_norm_epsr7   r8   r9   getattrr   Zregister_bufferr"   rr   rt   r.   r   rA   rh   r;   r:   r<   r%   r&   r*      s   

zTextEmbeddings.__init__Nc                 C   s   |d ur	|  }n|  d d }|d }|d u r$| jd d d |f }|d u rNt| drC| jd d d |f }||d |}|}ntj|tj| jjd}|d u rW| 	|}| 
|}	||	 }
| jdkrn| |}|
|7 }
| |
}
| |
}
|
S )NrP   r   r   r   r   r   )rA   r   hasattrr   rt   r"   r.   rh   rR   r   r6   r   r4   r   r9   )r;   r   r   r   r   input_shape
seq_lengthZbuffered_token_type_idsZ buffered_token_type_ids_expandedr6   r   r4   r%   r%   r&   r      s,   







zTextEmbeddings.forward)NNNNr   r   r    r!   r*   r   r   r%   r%   r<   r&   r+      s    r+   c                       s(   e Zd ZdZ fddZdd Z  ZS )r1   z#
    Image to Patch Embedding.
    c                    s   t    |j|j}}|j|j}}t|tjj	r|n||f}t|tjj	r)|n||f}|d |d  |d |d   }|| _|| _|| _|| _
tj||||d| _d S )Nr   r   )Zkernel_sizeZstride)r)   r*   rj   rk   r   r/   ru   collectionsabcIterabler3   r   Conv2drd   )r;   r:   rj   rk   r   r/   r3   r<   r%   r&   r*     s   
 zViltPatchEmbeddings.__init__c                 C   s@   |j \}}}}|| jkrtd| jjj}| |j|d}|S )NzeMake sure that the channel dimension of the pixel values match with the one set in the configuration.r   )rf   r   
ValueErrorrd   re   r   rs   )r;   r}   r   r   rK   rM   Ztarget_dtyper   r%   r%   r&   r   .  s   

zViltPatchEmbeddings.forwardr   r%   r%   r<   r&   r1     s    r1   c                       .   e Zd Z fddZdd Zd	ddZ  ZS )
ViltSelfAttentionc                    s   t    |j|j dkrt|dstd|j d|j d|j| _t|j|j | _| j| j | _t	j
|j| j|jd| _t	j
|j| j|jd| _t	j
|j| j|jd| _t	|j| _d S )Nr   Zembedding_sizezThe hidden size z4 is not a multiple of the number of attention heads .bias)r)   r*   r/   num_attention_headsr   r   rv   attention_head_sizeall_head_sizer   LinearZqkv_biasquerykeyvaluer7   Zattention_probs_dropout_probr9   r   r<   r%   r&   r*   :  s   

zViltSelfAttention.__init__c                 C   s6   |  d d | j| jf }|j| }|ddddS )NrP   r   r?   r   r	   )rA   r   r   rm   permute)r;   r   Znew_x_shaper%   r%   r&   transpose_for_scoresL  s   
z&ViltSelfAttention.transpose_for_scoresNFc                 C   s   |  |}| | |}| | |}| |}t||dd}	|	t| j	 }	|d ur4|	| }	t
jdd|	}
| |
}
|d urI|
| }
t|
|}|dddd }| d d | jf }|j| }|rr||
f}|S |f}|S )NrP   rO   rB   r   r?   r   r	   )r   r   r   r   r"   matmulrl   mathsqrtr   r   ZSoftmaxr9   r   
contiguousrA   r   rm   )r;   r   r   	head_maskoutput_attentionsZmixed_query_layerZ	key_layerZvalue_layerZquery_layerZattention_scoresZattention_probsZcontext_layerZnew_context_layer_shapeoutputsr%   r%   r&   r   Q  s(   



zViltSelfAttention.forwardNNF)r   r   r    r*   r   r   r   r%   r%   r<   r&   r   9  s    r   c                       sF   e Zd ZdZdeddf fddZdejdejdejfd	d
Z  Z	S )ViltSelfOutputz
    The residual connection is defined in ViltLayer instead of here (as is the case with other models), due to the
    layernorm applied before each block.
    r:   returnNc                    s.   t    t|j|j| _t|j| _d S N)	r)   r*   r   r   r/   denser7   r8   r9   r   r<   r%   r&   r*   |     
zViltSelfOutput.__init__r   input_tensorc                 C      |  |}| |}|S r   r   r9   r;   r   r   r%   r%   r&   r        

zViltSelfOutput.forward)
r   r   r    r!   r   r*   r"   Tensorr   r   r%   r%   r<   r&   r   v  s    $r   c                       r   )
ViltAttentionc                    s*   t    t|| _t|| _t | _d S r   )r)   r*   r   	attentionr   outputsetpruned_headsr   r<   r%   r&   r*     s   


zViltAttention.__init__c                 C   s   t |dkrd S t|| jj| jj| j\}}t| jj|| j_t| jj|| j_t| jj	|| j_	t| j
j|dd| j
_| jjt | | j_| jj| jj | j_| j|| _d S )Nr   r   rB   )lenr   r   r   r   r   r   r   r   r   r   r   r   union)r;   headsindexr%   r%   r&   prune_heads  s   zViltAttention.prune_headsNFc                 C   s6   |  ||||}| |d |}|f|dd   }|S )Nr   r   )r   r   )r;   r   r   r   r   Zself_outputsattention_outputr   r%   r%   r&   r     s   zViltAttention.forwardr   )r   r   r    r*   r   r   r   r%   r%   r<   r&   r     s    r   c                       s<   e Zd Zdeddf fddZdejdejfddZ  ZS )	ViltIntermediater:   r   Nc                    sD   t    t|j|j| _t|jt	rt
|j | _d S |j| _d S r   )r)   r*   r   r   r/   intermediate_sizer   ru   
hidden_actstrr
   intermediate_act_fnr   r<   r%   r&   r*     s
   
zViltIntermediate.__init__r   c                 C   r   r   )r   r   r;   r   r%   r%   r&   r     r   zViltIntermediate.forward	r   r   r    r   r*   r"   r   r   r   r%   r%   r<   r&   r     s    r   c                       sB   e Zd Zdeddf fddZdejdejdejfdd	Z  ZS )

ViltOutputr:   r   Nc                    s.   t    t|j|j| _t|j| _	d S r   )
r)   r*   r   r   r   r/   r   r7   r8   r9   r   r<   r%   r&   r*     r   zViltOutput.__init__r   r   c                 C   s    |  |}| |}|| }|S r   r   r   r%   r%   r&   r     s   

zViltOutput.forwardr   r%   r%   r<   r&   r     s    $r   c                       s*   e Zd ZdZ fddZdddZ  ZS )		ViltLayerz?This corresponds to the Block class in the timm implementation.c                    sb   t    |j| _d| _t|| _t|| _t|| _	t
j|j|jd| _t
j|j|jd| _d S )Nr   r   )r)   r*   Zchunk_size_feed_forwardZseq_len_dimr   r   r   intermediater   r   r   r   r/   r   layernorm_beforelayernorm_afterr   r<   r%   r&   r*     s   



zViltLayer.__init__NFc           	      C   sj   | j | ||||d}|d }|dd  }|||j }| |}| |}| ||}|f| }|S )N)r   r   r   )r   r   rs   rR   r   r   r   )	r;   r   r   r   r   Zself_attention_outputsr   r   Zlayer_outputr%   r%   r&   r     s   


zViltLayer.forwardr   r   r%   r%   r<   r&   r     s    
r   c                       s0   e Zd Z fddZ					dddZ  ZS )	ViltEncoderc                    s:   t     | _t fddt jD | _d| _d S )Nc                    s   g | ]}t  qS r%   )r   )rH   r   r:   r%   r&   rN     rb   z(ViltEncoder.__init__.<locals>.<listcomp>F)	r)   r*   r:   r   Z
ModuleListrangenum_hidden_layerslayergradient_checkpointingr   r<   r   r&   r*     s   
 
zViltEncoder.__init__NFTc                 C   s   |rdnd }|r
dnd }t | jD ]:\}	}
|r||f }|d ur$||	 nd }| jr7| jr7| |
j||||}n|
||||}|d }|rK||d f }q|rS||f }|satdd |||fD S t|||dS )Nr%   r   r   c                 s   s    | ]	}|d ur|V  qd S r   r%   r]   r%   r%   r&   	<genexpr>  s    z&ViltEncoder.forward.<locals>.<genexpr>)last_hidden_stater   r   )rz   r   r   ZtrainingZ_gradient_checkpointing_func__call__tupler   )r;   r   r   r   r   output_hidden_statesreturn_dictZall_hidden_statesZall_self_attentionsr   Zlayer_moduleZlayer_head_maskZlayer_outputsr%   r%   r&   r     s8   	

zViltEncoder.forward)NNFFTr   r   r    r*   r   r   r%   r%   r<   r&   r     s    	r   c                   @   s(   e Zd ZeZdZdZddgZdd ZdS )ViltPreTrainedModelviltTr'   r   c                 C   s   t |tjtjfr#|jjjd| jjd |j	dur!|j	j
  dS dS t |tjrF|jjjd| jjd |jdurD|jj|j 
  dS dS t |tjr[|j	j
  |jjd dS dS )zInitialize the weightsg        )meanZstdNg      ?)ru   r   r   r   re   dataZnormal_r:   Zinitializer_ranger   Zzero_r5   r   r   Zfill_)r;   moduler%   r%   r&   _init_weights-  s   

z!ViltPreTrainedModel._init_weightsN)	r   r   r    r   Zconfig_classZbase_model_prefixZsupports_gradient_checkpointingZ_no_split_modulesr   r%   r%   r%   r&   r   &  s    r   c                       s   e Zd Zd fdd	Zdd Zdd Zdd	 Ze	
	
	
	
	
	
	
	
	
	
	
	
ddee	j
 dee	j dee	j
 dee	j dee	j
 dee	j dee	j dee	j dee dee dee dee deeee	j f fddZ  ZS )	ViltModelTc                    sX   t  | || _t|| _t|| _tj|j	|j
d| _|r#t|nd| _|   dS )zv
        add_pooling_layer (bool, *optional*, defaults to `True`):
            Whether to add a pooling layer
        r   N)r)   r*   r:   r'   r   r   encoderr   r   r/   r   	layernorm
ViltPoolerpooler	post_init)r;   r:   add_pooling_layerr<   r%   r&   r*   @  s   

zViltModel.__init__c                 C   s
   | j jjS r   r   r,   r   r;   r%   r%   r&   get_input_embeddingsQ  s   
zViltModel.get_input_embeddingsc                 C   s   || j j_d S r   r   )r;   r   r%   r%   r&   set_input_embeddingsT     zViltModel.set_input_embeddingsc                 C   s*   |  D ]\}}| jj| j| qdS )z
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        N)itemsr   r   r   r   )r;   Zheads_to_pruner   r   r%   r%   r&   _prune_headsW  s   zViltModel._prune_headsNr   r   r   r}   r~   r   r   r   r   r   r   r   r   c              
   C   s  |
dur|
n| j j}
|dur|n| j j}|dur|n| j j}|dur*|dur*td|dur9| || | }n|durF| dd }ntd|\}}|durU|jn|j}|du retj	||f|d}|durq|durqtd|du r}|du r}td|dur|j
d n|j
d }||krtd	|du rtj	|| j j| j jf|d}| || j j}| j||||||||	d
\}}| ||}| j||||
||d}|d }| |}| jdur| |nd}|s||f|dd  S t|||j|jdS )ak  
        image_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`, *optional*):
            Optionally, instead of passing `pixel_values`, you can choose to directly pass an embedded representation.
            This is useful if you want more control over how to convert `pixel_values` into patch embeddings.
        image_token_type_idx (`int`, *optional*):
            - The token type ids for images.

        Examples:

        ```python
        >>> from transformers import ViltProcessor, ViltModel
        >>> from PIL import Image
        >>> import requests

        >>> # prepare image and text
        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)
        >>> text = "hello world"

        >>> processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-mlm")
        >>> model = ViltModel.from_pretrained("dandelin/vilt-b32-mlm")

        >>> inputs = processor(image, text, return_tensors="pt")
        >>> outputs = model(**inputs)
        >>> last_hidden_states = outputs.last_hidden_state
        ```NzDYou cannot specify both input_ids and inputs_embeds at the same timerP   z5You have to specify either input_ids or inputs_embedsrQ   zFYou cannot specify both pixel_values and image_embeds at the same timez7You have to specify either pixel_values or image_embedsr   zAThe text inputs and image inputs need to have the same batch size)r   )r   r   r   r   r   r   )r   pooler_outputr   r   )r:   r   r   use_return_dictr   Z%warn_if_padding_and_no_attention_maskrA   rR   r"   r{   rf   rj   Zget_head_maskr   r   Zget_extended_attention_maskr   r   r   r   r   r   )r;   r   r   r   r}   r~   r   r   r   r   r   r   r   r   Ztext_batch_sizer   rR   Zimage_batch_sizeZembedding_outputZextended_attention_maskZencoder_outputssequence_outputpooled_outputr%   r%   r&   r   _  sp   *


zViltModel.forward)TNNNNNNNNNNNN)r   r   r    r*   r  r  r  r   r   r"   
LongTensorr#   rv   boolr   r   r   r   r   r%   r%   r<   r&   r   >  sZ    	
r   c                       $   e Zd Z fddZdd Z  ZS )r   c                    s*   t    t|j|j| _t | _d S r   )r)   r*   r   r   r/   r   ZTanh
activationr   r<   r%   r&   r*     s   
zViltPooler.__init__c                 C   s(   |d d df }|  |}| |}|S rU   )r   r  )r;   r   Zfirst_token_tensorr	  r%   r%   r&   r     s   

zViltPooler.forwardr   r%   r%   r<   r&   r     s    r   zU
    ViLT Model with a language modeling head on top as done during pretraining.
    )Zcustom_introc                       s   e Zd ZddgZ fddZdd Zdd Ze																								dd
ee	j
 dee	j dee	j
 dee	j dee	j
 dee	j dee	j dee	j dee	j
 dee dee dee deeee	j f fddZ  ZS )ViltForMaskedLMzmlm_score.decoder.weightzmlm_score.decoder.biasc                    s,   t  | t|| _t|| _|   d S r   )r)   r*   r   r   ViltMLMHead	mlm_scorer   r   r<   r%   r&   r*     s   

zViltForMaskedLM.__init__c                 C   s   | j jS r   )r  decoderr   r%   r%   r&   get_output_embeddings  s   z%ViltForMaskedLM.get_output_embeddingsc                 C   s   || j _|j| j _d S r   )r  r  r   )r;   Znew_embeddingsr%   r%   r&   set_output_embeddings  s   z%ViltForMaskedLM.set_output_embeddingsNr   r   r   r}   r~   r   r   r   labelsr   r   r   r   c                 C   s  |dur|n| j j}| j|||||||||
||d}|dd \}}|dur+|jd n|jd }|ddd|f |dd|df }}| |}d}|	durgt }|	|j}	||d| j j	|	d}|s}|f|dd  }|dur{|f| S |S t
|||j|jdS )a/  
        image_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`, *optional*):
            Optionally, instead of passing `pixel_values`, you can choose to directly pass an embedded representation.
            This is useful if you want more control over how to convert `pixel_values` into patch embeddings.
        labels (*torch.LongTensor* of shape *(batch_size, sequence_length)*, *optional*):
            Labels for computing the masked language modeling loss. Indices should be in *[-100, 0, ...,
            config.vocab_size]* (see *input_ids* docstring) Tokens with indices set to *-100* are ignored (masked), the
            loss is only computed for the tokens with labels in *[0, ..., config.vocab_size]*

        Examples:

        ```python
        >>> from transformers import ViltProcessor, ViltForMaskedLM
        >>> import requests
        >>> from PIL import Image
        >>> import re
        >>> import torch

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)
        >>> text = "a bunch of [MASK] laying on a [MASK]."

        >>> processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-mlm")
        >>> model = ViltForMaskedLM.from_pretrained("dandelin/vilt-b32-mlm")

        >>> # prepare inputs
        >>> encoding = processor(image, text, return_tensors="pt")

        >>> # forward pass
        >>> outputs = model(**encoding)

        >>> tl = len(re.findall("\[MASK\]", text))
        >>> inferred_token = [text]

        >>> # gradually fill in the MASK tokens, one by one
        >>> with torch.no_grad():
        ...     for i in range(tl):
        ...         encoded = processor.tokenizer(inferred_token)
        ...         input_ids = torch.tensor(encoded.input_ids)
        ...         encoded = encoded["input_ids"][0][1:-1]
        ...         outputs = model(input_ids=input_ids, pixel_values=encoding.pixel_values)
        ...         mlm_logits = outputs.logits[0]  # shape (seq_len, vocab_size)
        ...         # only take into account text features (minus CLS and SEP token)
        ...         mlm_logits = mlm_logits[1 : input_ids.shape[1] - 1, :]
        ...         mlm_values, mlm_ids = mlm_logits.softmax(dim=-1).max(dim=-1)
        ...         # only take into account text
        ...         mlm_values[torch.tensor(encoded) != 103] = 0
        ...         select = mlm_values.argmax().item()
        ...         encoded[select] = mlm_ids[select].item()
        ...         inferred_token = [processor.decode(encoded)]

        >>> selected_token = ""
        >>> encoded = processor.tokenizer(inferred_token)
        >>> output = processor.decode(encoded.input_ids[0], skip_special_tokens=True)
        >>> print(output)
        a bunch of cats laying on a couch.
        ```N
r   r   r}   r~   r   r   r   r   r   r   r?   r   rP   r   r   r   r   )r:   r  r   rf   r  r   rs   rR   rm   r   r   r   r   )r;   r   r   r   r}   r~   r   r   r   r  r   r   r   r   r  r	  Ztext_seq_lenZtext_featuresr   Z
mlm_logitsZmasked_lm_lossloss_fctr   r%   r%   r&   r     s@   I*
zViltForMaskedLM.forwardr
  )r   r   r    Z_tied_weights_keysr*   r  r  r   r   r"   r  r#   r  r   r   r   r   r   r%   r%   r<   r&   r    sZ    		
r  c                       r  )ViltPredictionHeadTransformc                    sV   t    t|j|j| _t|jtrt	|j | _
n|j| _
tj|j|jd| _d S )Nr   )r)   r*   r   r   r/   r   ru   r   r   r
   transform_act_fnr   r   r   r<   r%   r&   r*   r  s   
z$ViltPredictionHeadTransform.__init__c                 C   s"   |  |}| |}| |}|S r   )r   r  r   r   r%   r%   r&   r   {  s   


z#ViltPredictionHeadTransform.forwardr   r%   r%   r<   r&   r  q  s    	r  c                       s.   e Zd Zd fdd	Zdd Zdd Z  ZS )	r  Nc                    sb   t    || _t|| _tj|j|jdd| _	t
t|j| _|d ur*|| j	_| j| j	_d S )NFr   )r)   r*   r:   r  	transformr   r   r/   r   r  r-   r"   r.   r   re   )r;   r:   re   r<   r%   r&   r*     s   

zViltMLMHead.__init__c                 C   s   | j | j_ d S r   )r   r  r   r%   r%   r&   _tie_weights  r  zViltMLMHead._tie_weightsc                 C   r   r   )r  r  )r;   r   r%   r%   r&   r     s   

zViltMLMHead.forwardr   )r   r   r    r*   r  r   r   r%   r%   r<   r&   r    s    r  z
    Vilt Model transformer with a classifier head on top (a linear layer on top of the final hidden state of the [CLS]
    token) for visual question answering, e.g. for VQAv2.
    c                          e Zd Z fddZe												ddeej deej deej deej deej d	eej d
eej deej deej dee	 dee	 dee	 de
eeej f fddZ  ZS )ViltForQuestionAnsweringc              	      sl   t  | |j| _t|| _tt|j|jd t	|jd t
 t|jd |j| _|   d S )Nr?   )r)   r*   
num_labelsr   r   r   
Sequentialr   r/   r   GELU
classifierr   r   r<   r%   r&   r*     s   
z!ViltForQuestionAnswering.__init__Nr   r   r   r}   r~   r   r   r   r  r   r   r   r   c                 C   s   |dur|n| j j}| j|||||||||
||d}|r|jn|d }| |}d}|	dur@|	|j}	tj	||	|	j
d  }|sV|f|dd  }|durT|f| S |S t|||j|jdS )aX  
        image_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`, *optional*):
            Optionally, instead of passing `pixel_values`, you can choose to directly pass an embedded representation.
            This is useful if you want more control over how to convert `pixel_values` into patch embeddings.
        labels (`torch.FloatTensor` of shape `(batch_size, num_labels)`, *optional*):
            Labels for computing the visual question answering loss. This tensor must be either a one-hot encoding of
            all answers that are applicable for a given example in the batch, or a soft encoding indicating which
            answers are applicable, where 1.0 is the highest score.

        Examples:

        ```python
        >>> from transformers import ViltProcessor, ViltForQuestionAnswering
        >>> import requests
        >>> from PIL import Image

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)
        >>> text = "How many cats are there?"

        >>> processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
        >>> model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")

        >>> # prepare inputs
        >>> encoding = processor(image, text, return_tensors="pt")

        >>> # forward pass
        >>> outputs = model(**encoding)
        >>> logits = outputs.logits
        >>> idx = logits.argmax(-1).item()
        >>> print("Predicted answer:", model.config.id2label[idx])
        Predicted answer: 2
        ```Nr  r   r?   r  )r:   r  r   r  r"  rs   rR   r   rE   Z binary_cross_entropy_with_logitsrf   r   r   r   )r;   r   r   r   r}   r~   r   r   r   r  r   r   r   r   r  r   r   r   r%   r%   r&   r     s:   1
z ViltForQuestionAnswering.forwardr
  r   r   r    r*   r   r   r"   r  r#   r  r   r   r   r   r   r%   r%   r<   r&   r    sT    	
r  z
    Vilt Model transformer with a classifier head on top (a linear layer on top of the final hidden state of the [CLS]
    token) for image-to-text or text-to-image retrieval, e.g. MSCOCO and F30K.
    c                       r  )ViltForImageAndTextRetrievalc                    s2   t  | t|| _t|jd| _|   d S r(   )	r)   r*   r   r   r   r   r/   rank_outputr   r   r<   r%   r&   r*     s   
z%ViltForImageAndTextRetrieval.__init__Nr   r   r   r}   r~   r   r   r   r  r   r   r   r   c                 C   s   |dur|n| j j}d}|	durtd| j|||||||||
||d}|r)|jn|d }| |}|sH|f|dd  }|durF|f| S |S t|||j|jdS )ad  
        image_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`, *optional*):
            Optionally, instead of passing `pixel_values`, you can choose to directly pass an embedded representation.
            This is useful if you want more control over how to convert `pixel_values` into patch embeddings.
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels are currently not supported.

        Examples:

        ```python
        >>> from transformers import ViltProcessor, ViltForImageAndTextRetrieval
        >>> import requests
        >>> from PIL import Image

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)
        >>> texts = ["An image of two cats chilling on a couch", "A football player scoring a goal"]

        >>> processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-coco")
        >>> model = ViltForImageAndTextRetrieval.from_pretrained("dandelin/vilt-b32-finetuned-coco")

        >>> # forward pass
        >>> scores = dict()
        >>> for text in texts:
        ...     # prepare inputs
        ...     encoding = processor(image, text, return_tensors="pt")
        ...     outputs = model(**encoding)
        ...     scores[text] = outputs.logits[0, :].item()
        ```NzTraining is not yet supported.r  r   r?   r  )	r:   r  NotImplementedErrorr   r  r%  r   r   r   )r;   r   r   r   r}   r~   r   r   r   r  r   r   r   r   r   r  r   r   r%   r%   r&   r     s8   -
z$ViltForImageAndTextRetrieval.forwardr
  r#  r%   r%   r<   r&   r$    sT    	
r$  zq
    Vilt Model transformer with a classifier head on top for natural language visual reasoning, e.g. NLVR2.
    c                       r  )"ViltForImagesAndTextClassificationc              	      sv   t  | |j| _t|| _|j}tt|j	| |j	| t
|j	| t t|j	| |j| _|   d S r   )r)   r*   r  r   r   
num_imagesr   r   r   r/   r   r!  r"  r   )r;   r:   r(  r<   r%   r&   r*   q  s   
z+ViltForImagesAndTextClassification.__init__Nr   r   r   r}   r~   r   r   r   r  r   r   r   r   c                 C   sD  |
dur|
n| j j}
|dur|n| j j}|dur|n| j j}|dur,|jdkr,|d}|dur:|jdkr:|d}|durC|jd nd}|du rT|durR|jd nd}|| j jkr^tdg }|rdg nd}|
rjg nd}t	|D ]l}| j
||||dur|dd|ddddddf nd|dur|dd|ddddf nd|||dur|dd|ddddf nd|d |
||d}|r|jn|d }|| |r||j |
r||j qptj|dd}| |}d}|	durt }|	|j}	||d| j|	d}|s|||f}|dur|f| S |S t||||d	S )
a3  
        image_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`, *optional*):
            Optionally, instead of passing `pixel_values`, you can choose to directly pass an embedded representation.
            This is useful if you want more control over how to convert `pixel_values` into patch embeddings.
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Binary classification labels.

        Examples:

        ```python
        >>> from transformers import ViltProcessor, ViltForImagesAndTextClassification
        >>> import requests
        >>> from PIL import Image

        >>> image1 = Image.open(requests.get("https://lil.nlp.cornell.edu/nlvr/exs/ex0_0.jpg", stream=True).raw)
        >>> image2 = Image.open(requests.get("https://lil.nlp.cornell.edu/nlvr/exs/ex0_1.jpg", stream=True).raw)
        >>> text = "The left image contains twice the number of dogs as the right image."

        >>> processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-nlvr2")
        >>> model = ViltForImagesAndTextClassification.from_pretrained("dandelin/vilt-b32-finetuned-nlvr2")

        >>> # prepare inputs
        >>> encoding = processor([image1, image2], text, return_tensors="pt")

        >>> # forward pass
        >>> outputs = model(input_ids=encoding.input_ids, pixel_values=encoding.pixel_values.unsqueeze(0))
        >>> logits = outputs.logits
        >>> idx = logits.argmax(-1).item()
        >>> print("Predicted answer:", model.config.id2label[idx])
        Predicted answer: True
        ```N   r   r	   z\Make sure to match the number of images in the model with the number of images in the input.)r   r   r}   r~   r   r   r   r   r   r   r   rP   rB   r  )r:   r   r   r  ndimZ	unsqueezerf   r(  r   r   r   r  r|   r   r   r"   rn   r"  r   rs   rR   rm   r  r   )r;   r   r   r   r}   r~   r   r   r   r  r   r   r   r(  Zpooler_outputsr   r   r   r   r  r	  r   r   r  r   r%   r%   r&   r     sp   /

,&&



z*ViltForImagesAndTextClassification.forwardr
  )r   r   r    r*   r   r   r"   r  r#   r  r   r   r   r   r   r%   r%   r<   r&   r'  k  sT    	
r'  c                       r  )ViltForTokenClassificationc                    sN   t  | |j| _t|dd| _t|j| _t	|j
|j| _|   d S )NF)r   )r)   r*   r  r   r   r   r7   r8   r9   r   r/   r"  r   r   r<   r%   r&   r*     s   z#ViltForTokenClassification.__init__Nr   r   r   r}   r~   r   r   r   r  r   r   r   r   c                 C   s   |dur|n| j j}| j|||||||||
||d}|d }|dur'|jd n|jd }| |}| |ddd|f }d}|	durZt }|	|j}	||	d| j
|		d}|sp|f|dd  }|durn|f| S |S t|||j|jdS )a/  
        image_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`, *optional*):
            Optionally, instead of passing `pixel_values`, you can choose to directly pass an embedded representation.
            This is useful if you want more control over how to convert `pixel_values` into patch embeddings.
        labels (`torch.LongTensor` of shape `(batch_size, text_sequence_length)`, *optional*):
            Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
        Nr  r   r   rP   r?   r  )r:   r  r   rf   r9   r"  r   rs   rR   rm   r  r   r   r   )r;   r   r   r   r}   r~   r   r   r   r  r   r   r   r   r  Ztext_input_sizer   r   r  r   r%   r%   r&   r     s@   
z"ViltForTokenClassification.forwardr
  )r   r   r    r*   r   r   r"   r  r#   r  r   r   r   r   r   r%   r%   r<   r&   r+    sT    	
r+  )r$  r'  r+  r  r  r   r   r   )>r!   collections.abcr   r   dataclassesr   typingr   r   r   r   r"   Ztorch.utils.checkpointr   Ztorch.nnr   Zactivationsr
   Zmodeling_outputsr   r   r   r   r   r   Zmodeling_utilsr   Zpytorch_utilsr   r   r   utilsr   r   Zconfiguration_viltr   Z
get_loggerr   loggerr   Moduler'   r+   r1   r   r   r   r   r   r   r   r   r   r   r  r  r  r  r$  r'  r+  __all__r%   r%   r%   r&   <module>   st    
 9=#&5  j] O