o
    Zh6                     @   s  d Z ddlmZ ddlmZmZmZmZmZ ddl	Z	ddl	m
Z
 ddlmZmZmZ ddlmZ dd	lmZmZ dd
lmZmZmZ ddlmZmZ ddlmZmZmZmZm Z  ddl!m"Z"m#Z#m$Z$ e%e&Z'de	j(de	j(fddZ)de	j(de	j(fddZ*de	j(de	j(fddZ+eG dd deZ,eG dd deZ-eG dd deZ.G dd  d e
j/Z0G d!d" d"e
j/Z1	#	$dNd%e
j/d&e	j(d'e	j(d(e	j(d)ee	j( d*e2d+e2d,e3fd-d.Z4G d/d0 d0e
j/Z5G d1d2 d2e
j/Z6G d3d4 d4e
j/Z7eG d5d6 d6eZ8G d7d8 d8e
j/Z9G d9d: d:e
j/Z:ed;d<G d=d> d>e8Z;G d?d@ d@e
j/Z<edAd<G dBdC dCe8Z=eG dDdE dEe8Z>eG dFdG dGe8Z?eG dHdI dIe8Z@edJd<G dKdL dLe8ZAg dMZBdS )OzPyTorch CLIP model.    )	dataclass)AnyCallableOptionalTupleUnionN)nn)BCEWithLogitsLossCrossEntropyLossMSELoss   )ACT2FN) _create_4d_causal_attention_mask_prepare_4d_attention_mask)BaseModelOutputBaseModelOutputWithPoolingImageClassifierOutput)ALL_ATTENTION_FUNCTIONSPreTrainedModel)ModelOutputauto_docstringcan_return_tuplelogging	torch_int   )
CLIPConfigCLIPTextConfigCLIPVisionConfiglogitsreturnc                 C   s   t j| tjt| | jdS )Ndevice)r   
functionalZcross_entropytorcharangelenr!   )r    r&   U/var/www/auris/lib/python3.10/site-packages/transformers/models/clip/modeling_clip.pycontrastive_loss%   s   r(   
similarityc                 C   s    t | }t |  }|| d S )Ng       @)r(   t)r)   Zcaption_lossZ
image_lossr&   r&   r'   	clip_loss)   s   r+   tensorc                 C   s,   t | d}t j|ddd}t |d}|S )z
    This method is equivalent to tensor.norm(p=2, dim=-1, keepdim=True) and used to make
    model `executorch` exportable. See issue https://github.com/pytorch/executorch/issues/3566
       T)dimZkeepdim      ?)r#   powsum)r,   Zsquare_tensorZ
sum_tensorZnormed_tensorr&   r&   r'   _get_vector_norm/   s   r3   c                   @   j   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eeejdf  ed< dZeeejdf  ed< dS )CLIPVisionModelOutputa  
    Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.

    Args:
        image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
            The image embeddings obtained by applying the projection layer to the pooler_output.
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    Nimage_embedslast_hidden_state.hidden_states
attentions)__name__
__module____qualname____doc__r6   r   r#   FloatTensor__annotations__r7   r8   r   r9   r&   r&   r&   r'   r5   :      
 r5   c                   @   r4   )CLIPTextModelOutputa  
    Base class for text model's outputs that also contains a pooling of the last hidden states.

    Args:
        text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
            The text embeddings obtained by applying the projection layer to the pooler_output.
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    Ntext_embedsr7   .r8   r9   )r:   r;   r<   r=   rB   r   r#   r>   r?   r7   r8   r   r9   r&   r&   r&   r'   rA   W   r@   rA   c                   @   s   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eej ed< dZeej ed< dZeej ed< dZeed< dZeed	< d
ee fddZdS )
CLIPOutputa  
    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
            Contrastive loss for image-text similarity.
        logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
            The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
            similarity scores.
        logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
            The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
            similarity scores.
        text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
            The text embeddings obtained by applying the projection layer to the pooled output of [`CLIPTextModel`].
        image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
            The image embeddings obtained by applying the projection layer to the pooled output of [`CLIPVisionModel`].
        text_model_output (`BaseModelOutputWithPooling`):
            The output of the [`CLIPTextModel`].
        vision_model_output (`BaseModelOutputWithPooling`):
            The output of the [`CLIPVisionModel`].
    Nlosslogits_per_imagelogits_per_textrB   r6   text_model_outputvision_model_outputr   c                    s   t  fdd  D S )Nc                 3   s.    | ]}|d vr | nt  | V  qdS ))rG   rH   N)getattrto_tuple).0kselfr&   r'   	<genexpr>   s
    
