o
    Zhs                    @   s  d Z ddlZddlZddlmZ ddlmZ ddlmZmZm	Z	m
Z
mZmZmZ ddlZddlZddlmZ ddlmZ dd	lmZmZ dd
lmZmZmZ ddlmZmZmZmZ ddlm Z m!Z!m"Z"m#Z#m$Z$ e%e&Z'dZ(dZ)dZ*ee$e"e#f Z+eG dd deZ,eG dd deZ-eG dd deZ.G dd dej/Z0G dd dej/Z1G dd dej/Z2G dd dej/Z3G dd dej/Z4G d d! d!ej/Z5G d"d# d#ej/Z6G d$d% d%ej/Z7G d&d' d'ej/Z8G d(d) d)ej/Z9G d*d+ d+ej/Z:eG d,d- d-eZ;eG d.d/ d/e;Z<eG d0d1 d1e;Z=eG d2d3 d3e;Z>eG d4d5 d5e;Z?G d6d7 d7ej/Z@G d8d9 d9ej/ZAG d:d; d;ej/ZBed<d=G d>d? d?e;ZCG d@dA dAej/ZDG dBdC dCej/ZEG dDdE dEej/ZFG dFdG dGej/ZGedHd=G dIdJ dJe;ZHg dKZIdS )LzPyTorch FLAVA model.    N)OrderedDict)	dataclass)AnyDictListOptionalSetTupleUnion)nn   )ACT2FN)BaseModelOutputBaseModelOutputWithPooling)PreTrainedModel find_pruneable_heads_and_indicesprune_linear_layer)ModelOutputauto_docstringlogging	torch_int   )FlavaConfigFlavaImageCodebookConfigFlavaImageConfigFlavaMultimodalConfigFlavaTextConfigzfacebook/flava-image-codebookg$(~k@c                   @   s   e Zd ZU dZdZeej ed< dZ	ee
 ed< dZeej ed< dZee
 ed< dZeej ed< dZee
 ed< d	ee fd
dZdS )FlavaModelOutputa  
    Output from FlavaModel containing embeddings and outputs from individual encoders.

    Note that `image_embeddings` and `text_embeddigns` returned are similar to pooled output returned from a
    transformer. If you want embeddings for contrastive loss or retrieval use a FLAVA model's `image_projection` and
    `text_projection` layers on `image_embeddings` and `text_embeddings` respectively.

    Args:
        image_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `pixel_values` are present):
            The image embeddings which are basically the pooled output of [`FlavaImageModel`].
        image_output (`BaseModelOutputWithPooling`, *optional*, returned when `pixel_values` are present):
            The output of the [`FlavaImageModel`].
        text_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `input_ids` are present):
            The text embeddings which are basically the pooled output of [`FlavaTextModel`].
        text_output (`BaseModelOutputWithPooling`, *optional*, returned when `input_ids` are present):
            The output of the [`FlavaTextModel`].
        multimodal_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `input_ids` and `pixel_values` are present and `skip_multimodal_encoder` is `None` or `False`):
            The multimodal embeddings which are basically the pooled output of [`FlavaTextModel`].
        multimodal_output (`BaseModelOutputWithPooling`, returned when `input_ids` and `pixel_values` are present and `skip_multimodal_encoder` is `None` or `False`):
            The output of the [`FlavaMultimodalModel`].
    Nimage_embeddingsimage_outputtext_embeddingstext_outputmultimodal_embeddingsmultimodal_outputreturnc                    s   t  fdd  D S )Nc                 3   s.    | ]}|d vr | nt  | V  qdS ))r!   r   r#   Ngetattrto_tuple.0kself W/var/www/auris/lib/python3.10/site-packages/transformers/models/flava/modeling_flava.py	<genexpr>R   s
    
z,FlavaModelOutput.to_tuple.<locals>.<genexpr>tuplekeysr+   r-   r+   r.   r'   Q   s   zFlavaModelOutput.to_tuple)__name__
__module____qualname____doc__r   r   torchFloatTensor__annotations__r   r   r    r!   r"   r#   r	   r   r'   r-   r-   r-   r.   r   2   s   
 r   c                   @   s   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eej ed< dZeej ed< dZeej ed< dZeej ed< d	efd
dZdS )FlavaLossesa"  Class representing pretraining losses from FLAVA model

    Args:
        mim (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mim_labels` and `pixel_values` are present, `input_ids_masked` is absent and `mim_weight` > 0.:
            Masked Image Modeling loss as used in BeIT calculated only for unimodal image data.
        mlm (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mlm_labels` and `input_ids_masked` are present, `pixel_values` is absent and `mlm_weight` > 0.:
            Masked Language Modeling loss as used in BERT calculated only for unimodal text data.
        itm (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `itm_labels`, `input_ids_masked`, `pixel_values` are present and `itm_weight` > 0.:
            Image Text Matching (ITM) loss calculated for paired image-text data. Note that ITM loss is calculated on
            masked pairs in FLAVA.
        global_contrastive (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `input_ids` and `pixel_values` are present and `global_contrastive_weight` > 0.:
            Contrastive loss for image-text similarity similar to CLIP but calculated globally for paired image-text
            data. This is calculated on unmasked images and texts.
        mmm_image (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mim_labels`, `pixel_values` and `input_ids_masked` are present and `mmm_image_weight` > 0.:
            Masked Multimodal Modeling loss's image component calculated on paired image-text data.
        mmm_text (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mlm_labels`, `pixel_values` and `input_ids_masked` are present and `mmm_text_weight` > 0.:
            Masked Multimodal Modeling loss's text component calculated on paired image-text data.
    Nmimmlmitmglobal_contrastive	mmm_imagemmm_textr$   c                 C   s(   d}|   D ]}|d urd} |S q|S )NTF)values)r,   all_nonevr-   r-   r.   rB   t   s   zFlavaLosses.all_none)r3   r4   r5   r6   r;   r   r7   r8   r9   r<   r=   r>   r?   r@   boolrB   r-   r-   r-   r.   r:   X   s   
 r:   c                   @   s  e Zd ZU dZdZeej ed< dZ	e