z&CLIPOutput.to_tuple.<locals>.<genexpr>)tuplekeysrM   r&   rM   r'   rJ      s   zCLIPOutput.to_tuple)r:   r;   r<   r=   rD   r   r#   r>   r?   rE   rF   rB   r6   rG   r   rH   r   r   rJ   r&   r&   r&   r'   rC   t   s   
 rC   c                       sX   e Zd Zdef fddZdejdededejfdd	Zddej	dejfddZ
  ZS )CLIPVisionEmbeddingsconfigc                    s   t    || _|j| _|j| _|j| _tt	
| j| _tj|j| j| j| jdd| _| j| j d | _| jd | _t| j| j| _| jdt	| jddd d S )NF)Zin_channelsZout_channelsZkernel_sizeZstridebiasr-   r   position_idsr   r.   
persistent)super__init__rS   hidden_size	embed_dim
image_size
patch_sizer   	Parameterr#   Zrandnclass_embeddingZConv2dZnum_channelspatch_embeddingnum_patchesnum_positions	Embeddingposition_embeddingregister_bufferr$   expandrN   rS   	__class__r&   r'   rZ      s"   
"zCLIPVisionEmbeddings.__init__
embeddingsheightwidthr   c                 C   s  |j d d }| jjd}|j d d }tj s(||kr(||kr(| | jS |ddddf }|ddddf }|j d }	|| j }
|| j }t	|d }|
d|||	}|dddd}tjj||
|fdd	d
}|dddddd|	}tj||fddS )a   
        This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
        images. This method is also adapted to support torch.jit tracing.

        Adapted from:
        - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
        - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
        r   r   Nr.   r0   r   r-   ZbicubicF)sizemodeZalign_cornersr/   )shapere   weightZ	unsqueezer#   Zjit
is_tracingrU   r^   r   reshapeZpermuter   r"   Zinterpolateviewcat)rN   rk   rl   rm   rb   re   rc   Zclass_pos_embedZpatch_pos_embedr/   Z
new_heightZ	new_widthZsqrt_num_positionsr&   r&   r'   interpolate_pos_encoding   s*   



z-CLIPVisionEmbeddings.interpolate_pos_encodingFpixel_valuesc              
   C   s   |j \}}}}|s&|| jks|| jkr&td| d| d| j d| j d	| jjj}| |j|d}|ddd}| j	
|dd}	tj|	|gdd	}
|r[|
| |
|| }
|
S |
| | j }
|
S )
NzInput image size (*z) doesn't match model ().)dtyper-   r   r.   rp   )rq   r]   
ValueErrorra   rr   r{   toflatten	transposer`   rg   r#   rv   rw   re   rU   )rN   rx   rw   
batch_size_rl   rm   Ztarget_dtypeZpatch_embedsZclass_embedsrk   r&   r&   r'   forward   s    
zCLIPVisionEmbeddings.forwardF)r:   r;   r<   r   rZ   r#   Tensorintrw   r>   r   __classcell__r&   r&   ri   r'   rR      s     )rR   c                	       sX   e Zd Zdef fddZ			ddeej deej deej dej	fd	d
Z
  ZS )CLIPTextEmbeddingsrS   c                    sR   t    |j}t|j|| _t|j|| _| j	dt
|jddd d S )NrU   rV   FrW   )rY   rZ   r[   r   rd   Z
vocab_sizetoken_embeddingZmax_position_embeddingsre   rf   r#   r$   rg   rN   rS   r\   ri   r&   r'   rZ      s   

zCLIPTextEmbeddings.__init__N	input_idsrU   inputs_embedsr   c                 C   s   |d ur	|j d n|j d }| jjj d }||kr#td| d| |d u r2| jd d d |f }|d u r;| |}| |}|| }|S )Nr.   r   zRSequence length must be less than max_position_embeddings (got `sequence length`: z and max_position_embeddings: )rq   re   rr   r|   rU   r   )rN   r   rU   r   
seq_lengthZmax_position_embeddingZposition_embeddingsrk   r&   r&   r'   r      s"   

zCLIPTextEmbeddings.forward)NNN)r:   r;   r<   r   rZ   r   r#   
LongTensorr>   r   r   r   r&   r&   ri   r'   r      s    r           Tmodulequerykeyvalueattention_maskscalingdropoutoutput_attentionsc                 K   s   t ||dd| }	|d ur|	| }	tjj|	dt jd|j}	tjj	|	|| j
d}	t |	|}
|
dd }
|s>d }	|
|	fS )Nr.   r   )r/   r{   )ptrainingr   r-   )r#   matmulr   r   r"   ZsoftmaxZfloat32r}   r{   r   r   
contiguous)r   r   r   r   r   r   r   r   kwargsattn_weightsattn_outputr&   r&   r'   eager_attention_forward  s   r   c                       sv   e Zd ZdZdeeef f fddZ			ddej	de
ej	 d	e
ej	 d
e
e deej	e
ej	 f f
ddZ  ZS )CLIPAttentionz=Multi-headed attention from 'Attention Is All You Need' paperrS   c                    s   t    || _|j| _|j| _| j| j | _| j| j | jkr-td| j d| j d| jd | _	|j
| _d| _t| j| j| _t| j| j| _t| j| j| _t| j| j| _d S )Nz;embed_dim must be divisible by num_heads (got `embed_dim`: z and `num_heads`: rz         F)rY   rZ   rS   r[   r\   Znum_attention_headsZ	num_headshead_dimr|   scaleZattention_dropoutr   	is_causalr   Lineark_projv_projq_projout_projrh   ri   r&   r'   rZ   /  s$   

zCLIPAttention.__init__NFr8   r   causal_attention_maskr   r   c                 C   sH  |j \}}}| |}| |}	| |}
|||d| jdd}|	||d| jdd}	|
||d| jdd}
| jjdkrH|du| _	n|durU|durU|| }n|dur[|}t
}| jjdkrw| jjdkrq|rqtd nt| jj }|| ||	|
|| j	| j| jsd	n| j|d
	\}}|||| }| |}|sd}||fS )z#Input shape: Batch x Time x Channelr.   r   r-   flash_attention_2NeagerZsdpaz`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.r   )r   r   r   r   )rq   r   r   r   ru   r   r   rS   _attn_implementationr   r   loggerZwarning_oncer   r   r   r   rt   r   r   )rN   r8   r   r   r   r   r   r\   ZqueriesrQ   valuesZattention_interfacer   r   r&   r&   r'   r   C  sH   	





zCLIPAttention.forward)NNF)r:   r;   r<   r=   r   r   r   rZ   r#   r   r   boolr   r   r   r&   r&   ri   r'   r   ,  s"    r   c                       s2   e Zd Z fddZdejdejfddZ  ZS )CLIPMLPc                    sD   t    || _t|j | _t|j|j	| _
t|j	|j| _d S N)rY   rZ   rS   r   Z
hidden_actactivation_fnr   r   r[   Zintermediate_sizefc1fc2rh   ri   r&   r'   rZ   ~  s
   
zCLIPMLP.__init__r8   r   c                 C   s"   |  |}| |}| |}|S r   )r   r   r   )rN   r8   r&   r&   r'   r     s   


zCLIPMLP.forward)r:   r;   r<   rZ   r#   r   r   r   r&   r&   ri   r'   r   }  s    r   c                       s\   e Zd Zdeeef f fddZ	ddejdejdejde	e
 d	eej f