ed< dZeej ed< dZee ed< dZeej ed< dZee ed< dZeej ed	< dZee ed
< dZeej ed< dZee ed< dZeej ed< dZee ed< dZeej ed< dZee ed< dZeej ed< dZeej ed< dZeej ed< dZeej ed< dZeej ed< dZeej ed< dZeej ed< dee  fddZ!dS )FlavaForPreTrainingOutputa  
    Output from FlavaForPreTraining containing embeddings, and outputs from individual encoders.

    Note that `image_embeddings` and `text_embeddings` returned are similar to pooled output returned from a
    transformer. If you want embeddings for contrastive loss or retrieval use a FLAVA model's `image_projection` and
    `text_projection` layers on `image_embeddings` and `text_embeddings` respectively.

    Args:
        loss (`torch.FloatTensor`, *optional*, returned when `return_loss` is True):
            Total loss calculated for this model.
        loss_info (`FlavaLosses`):
            Detailed info for FLAVA Pretraining losses. Check `FlavaLosses` class description for the information on
            the keys.
        image_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `pixel_values` are present):
            The image embeddings which are basically the pooled output of [`FlavaImageModel`].
        image_output (`BaseModelOutputWithPooling`, *optional*, returned when `pixel_values` are present):
            The output of the [`FlavaImageModel`].
        text_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `input_ids` are present):
            The text embeddings which are basically the pooled output of [`FlavaTextModel`].
        text_output (`BaseModelOutputWithPooling`, *optional*, returned when `input_ids` are present):
            The output of the [`FlavaTextModel`].
        multimodal_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `input_ids` and `pixel_values` are present and `skip_unmasked_multimodal_encoder` is `None` or `False`):
            The multimodal embeddings which are basically the pooled output of [`FlavaTextModel`].
        multimodal_output (`BaseModelOutputWithPooling`, returned when `input_ids` and `pixel_values` are present and `skip_unmasked_multimodal_encoder` is `None` or `False`):
            The output of the [`FlavaMultimodalModel`].

        image_masked_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `pixel_values` are present):
            The image embeddings which are basically the pooled output of [`FlavaImageModel`]. Uses `bool_masked_pos`
            to create masked images.
        image_masked_output (`BaseModelOutputWithPooling`, *optional*, returned when `pixel_values` are present):
            The output of the [`FlavaImageModel`]. Uses `bool_masked_pos` to create masked images.
        text_masked_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `input_ids_masked` are present):
            The text embeddings which are basically the pooled output of [`FlavaTextModel`].
        text_masked_output (`BaseModelOutputWithPooling`, *optional*, returned when `input_ids_masked` are present):
            The output of the [`FlavaTextModel`].
        multimodal_masked_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `input_ids` and `pixel_values` are present):
            The multimodal embeddings which are basically the pooled output of [`FlavaTextModel`].
        multimodal_masked_output (`BaseModelOutputWithPooling`, *optional*, returned when `input_ids_masked` and `pixel_values` are present):
            The output of the [`FlavaMultimodalModel`].

        mim_logits (`torch.FloatTensor` of shape `(batch_size, num_image_patches, image_vocab_size)` or of shape `(total_masked_patches, image_vocab_size)` , *optional*, returned when `pixel_values` are present and `input_ids_masked` are not):
                The logits for MIM unimodal loss. Uses `book_masked_pos` to get masked patches. The flattened output is
                returned when `bool_masked_pos` has some of the patches masked.
        mlm_logits (`torch.FloatTensor` of shape `(batch_size, text_seq_length, text_vocab_size)` or of shape `(total_masked_seq_length, text_vocab_size)`, *optional*, returned when `input_ids_masked` are present and `pixel_values` are not):
                The logits for MLM unimodal loss. The flattened output is returned when `input_ids_masked` has some of
                the tokens masked.
        itm_logits (`torch.FloatTensor` of shape `(batch_size, 2)`, *optional*, returned when `input_ids_masked` and `pixel_values` are present):
                The logits for ITM loss. Note that ITM loss is calculated on masked pairs in FLAVA.
        mmm_image_logits (`torch.FloatTensor` of shape `(batch_size, num_image_patches, image_vocab_size)` or of shape`(total_masked_patches, image_vocab_size)`, *optional*, returned when `pixel_values` and `input_ids_masked` are present):
                The logits for MMM image multimodal loss. Uses `book_masked_pos` to get masked patches. The flattened
                output is returned when `bool_masked_pos` has some of the patches masked.
        mmm_text_logits (`torch.FloatTensor` of shape `(batch_size, text_seq_length, text_vocab_size)` or of shape `(`(total_masked_seq_length, text_vocab_size)`), *optional*, returned when `pixel_values` and `input_ids_masked` are present):
                The logits for MMM text multimodal loss. The flattened output is returned when `input_ids_masked` has
                some of the tokens masked.
        contrastive_logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
            The scaled dot product scores between `image_embeddings` and `text_embeddings` but passed through FLAVA's
            `image_projection` and `text_projection` layers respectively. This represents the image-text similarity
            scores. This is calculated on unmasked images and texts.
        contrastive_logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
            The scaled dot product scores between `text_embeddings` and `image_embeddings` but passed through FLAVA's
            `text_projection` and `image_projection` layers respectively. This is calculated on unmasked images and
            texts.
    Nloss	loss_infor   r   r    r!   r"   r#   image_masked_embeddingsimage_masked_outputtext_masked_embeddingstext_masked_outputmultimodal_masked_embeddingsmultimodal_masked_output
mim_logits
mlm_logits
itm_logitscontrastive_logits_per_imagecontrastive_logits_per_textmmm_image_logitsmmm_text_logitsr$   c                    s$   g dt  fdd  D S )N)r!   r   r#   rK   rI   rM   c                 3   s.    | ]}|vr | nt  | V  qd S Nr%   r(   r,   Ztransformer_outputsr-   r.   r/      s   , z5FlavaForPreTrainingOutput.to_tuple.<locals>.<genexpr>r0   r+   r-   rV   r.   r'      s   z"FlavaForPreTrainingOutput.to_tuple)"r3   r4   r5   r6   rF   r   r7   r8   r9   rG   r:   r   r   r   r    r!   r"   r#   rH   rI   rJ   rK   rL   rM   rN   rO   rP   rQ   rR   rS   rT   r	   r   r'   r-   r-   r-   r.   rE   }   s0   
 @rE   c                	       sx   e Zd ZdZddededdf fddZd	ejd
e	de	dejfddZ
		ddejdeej dedejfddZ  ZS )FlavaImageEmbeddingszb
    Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
    Fconfiguse_mask_tokenr$   Nc                    s   t    |p	|j}ttdd|j| _|r#ttdd|jnd | _t	|j
|j|j|jd| _| jj}ttd|d |j| _t|j| _|j| _|| _d S )Nr   )
image_size
patch_sizenum_channels	embed_dim)super__init__
mask_tokenr   	Parameterr7   zeroshidden_size	cls_tokenPatchEmbeddingsrZ   r[   r\   patch_embeddingsnum_patchesposition_embeddingsDropouthidden_dropout_probdropoutrX   )r,   rX   rY   rg   	__class__r-   r.   r_      s   

 
zFlavaImageEmbeddings.__init__
embeddingsheightwidthc                 C   s   |j d d }| jj d d }tj s||kr||kr| jS | jddddf }| jddddf }|j d }|| j }	|| j }
t|d }|d|||}|dddd}t	j
j||	|
fdd	d
}|dddddd|}tj||fddS )a   
        This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
        images. This method is also adapted to support torch.jit tracing.

        Adapted from:
        - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
        - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
        r   Ng      ?r   r      ZbicubicF)sizemodeZalign_cornersdim)shaperh   r7   Zjit
is_tracingr[   r   Zreshapepermuter   
functionalZinterpolateviewcat)r,   rn   ro   rp   rg   Znum_positionsZclass_pos_embedZpatch_pos_embedrv   Z
new_heightZ	new_widthZsqrt_num_positionsr-   r-   r.   interpolate_pos_encoding   s(   



z-FlavaImageEmbeddings.interpolate_pos_encodingpixel_valuesbool_masked_posr}   c                 C   s   |j \}}}}| j||d}| \}}	}
|d urB| j||	d}| dkr0||dd}|d|}|d|  ||  }| j	|dd}t
j||fdd}|r_|| ||| }n|| j }| |}|S )N)r}   rq   r   r         ?r   ru   )rw   rf   rs   r`   expandrv   r{   Z	unsqueezeZtype_asrd   r7   r|   r}   rh   rk   )r,   r~   r   r}   
batch_sizer\   ro   rp   rn   seq_len_Zmask_tokensmask
cls_tokensr-   r-   r.   forward#  s    

zFlavaImageEmbeddings.forwardF)NF)r3   r4   r5   r6   r   rD   r_   r7   Tensorintr}   r   
BoolTensorr   __classcell__r-   r-   rl   r.   rW      s    +rW   c                	       sh   e Zd ZdZ				ddedeeeeef f ded	ef fd
dZddej	de
dej	fddZ  ZS )re   z#
    Image to Patch Embedding.
          r      rZ   r[   r\   r]   c                    s   t    t|tjjs||f}t|tjjs||f}|d |d  |d |d   }|| _|| _|| _t	j