d
dZ  ZS )CLIPEncoderLayerrS   c                    sR   t    |j| _t|| _tj| j|jd| _	t
|| _tj| j|jd| _d S NZeps)rY   rZ   r[   r\   r   	self_attnr   	LayerNormlayer_norm_epslayer_norm1r   mlplayer_norm2rh   ri   r&   r'   rZ     s   


zCLIPEncoderLayer.__init__Fr8   r   r   r   r   c                 C   sd   |}|  |}| j||||d\}}|| }|}| |}| |}|| }|f}|r0||f7 }|S )aI  
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`): attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
                `(config.encoder_attention_heads,)`.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
        )r8   r   r   r   )r   r   r   r   )rN   r8   r   r   r   Zresidualr   outputsr&   r&   r'   r     s"   




zCLIPEncoderLayer.forwardr   )r:   r;   r<   r   r   r   rZ   r#   r   r   r   r   r>   r   r   r&   r&   ri   r'   r     s    r   c                   @   s(   e Zd ZeZdZdZdZdZdd Z	dS )CLIPPreTrainedModelZclipTc                 C   s  | j j}t|tr#|jjjjd|d d |jjjjd|d d n)t|t	rX| j j}t
jj|jd|jd | d t
jj|jj|j j| d t
jj|jj|j j| d nt|tr| j j}|jd d|j j d  | }|jd | }t
jj|jj|d t
jj|jj|d t
jj|jj|d t
jj|jj|d nt|tr| j j}|j jd d|j j d  | }d|j j d | }t
jj|jj|d t
jj|jj|d ntt|trt
jj|jj|jd | j j d t
jj|jj|jd | j j d nKt|trt
jj|jj| j jd | j j d n2t|t r3t
jj|jj| j jd | j j d nt|t!rLt
jj|j"j| j j#jd | j j d t|t
j$r`|j%j&  |jj'd t|t
j(ru|j%durw|j%j&  dS dS dS )	zInitialize the weightsr   g{Gz?)meanstdr   )r   r-   g      ?N))rS   Zinitializer_factor
isinstancer   r   rr   dataZnormal_re   rR   r   initr`   r\   ra   Zinitializer_ranger   num_hidden_layersr   r   r   r   r   r[   r   r   	CLIPModeltext_projectiontext_embed_dimvisual_projectionvision_embed_dimCLIPVisionModelWithProjectionCLIPTextModelWithProjectionCLIPForImageClassification
classifiervision_configr   rT   Zzero_Zfill_r   )rN   r   factorZin_proj_stdZout_proj_stdZfc_stdr&   r&   r'   _init_weights  sj   



 z!CLIPPreTrainedModel._init_weightsN)
r:   r;   r<   r   config_classZbase_model_prefixZsupports_gradient_checkpointingZ_supports_sdpaZ_supports_flash_attn_2r   r&   r&   r&   r'   r     s    r   c                       sf   e Zd ZdZdef fddZe				ddeej	 deej	 dee
 d	ee
 d
ef
ddZ  ZS )CLIPEncoderz
    Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
    [`CLIPEncoderLayer`].

    Args:
        config: CLIPConfig
    rS   c                    s:   t     | _t fddt jD | _d| _d S )Nc                    s   g | ]}t  qS r&   )r   )rK   r   rS   r&   r'   
<listcomp>  s    z(CLIPEncoder.__init__.<locals>.<listcomp>F)	rY   rZ   rS   r   Z
ModuleListranger   layersgradient_checkpointingrh   ri   r   r'   rZ     s   
 
zCLIPEncoder.__init__Nr   r   r   output_hidden_statesr   c                 C   s   |dur|n| j j}|dur|n| j j}|rdnd}|rdnd}|}t| jD ]1\}	}
|r2||f }| jrC| jrC| |
j||||}n|
||||d}|d }|rX||d f }q'|r`||f }t	|||dS )a  
        Args:
            inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
                Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
                This is useful if you want more control over how to convert `input_ids` indices into associated vectors
                than the model's internal embedding lookup matrix.
            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                [What are attention masks?](../glossary#attention-mask)
            causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Causal mask for the text model. Mask values selected in `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                [What are attention masks?](../glossary#attention-mask)
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            output_hidden_states (`bool`, *optional*):
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
                for more detail.
            return_dict (`bool`, *optional*):
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
        Nr&   )r   r   r   )r7   r8   r9   )
rS   r   r   	enumerater   r   r   Z_gradient_checkpointing_func__call__r   )rN   r   r   r   r   r   Zencoder_statesZall_attentionsr8   idxZencoder_layerZlayer_outputsr&   r&   r'   r     sD   &

zCLIPEncoder.forwardNNNN)r:   r;   r<   r=   r   rZ   r   r   r#   r   r   r   r   r   r&   r&   ri   r'   r     s&    r   c                       sr   e Zd Zdef fddZee					ddeej	 deej	 deej	 dee
 d	ee
 d
efddZ  ZS )CLIPTextTransformerrS   c                    sT   t    || _|j}t|| _t|| _tj	||j
d| _|j| _|jdk| _d S )Nr   r   )rY   rZ   rS   r[   r   rk   r   encoderr   r   r   final_layer_normeos_token_idr   _use_flash_attention_2r   ri   r&   r'   rZ   `  s   


zCLIPTextTransformer.__init__Nr   r   rU   r   r   r   c                 C   s@  |d ur|n| j j}|d ur|n| j j}|d u rtd| }|d|d }| j||d}t||j|j	d}|d urE| j
sEt||j}| j|||||d}	|	j}
| |
}
| jdkrw|
tj|
jd |
j	d|jtj|
j	djdd	f }n|
tj|
jd |
j	d|jtj|
j	d| jk jdd	f }t|
||	j|	jd
S )NzYou have to specify input_idsr.   )r   rU   r    )r   r   r   r   r   r-   r   )r{   r!   rp   r7   pooler_outputr8   r9   )rS   r   r   r|   rn   ru   rk   r   r{   r!   r   r   r   r7   r   r   r#   r$   rq   r}   r   Zargmaxr   r8   r9   )rN   r   r   rU   r   r   Zinput_shaper8   r   encoder_outputsr7   pooled_outputr&   r&   r'   r   n  sT   



	zCLIPTextTransformer.forwardNNNNN)r:   r;   r<   r   rZ   r   r   r   r#   r   r   r   r   r   r&   r&   ri   r'   r   _  s,    r   zI
    The text model from CLIP without any head or projection on top.
    )Zcustom_introc                          e Zd ZeZddgZdef fddZdejfddZ	d	d
 Z
ee					ddeej deej deej dee dee defddZ  ZS )CLIPTextModelr   r   rS   c                    "   t  | t|| _|   d S r   )rY   rZ   r   
text_model	post_initrh   ri   r&   r'   rZ        
zCLIPTextModel.__init__r   c                 C   
   | j jjS r   r   rk   r   rM   r&   r&   r'   get_input_embeddings     
z"CLIPTextModel.get_input_embeddingsc                 C      || j j_d S r   r   rN   r   r&   r&   r'   set_input_embeddings     z"CLIPTextModel.set_input_embeddingsNr   r   rU   r   r   c                 C   s   | j |||||dS )a9  
        Examples:

        ```python
        >>> from transformers import AutoTokenizer, CLIPTextModel

        >>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
        >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")

        >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")

        >>> outputs = model(**inputs)
        >>> last_hidden_state = outputs.last_hidden_state
        >>> pooled_output = outputs.pooler_output  # pooled (EOS token) states
        ```r   r   rU   r   r   )r   )rN   r   r   rU   r   r   r&   r&   r'   r     s   zCLIPTextModel.forwardr   )r:   r;   r<   r   r   _no_split_modulesrZ   r   Moduler   r   r   r   r   r#   r   r   r   r   r   r&   r&   ri   r'   r     s4    r   c                       sd   e Zd Zdef fddZee				ddeej	 dee
 dee
 d	ee
 d