||||d| _d S )Nr   r   )kernel_sizeZstride)r^   r_   
isinstancecollectionsabcIterablerZ   r[   rg   r   Conv2d
projection)r,   rZ   r[   r\   r]   rg   rl   r-   r.   r_   L  s   
 zPatchEmbeddings.__init__Fr~   r}   r$   c              
   C   sx   |j \}}}}|s.|| jd ks|| jd kr.td| d| d| jd  d| jd  d	| |ddd}|S )Nr   r   zInput image size (*z) doesn't match model (z).rr   )rw   rZ   
ValueErrorr   flatten	transpose)r,   r~   r}   r   r\   ro   rp   xr-   r-   r.   r   _  s   zPatchEmbeddings.forward)r   r   r   r   r   )r3   r4   r5   r6   r   r
   r	   r_   r7   r   rD   r   r   r-   r-   rl   r.   re   G  s     $re   c                       sP   e Zd ZdZ fddZ			d
deej deej deej fdd	Z  Z	S )FlavaTextEmbeddingszGConstruct the embeddings from word, position and token_type embeddings.c                    s   t    tj|j|j|jd| _t|j|j| _	t|j
|j| _tj|j|jd| _t|j| _t|dd| _| jdt|jddd | jd	tj| j tjd
dd d S )N)padding_idxZepsposition_embedding_typeabsoluteposition_ids)r   rq   F)
persistenttoken_type_ids)dtype)r^   r_   r   	Embedding
vocab_sizerc   Zpad_token_idword_embeddingsZmax_position_embeddingsrh   Ztype_vocab_sizetoken_type_embeddings	LayerNormlayer_norm_epsri   rj   rk   r&   r   Zregister_bufferr7   aranger   rb   r   rs   longr,   rX   rl   r-   r.   r_   n  s   

zFlavaTextEmbeddings.__init__N	input_idsr   r   c                 C   s   |  }|d }|d u r| jd d d |f }|d u rAt| dr6| jd d d |f }||d |}|}ntj|tj| jjd}| 	|}| 
|}	||	 }
| jdkr]| |}|
|7 }
| |
}
| |
}
|
S )Nr   r   r   )r   devicer   )rs   r   hasattrr   r   r7   rb   r   r   r   r   r   rh   r   rk   )r,   r   r   r   input_shape
seq_lengthZbuffered_token_type_idsZ buffered_token_type_ids_expandedZinputs_embedsr   rn   rh   r-   r-   r.   r     s&   






zFlavaTextEmbeddings.forward)NNN)
r3   r4   r5   r6   r_   r   r7   r   r   r   r-   r-   rl   r.   r   k  s    r   c                       s   e Zd Zdeddf fddZdejdejfddZ				dd
ejdeej deej de	de
eejejf eej f f
ddZ  ZS )FlavaSelfAttentionrX   r$   Nc                    s   t    |j|j dkrt|dstd|j d|j d|j| _t|j|j | _| j| j | _t	j
|j| j|jd| _t	j
|j| j|jd| _t	j
|j| j|jd| _t	|j| _d S )Nr   Zembedding_sizezThe hidden size z4 is not a multiple of the number of attention heads .bias)r^   r_   rc   num_attention_headsr   r   r   attention_head_sizeall_head_sizer   LinearZqkv_biasquerykeyvalueri   Zattention_probs_dropout_probrk   r   rl   r-   r.   r_     s   

zFlavaSelfAttention.__init__r   c                 C   s6   |  d d | j| jf }|j| }|ddddS )Nrq   r   rr   r   r   )rs   r   r   r{   ry   )r,   r   Znew_x_shaper-   r-   r.   transpose_for_scores  s   
z'FlavaSelfAttention.transpose_for_scoresFhidden_statesattention_mask	head_maskoutput_attentionsc                 C   s   |  |}| | |}| | |}| |}t||dd}	|	t| j	 }	|d ur4|	| }	t
jj|	dd}
| |
}
|d urI|
| }
t|
|}|dddd }| d d | jf }|j| }|rr||
f}|S |f}|S )Nrq   ru   r   rr   r   r   )r   r   r   r   r7   matmulr   mathsqrtr   r   rz   Zsoftmaxrk   ry   
contiguousrs   r   r{   )r,   r   r   r   r   Zmixed_query_layerZ	key_layerZvalue_layerZquery_layerZattention_scoresZattention_probsZcontext_layerZnew_context_layer_shapeoutputsr-   r-   r.   r     s(   



zFlavaSelfAttention.forwardNNF)r3   r4   r5   FlavaPossibleConfigsr_   r7   r   r   r   rD   r
   r	   r   r   r-   r-   rl   r.   r     s"    r   c                       sF   e Zd ZdZdeddf fddZdejdejdejfd	d
Z  Z	S )FlavaSelfOutputz
    The residual connection is defined in FlavaLayer (same as ViTLayer) instead of here (as is the case with other
    models), due to the layernorm applied before each block.
    rX   r$   Nc                    s.   t    t|j|j| _t|j| _d S rU   )	r^   r_   r   r   rc   denseri   rj   rk   r   rl   r-   r.   r_        
zFlavaSelfOutput.__init__r   input_tensorc                 C      |  |}| |}|S rU   r   rk   r,   r   r   r-   r-   r.   r        

zFlavaSelfOutput.forward)
r3   r4   r5   r6   r   r_   r7   r   r   r   r-   r-   rl   r.   r     s    $r   c                       s   e Zd Zdeddf fddZdee ddfddZ				dd
ej	de
ej	 de
ej	 dedeeej	ej	f eej	 f f
ddZ  ZS )FlavaAttentionrX   r$   Nc                    s*   t    t|| _t|| _t | _d S rU   )r^   r_   r   	attentionr   outputsetpruned_headsr   rl   r-   r.   r_     s   


zFlavaAttention.__init__headsc                 C   s   t |dkrd S t|| jj| jj| j\}}t| jj|| j_t| jj|| j_t| jj	|| j_	t| j
j|dd| j
_| jjt | | j_| jj| jj | j_| j|| _d S )Nr   r   ru   )lenr   r   r   r   r   r   r   r   r   r   r   r   union)r,   r   indexr-   r-   r.   prune_heads   s   zFlavaAttention.prune_headsFr   r   r   r   c                 C   s8   | j ||||d}| |d |}|f|dd   }|S N)r   r   r   r   r   )r   r   )r,   r   r   r   r   Zself_outputsattention_outputr   r-   r-   r.   r     s   zFlavaAttention.forwardr   )r3   r4   r5   r   r_   r   r   r   r7   r   r   rD   r
   r	   r   r   r-   r-   rl   r.   r     s"    r   c                       s<   e Zd Zdeddf fddZdejdejfddZ  ZS )	FlavaIntermediaterX   r$   Nc                    sD   t    t|j|j| _t|jt	rt
|j | _d S |j| _d S rU   )r^   r_   r   r   rc   intermediate_sizer   r   
hidden_actstrr   intermediate_act_fnr   rl   r-   r.   r_   $  s
   
zFlavaIntermediate.__init__r   c                 C   r   rU   )r   r   r,   r   r-   r-   r.   r   -  r   zFlavaIntermediate.forward	r3   r4   r5   r   r_   r7   r   r   r   r-   r-   rl   r.   r   #  s    	r   c                       sB   e Zd Zdeddf fddZdejdejdejfdd	Z  ZS )
FlavaOutputrX   r$   Nc                    s.   t    t|j|j| _t|j| _	d S rU   )
r^   r_   r   r   r   rc   r   ri   rj   rk   r   rl   r-   r.   r_   5  r   zFlavaOutput.__init__r   r   c                 C   s    |  |}| |}|| }|S rU   r   r   r-   r-   r.   r   ;  s   