ef
ddZ  ZS )CLIPVisionTransformerrS   c                    sR   t    || _|j}t|| _tj||jd| _	t
|| _tj||jd| _d S r   )rY   rZ   rS   r[   rR   rk   r   r   r   pre_layrnormr   r   post_layernormr   ri   r&   r'   rZ     s   


zCLIPVisionTransformer.__init__NFrx   r   r   rw   r   c           	      C   s   |d ur|n| j j}|d ur|n| j j}|d u rtd| j||d}| |}| j|||d}|j}|d d dd d f }| |}t	|||j
|jdS )Nz You have to specify pixel_values)rw   )r   r   r   r   r   )rS   r   r   r|   rk   r   r   r7   r   r   r8   r9   )	rN   rx   r   r   rw   r8   r   r7   r   r&   r&   r'   r     s*   	

zCLIPVisionTransformer.forwardNNNF)r:   r;   r<   r   rZ   r   r   r   r#   r>   r   r   r   r   r&   r&   ri   r'   r     s&    
r   zK
    The vision model from CLIP without any head or projection on top.
    c                       s~   e Zd ZeZdZdgZdef fddZdej	fddZ
ee							
ddeej dee dee dedef
ddZ  ZS )CLIPVisionModelrx   r   rS   c                    r   r   )rY   rZ   r   vision_modelr   rh   ri   r&   r'   rZ   -  r   zCLIPVisionModel.__init__r   c                 C   r   r   r   rk   ra   rM   r&   r&   r'   r   3  r   z$CLIPVisionModel.get_input_embeddingsNFr   r   rw   c                 C   s   | j ||||dS )a  
        Example:

        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoProcessor, CLIPVisionModel

        >>> model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
        >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> inputs = processor(images=image, return_tensors="pt")

        >>> outputs = model(**inputs)
        >>> last_hidden_state = outputs.last_hidden_state
        >>> pooled_output = outputs.pooler_output  # pooled CLS states
        ```rx   r   r   rw   )r   )rN   rx   r   r   rw   r&   r&   r'   r   6  s   zCLIPVisionModel.forwardr   )r:   r;   r<   r   r   main_input_namer   rZ   r   r   r   r   r   r   r#   r>   r   r   r   r   r&   r&   ri   r'   r   #  s.    r   c                       s$  e Zd ZeZg dZdef fddZe					ddee	j
 dee	j
 dee	j
 d	ee d
ee de	jfddZe				ddee	j d	ee d
ee dede	jf
ddZee								ddee	j dee	j dee	j
 dee	j dee d	ee d
ee dedefddZ  ZS )r   )r   r   rR   rS   c                    s   t  | t|jtstdt|j dt|jts(tdt|j d|j}|j}|j	| _	|j
| _|j
| _t|}|j| _t|}|j| _tj| j| j	dd| _tj| j| j	dd| _tt| jj| _|   d S )NzKconfig.text_config is expected to be of type CLIPTextConfig but is of type .zOconfig.vision_config is expected to be of type CLIPVisionConfig but is of type FrT   )rY   rZ   r   text_configr   	TypeErrortyper   r   projection_dimr[   r   r   r   _from_configr   r   r   r   r   r   r   r_   r#   r,   rS   Zlogit_scale_init_valuelogit_scaler   )rN   rS   r  r   r   r   ri   r&   r'   rZ   a  s4   

zCLIPModel.__init__Nr   r   rU   r   r   r   c           	      C   sP   |dur|n| j j}|dur|n| j j}| j|||||d}|j}| |}|S )a  
        Returns:
            text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
            applying the projection layer to the pooled output of [`CLIPTextModel`].

        Examples:

        ```python
        >>> from transformers import AutoTokenizer, CLIPModel

        >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
        >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")

        >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
        >>> text_features = model.get_text_features(**inputs)
        ```Nr   )rS   r   r   r   r   r   )	rN   r   r   rU   r   r   text_outputsr   Ztext_featuresr&   r&   r'   get_text_features  s   
zCLIPModel.get_text_featuresFrx   rw   c                 C   sN   |dur|n| j j}|dur|n| j j}| j||||d}|j}| |}|S )aD  
        Returns:
            image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
            applying the projection layer to the pooled output of [`CLIPVisionModel`].

        Examples:

        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoProcessor, CLIPModel

        >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
        >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> inputs = processor(images=image, return_tensors="pt")

        >>> image_features = model.get_image_features(**inputs)
        ```Nr  )rS   r   r   r   r   r   )rN   rx   r   r   rw   vision_outputsr   Zimage_featuresr&   r&   r'   get_image_features  s   