zFlavaOutput.forwardr   r-   r-   rl   r.   r   4  s    $r   c                       sx   e Zd ZdZdeddf fddZ			ddejd	eej d
eej de	de
eejejf eej f f
ddZ  ZS )
FlavaLayerz?This corresponds to the Block class in the timm implementation.rX   r$   Nc                    sb   t    |j| _d| _t|| _t|| _t|| _	t
j|j|jd| _t
j|j|jd| _d S )Nr   r   )r^   r_   Zchunk_size_feed_forwardZseq_len_dimr   r   r   intermediater   r   r   r   rc   r   layernorm_beforelayernorm_afterr   rl   r-   r.   r_   G  s   



zFlavaLayer.__init__Fr   r   r   r   c           	      C   sb   | j | ||||d}|d }|dd  }|| }| |}| |}| ||}|f| }|S r   )r   r   r   r   r   )	r,   r   r   r   r   Zself_attention_outputsr   r   Zlayer_outputr-   r-   r.   r   S  s   


zFlavaLayer.forwardr   )r3   r4   r5   r6   r   r_   r7   r   r   rD   r
   r	   r   r   r-   r-   rl   r.   r   D  s"    r   c                       sn   e Zd Zdeddf fddZ					ddejd	eej d
eej dededede	e
ef fddZ  ZS )FlavaEncoderrX   r$   Nc                    s:   t     | _t fddt jD | _d| _d S )Nc                    s   g | ]}t  qS r-   )r   r)   r   rX   r-   r.   
<listcomp>v  s    z)FlavaEncoder.__init__.<locals>.<listcomp>F)	r^   r_   rX   r   Z
ModuleListrangenum_hidden_layerslayergradient_checkpointingr   rl   r   r.   r_   s  s   
 
zFlavaEncoder.__init__FTr   r   r   r   output_hidden_statesreturn_dictc                 C   s   |rdnd }|r
dnd }t | jD ]:\}	}
|r||f }|d ur$||	 nd }| jr7| jr7| |
j||||}n|
||||}|d }|rK||d f }q|rS||f }|satdd |||fD S t|||dS )Nr-   r   r   c                 s   s    | ]	}|d ur|V  qd S rU   r-   )r)   rC   r-   r-   r.   r/         z'FlavaEncoder.forward.<locals>.<genexpr>)last_hidden_stater   
attentions)	enumerater   r   ZtrainingZ_gradient_checkpointing_func__call__r1   r   )r,   r   r   r   r   r   r   Zall_hidden_statesZall_self_attentionsiZlayer_moduleZlayer_head_maskZlayer_outputsr-   r-   r.   r   y  s4   	

zFlavaEncoder.forward)NNFFT)r3   r4   r5   r   r_   r7   r   r   rD   r
   r1   r   r   r   r-   r-   rl   r.   r   r  s,    	
r   c                       s2   e Zd Zdef fddZdejfddZ  ZS )FlavaPoolerrX   c                    s*   t    t|j|j| _t | _d S rU   )r^   r_   r   r   rc   r   ZTanh
activationr   rl   r-   r.   r_     s   
zFlavaPooler.__init__r   c                 C   s(   |d d df }|  |}| |}|S Nr   )r   r   )r,   r   Zfirst_token_tensorpooled_outputr-   r-   r.   r     s   

zFlavaPooler.forwardr   r-   r-   rl   r.   r     s    r   c                   @   s:   e Zd ZeZdZdZdeej	ej
ejf ddfddZdS )FlavaPreTrainedModelflavaTmoduler$   Nc                 C   sX  t |tjtjfr#|jjjd| jjd |j	dur!|j	j
  dS dS t |tjrF|jjjd| jjd |jdurD|jj|j 
  dS dS t |tjr[|j	j
  |jjd dS t |trh|j	j
  dS t |tr|jj
  |jj
  |jdur|jj
  dS dS t |tr|jr|jj
  dS dS t |tr|jj| jj dS dS )zInitialize the weightsg        )meanZstdNr   )r   r   r   r   weightdataZnormal_rX   Zinitializer_ranger   Zzero_r   r   r   Zfill_FlavaMaskedPredictionHeadrW   rd   rh   r`   FlavaMultimodalModeluse_cls_token
FlavaModellogit_scalelogit_scale_init_value)r,   r  r-   r-   r.   _init_weights  s8   






z"FlavaPreTrainedModel._init_weights)r3   r4   r5   r   config_classbase_model_prefixsupports_gradient_checkpointingr
   r   r   r   r   r  r-   r-   r-   r.   r     s
    &r   c                       s   e Zd ZeZdZdZddedef fddZde	j
fd	d
Zde	j
fddZdeeee f ddfddZe								ddeej deej dee deej deej dee dee dee deeef fddZ  ZS )FlavaImageModelzflava.image_modelr~   TrX   add_pooling_layerc                    X   t  | || _t|| _t|| _tj|j	|j
d| _|r#t|nd| _|   dS v
        add_pooling_layer (bool, *optional*, defaults to `True`):
            Whether to add a pooling layer
        r   N)r^   r_   rX   rW   rn   r   encoderr   r   rc   r   	layernormr   pooler	post_initr,   rX   r  rl   r-   r.   r_     s   

zFlavaImageModel.__init__r$   c                 C      | j jS rU   rn   rf   r+   r-   r-   r.   get_input_embeddings     z$FlavaImageModel.get_input_embeddingsr   c                 C      || j _d S rU   r  r,   r   r-   r-   r.   set_input_embeddings     z$FlavaImageModel.set_input_embeddingsheads_to_pruneNc                 C   *   |  D ]\}}| jj| j| qdS z
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        Nitemsr  r   r   r   r,   r!  r   r   r-   r-   r.   _prune_heads     zFlavaImageModel._prune_headsr   r}   r   r   r   r   r   c	                 C   s   |dur|n| j j}|dur|n| j j}|dur|n| j j}|du r&td| || j j}| j|||d}	| j|	|||||d}
|
d }| 	|}| j
durT| 
|nd}|sb||f|
dd  S t|||
j|
jdS )z
        bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, image_num_patches)`):
            Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
        Nz You have to specify pixel_values)r   r}   r   r   r   r   r   r   r   r   Zpooler_outputr   r   )rX   r   r   use_return_dictr   get_head_maskr   rn   r  r  r  r   r   r   )r,   r~   r   r}   r   r   r   r   r   embedding_outputencoder_outputssequence_outputr   r-   r-   r.   r     s:   
zFlavaImageModel.forwardTNNNNNNNN)r3   r4   r5   r   r  r  main_input_namerD   r_   r   Moduler  r  r   r   r   r'  r   r   r7   r   r   r
   r1   r   r   r   r-   r-   rl   r.   r    sH    	

r  c                       s   e Zd ZeZdZddedef fddZdefdd	Z	d
e
jfddZdeeee f ddfddZe								ddeej deej deej deej deej dee dee dee deeef fddZ  ZS )FlavaTextModelzflava.text_modelTrX   r  c                    r  r  )r^   r_   rX   r   rn   r   r  r   r   rc   r   r  r   r  r  r  rl   r-   r.   r_   >  s   

zFlavaTextModel.__init__r$   c                 C   r  rU   rn   r   r+   r-   r-   r.   r  N  r  z#FlavaTextModel.get_input_embeddingsr   c                 C   r  rU   r5  r  r-   r-   r.   r  Q  r   z#FlavaTextModel.set_input_embeddingsr!  Nc                 C   r"  r#  r$  r&  r-   r-   r.   r'  T  r(  zFlavaTextModel._prune_headsr   r   r   r   r   r   r   r   c	                 C   s  |dur|n| j j}|dur|n| j j}|dur|n| j j}|du r&td| }	|du r6tj|	|jd}| 	|| j j
}| ||	|j}
| j|||d}| j||
||||d}|d }| |}| jdurl| |nd}|sz||f|dd  S t|||j|jdS )	a  
        input_ids (`torch.LongTensor` of shape `(batch_size, text_seq_length)`):
            Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See
            [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input
            IDs?](../glossary#input-ids)
        token_type_ids (`torch.LongTensor` of shape `(batch_size, text_seq_length)`, *optional*):
            Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
            1]`:
            - 0 corresponds to a *sentence A* token,
            - 1 corresponds to a *sentence B* token.
            [What are token type IDs?](../glossary#token-type-ids)
        NzYou have to specify input_idsr   )r   r   r   r)  r   r   r*  )rX   r   r   r+  r   rs   r7   onesr   r,  r   get_extended_attention_maskrn   r  r  r  r   r   r   )r,   r   r   r   r   r   r   r   r   r   extended_attention_maskr-  r.  r/  r   r-   r-   r.   r   \  sJ   
zFlavaTextModel.forwardr0  r1  )r3   r4   r5   r   r  r  rD   r_   re   r  r   r3  r  r   r   r   r'  r   r   r7   r   r
   r1   r   r   r   r-   r-   rl   r.   r4  8  sF    	

r4  c                       s   e Zd ZeZdZdZddef fddZdee	e
e	 f dd	fd
dZe										ddejdeej deej dee dee dee deeef fddZ  ZS )r  zflava.multimodal_modelr   TrX   c                    sv   t  | || _| jj| _| jrttdd|j| _	t
|| _tj|j|jd| _|r2t|nd| _|   dS )r  r   r   N)r^   r_   rX   r  r   ra   r7   rb   rc   rd   r   r  r   r   r  r   r  r  r  rl   r-   r.   r_     s   

zFlavaMultimodalModel.__init__r!  r$   Nc                 C   r"  r#  r$  r&  r-   r-   r.   r'    r(  z!FlavaMultimodalModel._prune_headsr   r   r   r   r   c                 C   s&  |dur|n| j j}|dur|n| j j}|dur|n| j j}| \}}}	| jr=| j|dd}
tj	|
|fdd}|d7 }|du rKtj
||f|jd}| || j j}| |||f|j}| j||||||d}|d }| |}| jdur{| |nd}|s||f|dd  S t|||j|jdS )	z
        hidden_states (`torch.FloatTensor` of shape `(batch_size, image_num_patches + text_seq_len, hidden_size)`):
            The concatenated hidden states of unimodal encoders.
        Nrq   r   ru   r6  r)  r   r*  )rX   r   r   r+  rs   r  rd   r   r7   r|   r7  r   r,  r   r8  r  r  r  r   r   r   )r,   r   r   r   r   r   r   r   r   r   r   r9  r.  r/  r   r-   r-   r.   r     sD   
zFlavaMultimodalModel.forwardr0  )NNNNN)r3   r4   r5   r   r  r  r2  r_   r   r   r   r'  r   r7   r   r   rD   r
   r1   r   r   r   r-   r-   rl   r.   r    s6    
r  c                       s  e Zd ZeZdef fddZe							ddeej	 deej	 deej	 deej	 d	ee
 d
ee
 dee
 dejfddZe								ddeej	 deej dee
 deej	 deej	 d	ee
 d
ee
 dee
 dejfddZe											ddeej deej deej	 deej	 deej	 deej deej	 dee
 d	ee
 d
e
dee
 deeef fddZ  ZS )r  rX   c                    s0  t  | t|jtstdt|j dt|jts(tdt|j dt|j	t
s;tddt|j	 d |j}|j}|j	}|j| _|j| _|j| _|j| _t|| _t|| _t|| _t| j| j| _t| j| j| _tt| jj| _t| j| j| _ t| j| j| _!| "  d S )NzLconfig.text_config is expected to be of type FlavaTextConfig but is of type r   zNconfig.image_config is expected to be of type FlavaImageConfig but is of type zMconfig.multimodal_config is expected to be of type FlavaMultimodalConfig but zis of type )#r^   r_   r   text_configr   	TypeErrortypeimage_configr   multimodal_configr   Zprojection_dimrc   Ztext_hidden_sizeZimage_hidden_sizeZmm_hidden_sizer4  
text_modelr  image_modelr  multimodal_modelr   r   image_projectiontext_projectionra   r7   ZtensorrX   r
  r	  image_to_mm_projectiontext_to_mm_projectionr  )r,   rX   r:  r=  r>  rl   r-   r.   r_     sF   


zFlavaModel.__init__Nr   r   r   r   r   r   r   r$   c              	   C   s.   | j |||||||d}|d }	| |	}
|
S )a  
        input_ids (`torch.LongTensor` of shape `(batch_size, text_seq_length)`):
            Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See
            [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input
            IDs?](../glossary#input-ids)
        token_type_ids (`torch.LongTensor` of shape `(batch_size, text_seq_length)`, *optional*):
            Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
            1]`:
            - 0 corresponds to a *sentence A* token,
            - 1 corresponds to a *sentence B* token.
            [What are token type IDs?](../glossary#token-type-ids)

        Returns:
            text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
            applying the projection layer to the pooled output of [`FlavaTextModel`].

        Examples:

        ```python
        >>> from transformers import AutoProcessor, FlavaModel

        >>> model = FlavaModel.from_pretrained("{0}")
        >>> processor = AutoProcessor.from_pretrained("{0}")

        >>> inputs = processor(
        ...     text=["a photo of a cat", "a photo of a dog"], max_length=77, padding="max_length", return_tensors="pt"
        ... )
        >>> text_features = model.get_text_features(**inputs)
        ```
        )r   r   r   r   r   r   r   r   )r?  rC  )r,   r   r   r   r   r   r   r   Ztext_outputsr   Ztext_featuresr-   r-   r.   get_text_features8  s   )

zFlavaModel.get_text_featuresr~   r   r}   r   c	              
   C   s0   | j ||||||||d}	|	d }
| |
}|S )a  
        bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, image_num_patches)`):
            Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).

        Returns:
            image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
            applying the projection layer to the pooled output of [`FlavaImageModel`].

        Examples:

        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoProcessor, FlavaModel

        >>> model = FlavaModel.from_pretrained("{0}")
        >>> processor = AutoProcessor.from_pretrained("{0}")

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> inputs = processor(images=image, return_tensors="pt")

        >>> image_features = model.get_image_features(**inputs)
        ```
        )r~   r   r   r   r   r   r}   r   r   )r@  rB  )r,   r~   r   r}   r   r   r   r   r   Zimage_outputsr   Zimage_featuresr-   r-   r.   get_image_featuresp  s   &