zCLIPModel.get_image_featuresreturn_lossc	              	   C   s   |dur|n| j j}|dur|n| j j}| j||||d}	| j|||||d}
|	j}| |}|
j}| |}|t| }|t| }t	
|| |j}|| j |j }| }d}|rft|}t||||||
|	dS )a  
        return_loss (`bool`, *optional*):
            Whether or not to return the contrastive loss.

        Examples:

        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoProcessor, CLIPModel

        >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
        >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> inputs = processor(
        ...     text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
        ... )

        >>> outputs = model(**inputs)
        >>> logits_per_image = outputs.logits_per_image  # this is the image-text similarity score
        >>> probs = logits_per_image.softmax(dim=1)  # we can take the softmax to get the label probabilities
        ```Nr  r   )rD   rE   rF   rB   r6   rG   rH   )rS   r   r   r   r   r   r   r   r3   r#   r   r*   r}   r!   r
  expr+   rC   )rN   r   rx   r   rU   r  r   r   rw   r  r  r6   rB   rF   rE   rD   r&   r&   r'   r     sJ   '

zCLIPModel.forwardr   r   )NNNNNNNF)r:   r;   r<   r   r   r   rZ   r   r   r#   r   r   r>   r  r  r   r   rC   r   r   r&   r&   ri   r'   r   \  s    #+/	
r   c                       r   )r   r   r   rS   c                    @   t  | t|}|j| _tj|j|jdd| _	| 
  d S NFr  )rY   rZ   r   r	  r   r   r   r[   r  r   r   )rN   rS   r   ri   r&   r'   rZ   @  
   
z$CLIPTextModelWithProjection.__init__r   c                 C   r   r   r   rM   r&   r&   r'   r   K  r   z0CLIPTextModelWithProjection.get_input_embeddingsc                 C   r   r   r   r   r&   r&   r'   r   N  r   z0CLIPTextModelWithProjection.set_input_embeddingsNr   r   rU   r   r   c           	      C   s:   | j |||||d}|j}| |}t||j|j|jdS )a  
        Examples:

        ```python
        >>> from transformers import AutoTokenizer, CLIPTextModelWithProjection

        >>> model = CLIPTextModelWithProjection.from_pretrained("openai/clip-vit-base-patch32")
        >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")

        >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")

        >>> outputs = model(**inputs)
        >>> text_embeds = outputs.text_embeds
        ```r   )rB   r7   r8   r9   )r   r   r   rA   r7   r8   r9   )	rN   r   r   rU   r   r   r  r   rB   r&   r&   r'   r   Q  s   
z#CLIPTextModelWithProjection.forwardr   )r:   r;   r<   r   r   r   rZ   r   r   r   r   r   r   r   r#   r   r   rA   r   r   r&   r&   ri   r'   r   :  s4    r   c                       sx   e Zd ZeZdZdef fddZdejfddZ	e
e					ddeej d
ee dee dedef
ddZ  ZS )r   rx   rS   c                    r  r  )rY   rZ   r   r	  r   r   r   r[   r  r   r   rN   rS   r   ri   r&   r'   rZ     r  z&CLIPVisionModelWithProjection.__init__r   c                 C   r   r   r   rM   r&   r&   r'   r     r   z2CLIPVisionModelWithProjection.get_input_embeddingsNFr   r   rw   c                 C   s8   | j ||||d}|j}| |}t||j|j|jdS )a  
        Examples:

        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoProcessor, CLIPVisionModelWithProjection

        >>> model = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-base-patch32")
        >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> inputs = processor(images=image, return_tensors="pt")

        >>> outputs = model(**inputs)
        >>> image_embeds = outputs.image_embeds
        ```r  )r6   r7   r8   r9   )r   r   r   r5   r7   r8   r9   )rN   rx   r   r   rw   r  r   r6   r&   r&   r'   r     s   