zFlavaModel.get_image_featuresTimage_attention_maskskip_multimodal_encoderc              	   C   sz  |dur|n| j j}|
stdd}d}d}d}|dur7| j||||	|
|d}|d |d }}| |d }d}d}d}d}|dur_| j|||||	|
|d}|d |d }}| |d }d}d}|dur|dur|s|dur|j\}}}| jj	r|d7 }t
j|||jd	}t
j||gdd
}nd}t
j||gdd
}| j|||d}|d }|s||||||fS t||||||dS )a	  
        input_ids (`torch.LongTensor` of shape `(batch_size, image_num_patches + text_seq_len)`):
            Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See
            [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input
            IDs?](../glossary#input-ids)
        token_type_ids (`torch.LongTensor` of shape `(batch_size, image_num_patches + text_seq_len)`, *optional*):
            Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
            1]`:
            - 0 corresponds to a *sentence A* token,
            - 1 corresponds to a *sentence B* token.
            [What are token type IDs?](../glossary#token-type-ids)
        bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, image_num_patches)`):
            Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
        skip_multimodal_encoder (*bool*, *optional*):
            Skip any calculations for multimodal encoder. Useful if multimodal encoding is not going to be used.
        image_attention_mask (`torch.Tensor` of shape `(batch_size, image_num_patches)`, *optional*):
            Mask to avoid performing attention on padding pixel values for image inputs. Mask values selected in `[0, 1]`:
            - 1 for pixel values that are real (i.e., **not masked**),
            - 0 for pixel values that are padding (i.e., **masked**).

        Examples:

        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoProcessor, FlavaModel

        >>> model = FlavaModel.from_pretrained("facebook/flava-full")
        >>> processor = AutoProcessor.from_pretrained("facebook/flava-full")

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> inputs = processor(text=["a photo of a cat"], images=image, return_tensors="pt", padding=True)

        >>> outputs = model(**inputs)

        >>> image_embeddings = outputs.image_embeddings
        >>> text_embeddings = outputs.text_embeddings
        >>> multimodal_embeddings = outputs.multimodal_embeddings

        >>> outputs.image_embeddings.shape
        torch.Size([1, 197, 768])

        >>> text_embeddings.shape
        torch.Size([1, 7, 768])

        >>> multimodal_embeddings.shape
        torch.Size([1, 205, 768])
        ```
        NzRFLAVA model requires hidden states to work. Please set `output_hidden_states=True`)r~   r   r   r   r   r   r   rr   rq   )r   r   r   r   r   r   r   r   r6  ru   )r   r   )r   r   r    r!   r"   r#   )rX   r   r   r@  rD  r?  rE  rw   rA  r  r7   r7  r   r|   r   )r,   r   r~   r   r   r   r   rH  rI  r   r   r   r   Zimage_statesZimage_mm_projectionr   r    Ztext_statesZtext_mm_projectionr!   r"   r#   r   r   r   Zattention_mask_imageZattention_multimodalZmultimodal_inputr-   r-   r.   r     s   C
	zFlavaModel.forward)NNNNNNNr1  )NNNNNNNNNTN)r3   r4   r5   r   r  r_   r   r   r7   r   rD   r8   rF  r   rG  
LongTensorr
   r	   r   r   r   r-   r-   rl   r.   r  	  s    +	7	
5	

r  c                       s<   e Zd Zdedef fddZdejdejfddZ  ZS )	FlavaImageCodebookResPathin_sizeout_sizec                    s   t    |d }t }t |d< tj||ddd|d< t |d< tj||ddd|d< t |d	< tj||ddd|d
< t |d< tj||ddd|d< t|| _d S )N   Zrelu_1r   r   r   paddingZconv_1Zrelu_2Zconv_2Zrelu_3Zconv_3Zrelu_4r   Zconv_4)r^   r_   r   r   ReLUr   
Sequentialpath)r,   rL  rM  kwargsZhid_sizerS  rl   r-   r.   r_   6  s   
z"FlavaImageCodebookResPath.__init__r   r$   c                 C   
   |  |S rU   )rS  r,   r   r-   r-   r.   r   F     
z!FlavaImageCodebookResPath.forward	r3   r4   r5   r   r_   r7   r   r   r   r-   r-   rl   r.   rK  5  s    rK  c                       s@   e Zd Zdededef fddZdejdejfdd	Z  ZS )
FlavaImageCodebookBlockrL  rM  
num_layersc                    sP   t    d|d  | _||krtj||ddd| _nt | _t||| _d S )Nr   rr   r   rO  )	r^   r_   	post_gainr   r   id_pathZIdentityrK  res_path)r,   rL  rM  rZ  rT  rl   r-   r.   r_   K  s   

z FlavaImageCodebookBlock.__init__r   r$   c                 C   s   |  || j| |  S rU   )r\  r[  r]  rV  r-   r-   r.   r   W  s   zFlavaImageCodebookBlock.forwardrX  r-   r-   rl   r.   rY  J  s    rY  c                       sJ   e Zd Zddededededef
 fddZd	ejd
ejfddZ  Z	S )FlavaImageCodebookLayerGroupT
num_blocksrZ  rL  rM  use_poolc                    s   t    t }t|D ]!}|dkr t||||d|d  < qt||||d|d  < q|r8tjdd|d< t|| _d S )Nr   Zblock_r   rr   )r   pool)	r^   r_   r   r   rY  r   Z	MaxPool2drR  group)r,   r_  rZ  rL  rM  r`  blocksr   rl   r-   r.   r_   \  s   
z%FlavaImageCodebookLayerGroup.__init__r   r$   c                 C   rU  rU   )rb  rV  r-   r-   r.   r   j  rW  z$FlavaImageCodebookLayerGroup.forwardr0  )
r3   r4   r5   r   rD   r_   r7   r   r   r   r-   r-   rl   r.   r^  [  s    $r^  a"  
    The FLAVA's image codebook model inspired from DALL-E's original encoder. Outputs raw hidden states and can be used
    to generate image tokens for an image based on DALL-E's vocab. Used to generate labels for MIM. Use
    `get_codebook_indices` to get image tokens for an image.
    )Zcustom_introc                       sx   e Zd ZdZeZdZdZdedef fddZ	de
jde
jfd	d
Zde
jde
jfddZde
jde
jfddZ  ZS )FlavaImageCodebook r~   FrX   rT  c                    sd  t  | || _|j| _|j| _|j| _|j| _|j| _| j| j }t }t	
 |d< t	jd| j | jddd|d< t }t	j| jd| j ddd|d	< t| j|d| j d| j |d
< t| j|d| j d| j |d< t| j|d| j d| j |d< t| j|d| j d| j dd|d< t	||d< t	|| _|   | jjr|  D ]}d|_qd S d S )NZrelu   r   r   rO  conv   r   inputZgroup_1rr   Zgroup_2rN  Zgroup_3F)r`  Zgroup_4r   )r^   r_   rX   Z
num_groupsinput_channelsZnum_blocks_per_grouprc   r   r   r   rQ  r   r^  rR  rc  r  freeze
parametersZrequires_grad)r,   rX   rT  rZ  Zoutput_blocksrc  paramrl   r-   r.   r_   |  sB   
zFlavaImageCodebook.__init__r$   c                 C   s"   d t | |}tj|ddS )Na  
        Args:
            pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
                Pixel values. Codebook pixel values can be obtained using [`AutoImageProcessor`] by passing
                `return_codebook_pixels=True`. See [`FlavaImageProcessor.__call__`] for details.

        Examples:
        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoImageProcessor, FlavaImageCodebook

        >>> model = FlavaImageCodebook.from_pretrained("{0}")
        >>> image_processor = AutoImageProcessor.from_pretrained("{0}")

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> inputs = image_processor([image], return_codebook_pixels=True, return_tensors="pt")
        >>> inputs = dict(pixel_values=inputs.codebook_pixel_values)

        >>> outputs = model.get_codebook_indices(**inputs)
        ```
        r   )Zaxis)format_CHECKPOINT_FOR_CODEBOOK_DOCrc  r7   Zargmaxr,   r~   Zz_logitsr-   r-   r.   get_codebook_indices  s   
z'FlavaImageCodebook.get_codebook_indicesc                 C   s   |  |}tjdd|S )Nr   ru   )rc  r   ZSoftmaxrp  r-   r-   r.   get_codebook_probs  s   
z%FlavaImageCodebook.get_codebook_probsc                 C   s`   d t t|jdkrtd|j d|jd | jkr+td|jd  d| j | |S )Na  
        Args:
            pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
                Pixel values. Codebook pixel values can be obtained using [`AutoImageProcessor`] by passing
                `return_codebook_pixels=True`. See [`FlavaImageProcessor.__call__`] for details.

        Examples:

        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoImageProcessor, FlavaImageCodebook

        >>> model = FlavaImageCodebook.from_pretrained("{0}")
        >>> image_processor = AutoImageProcessor.from_pretrained("{0}")

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> inputs = image_processor([image], return_codebook_pixels=True, return_tensors="pt")
        >>> inputs = dict(pixel_values=inputs.codebook_pixel_values)

        >>> outputs = model(**inputs)
        >>> print(outputs.shape)
        (1, 196)
        ```
        rN  zinput shape z
 is not 4dr   z
input has z channels but model built for )rn  ro  r   rw   r   rj  rc  )r,   r~   r-   r-   r.   r     s   
zFlavaImageCodebook.forward)r3   r4   r5   r  r   r  r2  r  r   r_   r7   r   rq  rr  r8   r   r   r-   r-   rl   r.   rd  o  s    ,rd  c                       $   e Zd Z fddZdd Z  ZS )FlavaPredictionHeadTransformc                    sV   t    t|j|j| _t|jtrt	|j | _
n|j| _
tj|j|jd| _d S )Nr   )r^   r_   r   r   rc   r   r   r   r   r   transform_act_fnr   r   r   rl   r-   r.   r_     s   
z%FlavaPredictionHeadTransform.__init__c                 C   s"   |  |}| |}| |}|S rU   )r   ru  r   r   r-   r-   r.   r     s   


z$FlavaPredictionHeadTransform.forwardr3   r4   r5   r_   r   r   r-   r-   rl   r.   rt    s    	rt  c                       s.   e Zd Zd fdd	Zdd Zdd Z  ZS )	r  Nc                    sb   t    || _t|| _tj|j|jdd| _	t
t|j| _|d ur*|| j	_| j| j	_d S )NFr   )r^   r_   rX   rt  	transformr   r   rc   r   decoderra   r7   rb   r   r  )r,   rX   r  rl   r-   r.   r_     s   

z"FlavaMaskedPredictionHead.__init__c                 C   s   | j | j_ d S rU   )r   rx  r+   r-   r-   r.   _tie_weights	  s   z&FlavaMaskedPredictionHead._tie_weightsc                 C   r   rU   )rw  rx  rV  r-   r-   r.   r        

z!FlavaMaskedPredictionHead.forwardrU   )r3   r4   r5   r_   ry  r   r   r-   r-   rl   r.   r    s    r  c                       rs  )FlavaITMHeadc                    s.   t    || _t|| _t|jd| _d S )Nrr   )	r^   r_   rX   r   r  r   r   rc   seq_relationshipr   rl   r-   r.   r_     s   

zFlavaITMHead.__init__c                 C   r   rU   )r  r|  rV  r-   r-   r.   r     rz  zFlavaITMHead.forwardrv  r-   r-   rl   r.   r{    s    r{  c                       rs  )FlavaGlobalContrastiveHeadc                    s   t    || _|j| _d S rU   )r^   r_   rX   global_backprop_contrastiver   rl   r-   r.   r_      s   
z#FlavaGlobalContrastiveHead.__init__c                    s2  t |}t j rt j s!t j d jd} g}g}nQ d}t j }	| j	r?t jj
j }t jj
j}n$fddt|	D } fddt|	D }t j|  t j| |t j  t j| jd }t |}t |}t  |dd| }
t |dd| }|
||fS )Nr   r6  c                       g | ]}t  qS r-   r7   Z
zeros_liker   )r    r-   r.   r   5      z6FlavaGlobalContrastiveHead.forward.<locals>.<listcomp>c                    r  r-   r  r   )r   r-   r.   r   6  r  r   )r7   expdistributedZis_availableZis_initializedr   rs   r   Zget_world_sizer~  r   rz   Z
all_gatherr   Zget_rankr|   r   r   )r,   r   r    r	  ZtemperaturelabelsZimage_embeddings_allZtext_embeddings_allZlocal_batch_sizeZ
world_sizelogits_per_imagelogits_per_textr-   )r   r    r.   r   %  s,   





z"FlavaGlobalContrastiveHead.forwardrv  r-   r-   rl   r.   r}    s    r}  zk
    The FLAVA model for pretraining which outputs losses, embeddings, logits and transformer outputs.
    c                (       s&  e Zd Zg dZddedeej f fddZde	j
fdd	Ze															
		d dee	j dee	j dee	j dee	j dee	j
 dee	j
 dee	j
 dee	j dee	j
 dee dee	j
 dee	j
 dee	j
 dee dedee dee deee	j
 ef f$ddZ  ZS )!FlavaForPreTraining)zmmm_text_head.decoder.biaszmmm_image_head.decoder.biaszmlm_head.decoder.biaszmim_head.decoder.biasNrX   image_codebookc                    s   t  | t|| _|| _| jdu r|jrt|j| _t|j	| _
t|j| _t|| _t|j	| _t|j| _t|| _|j	j| _|jj| _|j| _|j| _|j| _|j| _|j| _|j| _|j| _|j| _|   dS )z
        image_codebook ([`nn.Module`]):
            If passed, the image codebook will be set to this. Otherwise, it will be initialized using the
            image_codebook_config defined in the config first as the first parameter.
        N)r^   r_   r  r   r  Zinit_codebookrd  Zimage_codebook_configr  r=  mim_headr:  mlm_headr{  itm_headmmm_image_headmmm_text_headr}  global_contrastive_headr   image_vocab_sizetext_vocab_size
mlm_weight
mim_weightglobal_contrastive_weightce_ignore_index
itm_weightmmm_image_weightmmm_text_weight skip_unmasked_multimodal_encoderr  )r,   rX   r  rl   r-   r.   r_   U  s,   