z%CLIPVisionModelWithProjection.forwardr   )r:   r;   r<   r   r   r  rZ   r   r   r   r   r   r   r#   r>   r   r5   r   r   r&   r&   ri   r'   r   |  s,    r   z
    CLIP vision encoder with an image classification head on top (a linear layer on top of the pooled final hidden states of
    the patch tokens) e.g. for ImageNet.
    c                       sn   e Zd ZdZdeddf fddZee				ddee	j
 dee	j
 dee d	ee def
d
dZ  ZS )r   rx   rS   r   Nc                    sZ   t  | |j| _t|j}|j| _|jdkr"t|jj	|jnt
 | _|   d S )Nr   )rY   rZ   
num_labelsr   r	  r   r   r   r   r[   ZIdentityr   r   r  ri   r&   r'   rZ     s   "z#CLIPForImageClassification.__init__labelsr   r   c           
      C   sr  |dur|n| j j}|dur|n| j j}| j|||d}|j}tj|ddddddf dd}| |}d}|dur||j	}| j j
du rl| jdkrRd| j _
n| jdkrh|jtjksc|jtjkrhd| j _
nd| j _
| j j
dkrt }	| jdkr|	| | }n+|	||}n%| j j
dkrt }	|	|d| j|d}n| j j
dkrt }	|	||}t|||j|jd	S )
a  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        N)r   r   r   rp   Z
regressionZsingle_label_classificationZmulti_label_classificationr.   )rD   r   r8   r9   )rS   r   r   r   r7   r#   r   r   r}   r!   Zproblem_typer  r{   longr   r   Zsqueezer
   ru   r	   r   r8   r9   )
rN   rx   r  r   r   r   Zsequence_outputr   rD   Zloss_fctr&   r&   r'   r     sJ   $


"


z"CLIPForImageClassification.forwardr   )r:   r;   r<   r  r   rZ   r   r   r   r#   r   r   r   r   r   r&   r&   ri   r'   r     s(    r   )r   r   r   r   r   r   r   )r   T)Cr=   dataclassesr   typingr   r   r   r   r   r#   r   Ztorch.nnr	   r
   r   Zactivationsr   Zmodeling_attn_mask_utilsr   r   Zmodeling_outputsr   r   r   Zmodeling_utilsr   r   utilsr   r   r   r   r   Zconfiguration_clipr   r   r   Z
get_loggerr:   r   r   r(   r+   r3   r5   rA   rC   r   rR   r   floatr   r   r   r   r   r   r   r   r   r   r   r   r   r   r   __all__r&   r&   r&   r'   <module>   s   
$S/
Q2@`Z414 ^A@T