zFlavaForPreTraining.__init__r   c                 C   s"   |  dkr||dd}|S )Nrr   r   rq   )rv   r{   rs   rV  r-   r-   r.   _resize_to_2dx  s   z!FlavaForPreTraining._resize_to_2dTr   input_ids_maskedr~   codebook_pixel_valuesr   r   r   r   rH  r  
mlm_labels
mim_labels
itm_labelsr   r   r   return_lossr$   c           6      C   s  |dur|n| j j}|dur|n| j j}|
dur|
n| j}
|du r,|dur,td |}| j||||||	|
||dd
}| j|||||	|||dd	}d}|j}|j}|j}|j}|j	}d } } } } } }} d }! }" }#}$d }% }&}'|dus~|dur|du r|r| j
du rtd|du rtd| j
|}| jdkr	|dur	|du r	|}(|dur| |}| |}| j||d< |(dd|d	 dddf }(|| j})||) }*|(|)ddf }(| |(}!|rtj|!d
| j|*d
}|| j9 }n| |(}!| jdkrj|durj|du rj|}+|dure| |}|+dd|d	 dddf }+|| j})||) },|+|)ddf }+| |+}"|rdtj|"d
| j|,d
}|| j9 }n| |+}"| jdkr|dur| |}%|dur|d}-t|-  |-|-!dg}|rtj|%|} | | j9 } |dur|| }|dur|| }|dur|| }|| }|dur-| j"dkr-|}(|d	d	 }.|(dddd|. ddf }(|dur(| |}| |}| j||d< || j})||) }*|(|)ddf }(| #|(}$|r'tj|$d
| j|*d
}|| j"9 }n| #|(}$|dur| j$dkr|}+|+dd|d	 dddf }+|dur| |}|| j})||) },|+|)ddf }+| %|+}#|rtj|#d
| j|,d
}|| j$9 }n| %|+}#|dur	|dur	| j&dkr	| j'|dddddf }/tjj(|/d
d}/| j)|dddddf }0tjj(|0d
d}0| jj*j+,t-t. | /|0|/| jj*\}&}'}1|dur|&| }&|'| }'|1| }1|r	tj|&|1}2tj|'|1}3|2|3 d }|| j&9 }t0||| |||d}4|r&|41 s&t2dd |43 D }|s||j4dur5|j45 nd||j6durB|j65 nd|j	|j7durP|j75 nd||j4dur]|j45 nd||j6durj|j65 nd||j7durw|j75 nd|!|"|%|&|&|$|#f}5|r|41 s||4f|5 }5t8dd |5D S t9d&i d|d|4d|d|j4d|d|j6d|j	d|j7d|d|j4d|d|j6d|d|j7d|!d |"d!|%d"|&d#|'d$|$d%|#S )'a  
        input_ids (`torch.LongTensor` of shape `(batch_size, text_seq_len)`):
            Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See
            [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input
            IDs?](../glossary#input-ids)
        input_ids_masked (`torch.LongTensor` of shape `(batch_size, text_seq_len)`):
            Indices of input sequence tokens in the vocabulary. These ones are the masked version of the original task
            to be used with MLM. Indices can be obtained using [`AutoTokenizer`] along with
            [`DataCollatorForMaskedLanguageModeling`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids)
        token_type_ids (`torch.LongTensor` of shape `(batch_size, text_seq_len)`, *optional*):
            Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
            1]`:
            - 0 corresponds to a *sentence A* token,
            - 1 corresponds to a *sentence B* token.
            [What are token type IDs?](../glossary#token-type-ids)
        bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, image_num_patches)`):
            Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
        image_attention_mask (`torch.FloatTensor` of shape `(batch_size, image_num_patches)`, *optional*):
            Mask to avoid performing attention on padding token indices specifically for images. Mask values selected
            in `[0, 1]`:
            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.
            [What are attention masks?](../glossary#attention-mask)
        skip_unmasked_multimodal_encoder (*bool*, *optional*):
            Skip any calculations for multimodal encoder for unmasked inputs. FLAVA pretraining doesn't need unmasked
            multimodal embeddings or outputs as of now.
        mlm_labels (`torch.LongTensor` of shape `(batch_size, text_seq_len)`, *optional*):
            Labels for computing the left-to-right language and multimodal masked modeling loss (next word prediction).
            Indices should be in `[-100, 0, ..., text_config.vocab_size - 1]` (see `input_ids` docstring). Tokens with
            indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0,
            ..., text_config.vocab_size - 1]`.
        mim_labels (`torch.LongTensor` of shape `(batch_size, image_num_patches)`, *optional*):
            Labels for computing the image and multimodal masked modeling loss. Indices should be in `[-100, 0, ...,
            image_config.vocab_size - 1]`. Tokens with indices set to `-100` are ignored (masked), the loss is only
            computed for the tokens with labels in `[0, ..., image_config.vocab_size - 1]`. If not passed, they are
            generated automatically using the image codebook assigned to the model. By default, it uses
            [`FlavaImageCodebook`]. See [`FlavaImageCodebook`] to understand how to generate mim_labels.
        itm_labels (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*):
            Labels for computing the image-text matching loss. 0 means the pairs don't match and 1 means they match.
            The pairs with 0 will be skipped for calculation of MMM and global contrastive losses as well.
        return_loss (`bool`, *optional*, default to None):
            Whether to return calculated loss or not.
        codebook_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_image_patches, patch_size, patch_size, 3)`, *optional*):
            Pixel values for image patches that are used to compute the image codebook labels for masked image modeling.

        Examples:
        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import FlavaForPreTraining, AutoProcessor

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> model = FlavaForPreTraining.from_pretrained("facebook/flava-full")
        >>> processor = AutoProcessor.from_pretrained("facebook/flava-full")

        >>> text = ["a photo of a cat"]

        >>> inputs = processor(
        ...     images=[image],
        ...     text=text,
        ...     return_masks=True,
        ...     return_codebook_pixels=True,
        ...     padding=True,
        ...     max_length=77,
        ...     return_tensors="pt",
        ... )


        >>> output = model(**inputs)
        ```
        Nz`input_ids_masked` isn't passed which means MLM loss won't be calculated correctlySetting it to `input_ids` so that model can work. Please pass it if this is unintentional. This is usually OKAY if you are doing inference on unmasked text...T)
r   r~   r   r   r   rH  rI  r   r   r   )	r   r~   r   r   rH  r   r   r   r   z`return_loss` is set to True but the image codebook is not initialized and no `mim_labels`  have been passed. Reinstantiate the model with `init_codebook` set to True or pass in your custom `mim_labels`z`codebook_pixel_value` are required to generate `mim_labels` if loss is expected. Call `AutoProcessor` with `return_codebook_pixels` set to Truer   r   rq   rr   ru   )r;   r<   r=   r>   r?   r@   c                 s   s     | ]}|d ur
|ndV  qd S r   r-   )r)   rF   r-   r-   r.   r/     s    z.FlavaForPreTraining.forward.<locals>.<genexpr>c                 s   s    | ]	}|d u r|V  qd S rU   r-   )r)   r   r-   r-   r.   r/     r   rF   rG   r   r   r    r!   r"   r#   rH   rI   rJ   rK   rL   rM   rN   rO   rP   rQ   rR   rS   rT   r-   ):rX   r+  r  r  loggerwarningr   r   r    r"   r  RuntimeErrorr   rq  r  r  r  ners   r  r   rz   Zcross_entropyr{   r  r  r  r  r  r  r7   whereanynewr  r  r  r  r  rC  	normalizerB  r	  r  Zclamp_LOGIT_SCALE_CLAMP_MINLOGIT_SCALE_CLAMP_MAXr  r:   rB   sumrA   r   r'   r!   r#   r1   rE   )6r,   r   r  r~   r  r   r   r   r   rH  r  r  r  r  r   r   r   r  Zflava_outputZflava_masked_outputZpos_maskr   r    rH   rJ   rL   Z
total_lossZmim_lossZmlm_lossZmmm_text_lossZmmm_image_lossZgc_lossZitm_lossrN   rO   rT   rS   rP   r  r  Zsequence_for_imageZmasked_tokensZmim_labels_filteredZsequence_for_textZmlm_labels_filteredZ	pos_pairsZ	end_indexZtext_embeddingZimage_embeddingZ	gc_labelsZgc_loss_imageZgc_loss_textZflava_lossesr   r-   r-   r.   r   }  s  _
 


"


 

"















"




 



	
	
zFlavaForPreTraining.forwardrU   )NNNNNNNNNNNNNNTNN)r3   r4   r5   Z_tied_weights_keysr   r   r   r3  r_   r7   r   r  r   rJ  r8   rD   r
   r	   rE   r   r   r-   r-   rl   r.   r  G  sv    #	
r  )r  rd  r  r  r  r   r4  )Jr6   r   r   r   dataclassesr   typingr   r   r   r   r   r	   r
   r7   Ztorch.utils.checkpointr   Zactivationsr   Zmodeling_outputsr   r   Zmodeling_utilsr   r   r   utilsr   r   r   r   Zconfiguration_flavar   r   r   r   r   Z
get_loggerr3   r  ro  r  r  r   r   r:   rE   r3  rW   re   r   r   r   r   r   r   r   r   r   r   r  r4  r  r  rK  rY  r^  rd  rt  r  r{  r}  r  __all__r-   r-   r-   r.   <module>   s   $
	%$ed$9C*.3"`p_  -u(   #