o
    Zh                     @   s  d dl Z d dlZd dlmZ d dlmZmZmZmZm	Z	 d dl
Zd dlZd dlmZ d dlm  mZ d dlmZmZmZ d dlmZ ddlmZ ddlmZ dd	lmZ dd
lmZmZm Z  ddl!m"Z"m#Z# ddl$m%Z%m&Z&m'Z'm(Z( ddl)m*Z*m+Z+m,Z, e(-e.Z/eG dd de%Z0eG dd de%Z1eG dd de%Z2G dd dej3Z4	dUdej3dej5dej5dej5deej5 de6de6fdd Z7G d!d" d"ej3Z8G d#d$ d$ej3Z9G d%d& d&eZ:G d'd( d(ej3Z;G d)d* d*ej3Z<G d+d, d,ej3Z=d-d. Z>	1dVd2ej5d3e6d4e6d5e6d6e6d7ej5fd8d9Z?dWd<d=Z@d>d? ZAd@dA ZBG dBdC dCej3ZCe&G dDdE dEe#ZDe&dFdGG dHdI dIeDZEG dJdK dKej3ZFe&dLdGG dMdN dNeDZGe&G dOdP dPeDZHe&dQdGG dRdS dSeDZIg dTZJdS )X    N)	dataclass)AnyCallableOptionalTupleUnion)BCEWithLogitsLossCrossEntropyLossMSELoss)_calculate_fan_in_and_fan_out   )ACT2FN)_prepare_4d_attention_mask)GradientCheckpointingLayer)BaseModelOutputBaseModelOutputWithPoolingImageClassifierOutput)ALL_ATTENTION_FUNCTIONSPreTrainedModel)ModelOutputauto_docstringcan_return_tuplelogging   )Siglip2ConfigSiglip2TextConfigSiglip2VisionConfigc                   @   j   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eeejdf  ed< dZeeejdf  ed< dS )Siglip2VisionOutputa  
    Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.

    Args:
        image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
            The image embeddings obtained by applying the projection layer to the pooler_output.
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    Nimage_embedslast_hidden_state.hidden_states
attentions)__name__
__module____qualname____doc__r   r   torchFloatTensor__annotations__r    r!   r   r"    r*   r*   [/var/www/auris/lib/python3.10/site-packages/transformers/models/siglip2/modeling_siglip2.pyr   -      
 r   c                   @   r   )Siglip2TextOutputa  
    Base class for text model's outputs that also contains a pooling of the last hidden states.

    Args:
        text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
            The text embeddings obtained by applying the projection layer to the pooler_output.
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    Ntext_embedsr    .r!   r"   )r#   r$   r%   r&   r.   r   r'   r(   r)   r    r!   r   r"   r*   r*   r*   r+   r-   J   r,   r-   c                   @   s   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eej ed< dZeej ed< dZeej ed< dZeed< dZeed	< d
ee fddZdS )Siglip2Outputa  
    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
            Contrastive loss for image-text similarity.
        logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
            The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
            similarity scores.
        logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
            The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
            similarity scores.
        text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
            The text embeddings obtained by applying the projection layer to the pooled output of [`Siglip2TextModel`].
        image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
            The image embeddings obtained by applying the projection layer to the pooled output of [`Siglip2VisionModel`].
        text_model_output (`BaseModelOutputWithPooling`):
            The output of the [`Siglip2TextModel`].
        vision_model_output (`BaseModelOutputWithPooling`):
            The output of the [`Siglip2VisionModel`].
    Nlosslogits_per_imagelogits_per_textr.   r   text_model_outputvision_model_outputreturnc                    s   t  fdd  D S )Nc                 3   s.    | ]}|d vr | nt  | V  qdS ))r3   r4   N)getattrto_tuple).0kselfr*   r+   	<genexpr>   s
    
z)Siglip2Output.to_tuple.<locals>.<genexpr>)tuplekeysr:   r*   r:   r+   r7      s   zSiglip2Output.to_tuple)r#   r$   r%   r&   r0   r   r'   r(   r)   r1   r2   r.   r   r3   r   r4   r   r   r7   r*   r*   r*   r+   r/   g   s   
 r/   c                	       sb   e Zd Zdef fddZedejdejde	dejfdd	Z
d
ejdejdejfddZ  ZS )Siglip2VisionEmbeddingsconfigc                    sn   t    || _|j| _|j| _tj|j| j | j | jd| _	|j
| _
t| j
d | _t| j
| j| _d S )N)Zin_featuresZout_featuresg      ?)super__init__r@   hidden_size	embed_dimZ
patch_sizennLinearZnum_channelspatch_embeddingZnum_patchesintposition_embedding_size	Embeddingposition_embeddingr;   r@   	__class__r*   r+   rB      s   
z Siglip2VisionEmbeddings.__init__positional_embeddingsspatial_shapes
max_lengthr5   c                 C   s   |j d }| j d }| j}tj|||f| j|d}| dddd} | jjdkr/| tj	} t
|D ];}|| \}}	tj| ||	fddd	d
}
|
|||	 dd}
|
|}
|
||d||	 f< |
d ||||	 df< q3|S )ac  
        Resize positional embeddings to image-specific size and pad to a fixed size.

        Args:
            positional_embeddings (`torch.Tensor`):
                Position embeddings of shape (height, width, embed_dim)
            spatial_shapes (`torch.LongTensor`):
                Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to
            max_length (`int`):
                Maximum length of the positional embeddings to pad resized positional embeddings to

        Returns:
            `torch.Tensor`: Embeddings of shape (batch_size, max_length, embed_dim)
        r   )devicedtype   r   cpuZbilinearFT)sizemodeZalign_cornersZ	antialiasN)shaperT   r'   emptyrS   ZpermuteZ	unsqueezetypetofloat32rangeFZinterpolatereshape	transpose)rO   rP   rQ   
batch_sizerD   Zsource_dtypeZresulted_positional_embeddingsiheightwidthZresized_embeddingsr*   r*   r+   resize_positional_embeddings   s2   

	
z4Siglip2VisionEmbeddings.resize_positional_embeddingspixel_valuesc                 C   sT   | j jj}|  |j|d}| jj| j| jd}| j|||jd d}|| }|S )aH  
        Args:
            pixel_values (`torch.FloatTensor`):
                Pixel values of shape (batch_size, max_num_patches, num_channels * patch_size * patch_size)
            spatial_shapes (`List[Tuple[int, int]]`):
                Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to
        )rT   rR   r   )rQ   )	rG   weightrT   r\   rK   r`   rI   rf   rY   )r;   rg   rP   Ztarget_dtypeZpatch_embedsrO   Zresized_positional_embeddings
embeddingsr*   r*   r+   forward   s   


zSiglip2VisionEmbeddings.forward)r#   r$   r%   r   rB   staticmethodr'   Tensor
LongTensorrH   rf   r(   rj   __classcell__r*   r*   rM   r+   r?      s    $:r?           modulequerykeyvalueattention_maskscalingdropoutc           
      K   s|   t ||dd| }|d ur|| }tjj|dt jd|j}tjj	||| j
d}t ||}	|	dd }	|	|fS )NrR   )dimrT   )ptrainingr   rU   )r'   matmulra   rE   
functionalZsoftmaxr]   r\   rT   rv   rz   
contiguous)
rp   rq   rr   rs   rt   ru   rv   kwargsattn_weightsattn_outputr*   r*   r+   eager_attention_forward   s   
r   c                       sj   e Zd ZdZdeeef f fddZ		ddej	de
ej	 d	e
e d
eej	e
ej	 f fddZ  ZS )Siglip2Attentionz=Multi-headed attention from 'Attention Is All You Need' paperr@   c                    s   t    || _|j| _|j| _| j| j | _| j| j | jkr-td| j d| j d| jd | _	|j
| _d| _t| j| j| _t| j| j| _t| j| j| _t| j| j| _d S )Nz;embed_dim must be divisible by num_heads (got `embed_dim`: z and `num_heads`: z).      F)rA   rB   r@   rC   rD   num_attention_heads	num_headshead_dim
ValueErrorscaleZattention_dropoutrv   	is_causalrE   rF   k_projv_projq_projout_projrL   rM   r*   r+   rB     s$   

zSiglip2Attention.__init__NFr!   rt   output_attentionsr5   c              
   C   s  |j \}}}| |}| |}| |}	|||| j| jdd}|||| j| jdd}|	||| j| jdd}	t}
| j	j
dkr[| j	j
dkrU|rUtd nt| j	j
 }
|
| |||	|| j| j| jsjdn| jd\}}|||| }| |}|sd}||fS )	z#Input shape: Batch x Time x Channelr   rU   eagerZsdpaz`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.ro   )r   ru   rv   N)rY   r   r   r   viewr   r   ra   r   r@   _attn_implementationloggerZwarning_oncer   r   r   rz   rv   r`   r}   r   )r;   r!   rt   r   rb   
seq_lengthrD   Zqueriesr>   valuesZattention_interfacer   r   r*   r*   r+   rj     s:   




zSiglip2Attention.forward)NF)r#   r$   r%   r&   r   r   r   rB   r'   rl   r   boolr   rj   rn   r*   r*   rM   r+   r     s    r   c                       s2   e Zd Z fddZdejdejfddZ  ZS )
Siglip2MLPc                    sD   t    || _t|j | _t|j|j	| _
t|j	|j| _d S N)rA   rB   r@   r   Z
hidden_actactivation_fnrE   rF   rC   Zintermediate_sizefc1fc2rL   rM   r*   r+   rB   P  s
   
zSiglip2MLP.__init__r!   r5   c                 C   s"   |  |}| |}| |}|S r   )r   r   r   )r;   r!   r*   r*   r+   rj   W  s   


zSiglip2MLP.forward)r#   r$   r%   rB   r'   rl   rj   rn   r*   r*   rM   r+   r   O  s    r   c                
       sV   e Zd Zdeeef f fddZ	ddejdejde	e
 deej fd	d
Z  ZS )Siglip2EncoderLayerr@   c                    sR   t    |j| _tj| j|jd| _t|| _	tj| j|jd| _
t|| _d S )NZeps)rA   rB   rC   rD   rE   	LayerNormlayer_norm_epslayer_norm1r   	self_attnlayer_norm2r   mlprL   rM   r*   r+   rB   _  s   

zSiglip2EncoderLayer.__init__Fr!   rt   r   r5   c                 C   sb   |}|  |}| j|||d\}}|| }|}| |}| |}|| }|f}|r/||f7 }|S )a=  
        Args:
            hidden_states (`torch.FloatTensor`):
                Input to the layer of shape `(batch, seq_len, embed_dim)`.
            attention_mask (`torch.FloatTensor`):
                Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
            output_attentions (`bool`, *optional*, defaults to `False`):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
        )r!   rt   r   )r   r   r   r   )r;   r!   rt   r   residualr   outputsr*   r*   r+   rj   g  s    




zSiglip2EncoderLayer.forward)F)r#   r$   r%   r   r   r   rB   r'   rl   r   r   r   r(   rj   rn   r*   r*   rM   r+   r   ^  s    r   c                
       sZ   e Zd ZdZdef fddZe			ddeej	 dee
 dee
 d	efd
dZ  ZS )Siglip2Encoderz
    Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
    [`Siglip2EncoderLayer`].

    Args:
        config: Siglip2Config
    r@   c                    s:   t     | _t fddt jD | _d| _d S )Nc                    s   g | ]}t  qS r*   )r   )r8   _r@   r*   r+   
<listcomp>  s    z+Siglip2Encoder.__init__.<locals>.<listcomp>F)	rA   rB   r@   rE   Z
ModuleListr^   Znum_hidden_layerslayersZgradient_checkpointingrL   rM   r   r+   rB     s   
 
zSiglip2Encoder.__init__Nrt   r   output_hidden_statesr5   c           
      C   s   |dur|n| j j}|dur|n| j j}|rdnd}|rdnd}|}| jD ]}|r.||f }||||d}	|	d }|rB||	d f }q%|rJ||f }t|||dS )ad  
        Args:
            inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
                Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
                This is useful if you want more control over how to convert `input_ids` indices into associated vectors
                than the model's internal embedding lookup matrix.
            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                [What are attention masks?](../glossary#attention-mask)
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            output_hidden_states (`bool`, *optional*):
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
                for more detail.
            return_dict (`bool`, *optional*):
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
        Nr*   )r   r   r   )r    r!   r"   )r@   r   r   r   r   )
r;   inputs_embedsrt   r   r   Zencoder_statesZall_attentionsr!   Zencoder_layerZlayer_outputsr*   r*   r+   rj     s2   


zSiglip2Encoder.forwardNNN)r#   r$   r%   r&   r   rB   r   r   r'   rl   r   r   rj   rn   r*   r*   rM   r+   r     s     r   c                       s`   e Zd Zdef fddZee		ddejdej	dej
dee d	ee d
efddZ  ZS )Siglip2VisionTransformerr@   c                    sr   t    || _|j}t|| _t|| _tj	||j
d| _t|ds%dn|j| _| jr1t|| _|jdk| _d S )Nr   vision_use_headTflash_attention_2)rA   rB   r@   rC   r?   ri   r   encoderrE   r   r   post_layernormhasattrr   use_head$Siglip2MultiheadAttentionPoolingHeadheadr   _use_flash_attention_2r;   r@   rD   rM   r*   r+   rB     s   



z!Siglip2VisionTransformer.__init__Nrg   rt   rP   r   r   r5   c                 C   s   |dur|n| j j}|dur|n| j j}| ||}|dur(| js(t||j}n|}| j||||d}|j}	| 	|	}	| j
rD| |	|nd}
t|	|
|j|jdS )z
        spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
            Tensor containing the spatial dimensions (height, width) of the input images.
        Nr   rt   r   r   r    pooler_outputr!   r"   )r@   r   r   ri   r   r   rT   r   r    r   r   r   r   r!   r"   )r;   rg   rt   rP   r   r   r!   Zencoder_attention_maskencoder_outputsr    r   r*   r*   r+   rj     s,   
z Siglip2VisionTransformer.forwardNN)r#   r$   r%   r   rB   r   r   r'   r(   rl   rm   r   r   r   rj   rn   r*   r*   rM   r+   r     s&    r   c                	       sX   e Zd Zdef fddZ			ddeej deej deej dej	fd	d
Z
  ZS )Siglip2TextEmbeddingsr@   c                    sR   t    |j}t|j|| _t|j|| _| j	dt
|jddd d S )Nposition_ids)r   rR   F)
persistent)rA   rB   rC   rE   rJ   Z
vocab_sizetoken_embeddingZmax_position_embeddingsrK   Zregister_bufferr'   Zarangeexpandr   rM   r*   r+   rB     s   

zSiglip2TextEmbeddings.__init__N	input_idsr   r   r5   c                 C   s   |d ur	|j d n|j d }| jjj d }||kr#td| d| |d u r2| jd d d |f }|d u r;| |}| |}|| }|S )NrR   rw   r   zRSequence length must be less than max_position_embeddings (got `sequence length`: z and max_position_embeddings: )rY   rK   rh   r   r   r   )r;   r   r   r   r   Zmax_position_embeddingZposition_embeddingsri   r*   r*   r+   rj   (  s"   

zSiglip2TextEmbeddings.forwardr   )r#   r$   r%   r   rB   r   r'   rm   r(   rl   rj   rn   r*   r*   rM   r+   r     s    r   c                 C   s   dd }||d|  k s||d|  krt jddd ||| | }||| | }| d| d d| d  |   | |td  | | | j||d d S )	Nc                 S   s   dt | t d  d S )N      ?       @)matherfsqrt)xr*   r*   r+   norm_cdfF  s   z _trunc_normal_.<locals>.norm_cdfrU   zjmean is more than 2 std from [a, b] in nn.init.trunc_normal_. The distribution of values may be incorrect.)
stacklevelr   r   )minmax)	warningswarnuniform_Zerfinv_mul_r   r   add_Zclamp_)tensormeanstdabr   lur*   r*   r+   _trunc_normal_C  s    	
r   r          r   r   r   r   r   r   r5   c                 C   sN   t   t| dd|| | || W d   dS 1 s w   Y  dS )an  Fills the input Tensor with values drawn from a truncated
    normal distribution. The values are effectively drawn from the
    normal distribution :math:`\mathcal{N}(	ext{mean}, 	ext{std}^2)`
    with values outside :math:`[a, b]` redrawn until they are within
    the bounds. The method used for generating the random values works
    best when :math:`a \leq 	ext{mean} \leq b`.

    NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
    bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
    and the result is subsequently scaled and shifted by the mean and std args.

    Args:
        tensor: an n-dimensional `torch.Tensor`
        mean: the mean of the normal distribution
        std: the standard deviation of the normal distribution
        a: the minimum cutoff value
        b: the maximum cutoff value
    r   r   N)r'   no_gradr   r   r   )r   r   r   r   r   r*   r*   r+   trunc_normal_tf_g  s   
"r   fan_innormalc           	      C   s  t | \}}|dkr|}n|dkr|}n
|dkr|| d }|| }|dkr3t| t|d d d S |dkrWt  | jt|d W d    d S 1 sPw   Y  d S |d	krtd
| }t  | | | W d    d S 1 syw   Y  d S td| )Nr   fan_outZfan_avgrU   truncated_normalg۶%?r   r   uniformr   zinvalid distribution )	r   r   r   r   r'   r   normal_r   r   )	r   r   rX   distributionr   r   denomZvarianceboundr*   r*   r+   variance_scaling_  s(   
"
"r   c                 C      t | ddd d S )Nr   r   rX   r   r   r   r*   r*   r+   lecun_normal_     r   c                 C   r   )Nr   r   r   r   r   r*   r*   r+   default_flax_embed_init  r   r   c                       sr   e Zd Zdef fddZee					ddeej	 deej	 deej	 dee
 d	ee
 d
efddZ  ZS )Siglip2TextTransformerr@   c                    s\   t    || _|j}t|| _t|| _tj	||j
d| _t||j| _|jdk| _d S )Nr   r   )rA   rB   r@   rC   r   ri   r   r   rE   r   r   final_layer_normrF   Zprojection_sizer   r   r   r   rM   r*   r+   rB     s   


zSiglip2TextTransformer.__init__Nr   rt   r   r   r   r5   c                 C   s   |d ur|n| j j}|d ur|n| j j}|d u rtd| }|d|d }| j||d}|d ur<| js<t||j	}| j
||||d}|j}	| |	}	|	d d dd d f }
| |
}
t|	|
|j|jdS )NzYou have to specify input_idsrR   )r   r   r   r   )r@   r   r   r   rW   r   ri   r   r   rT   r   r    r   r   r   r!   r"   )r;   r   rt   r   r   r   Zinput_shaper!   r   r    pooled_outputr*   r*   r+   rj     s4   


zSiglip2TextTransformer.forwardNNNNN)r#   r$   r%   r   rB   r   r   r   r'   rl   r   r   rj   rn   r*   r*   rM   r+   r     s,    r   c                   @   s0   e Zd ZeZdZdZg dZdZdZ	dd Z
dS )Siglip2PreTrainedModelZsiglip2T)r   r   r?   r   r   c                 C   sf  t |tr%t | jtr| jjjn| jj}tjj|j	j
dt| d dS t |tjr2t|j
 dS t |trytj|jj
 tj|jj
 tj|jj
 tj|jj
 tj|jj tj|jj tj|jj tj|jj dS t |trtj|jj
 tj|jj
 tjj|jjdd tjj|jjdd dS t |trtj|jj tj|jjj tj|jjj dS t |t rt!"t!#d}|j$j%| |j&j'  dS t |t(rtjj|j)j
| jjjd | jj* d dS t |tj+tj,frt-|j
 |jdurtj|j dS dS t |tj.r1|jj'  |j
j%d dS dS )zInitialize the weightsr   r   gư>r   r   N)/
isinstancer?   r@   r   vision_configrC   rE   initr   rK   rh   npr   rJ   r   r   Zxavier_uniform_r   r   r   r   Zzeros_Zbiasr   r   r   r   probedata	attentionZin_proj_weightZin_proj_biasSiglip2Modelr'   logr   logit_scaleZfill_
logit_biasZzero_Siglip2ForImageClassification
classifierZinitializer_factorrF   ZConv2dr   r   )r;   rp   re   Zlogit_scale_initr*   r*   r+   _init_weights  sX   

"






z$Siglip2PreTrainedModel._init_weightsN)r#   r$   r%   r   config_classZbase_model_prefixZsupports_gradient_checkpointingZ_no_split_modulesZ_supports_flash_attn_2Z_supports_sdpar  r*   r*   r*   r+   r     s    r   zL
    The text model from Siglip2 without any head or projection on top.
    )Zcustom_introc                       s   e Zd ZeZdef fddZdejfddZdd Z	e
e										dd
eej deej deej dee dee defddZ  ZS )Siglip2TextModelr@   c                    "   t  | t|| _|   d S r   )rA   rB   r   
text_model	post_initrL   rM   r*   r+   rB   (  s   
zSiglip2TextModel.__init__r5   c                 C   
   | j jjS r   r  ri   r   r:   r*   r*   r+   get_input_embeddings.     
z%Siglip2TextModel.get_input_embeddingsc                 C   s   || j j_d S r   r  )r;   rs   r*   r*   r+   set_input_embeddings1  s   z%Siglip2TextModel.set_input_embeddingsNr   rt   r   r   r   c                 C      | j |||||dS )a  
        Examples:

        ```python
        >>> from transformers import AutoTokenizer, Siglip2TextModel

        >>> model = Siglip2TextModel.from_pretrained("google/siglip2-base-patch16-224")
        >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip2-base-patch16-224")

        >>> # important: make sure to set padding="max_length" as that's how the model was trained
        >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")

        >>> outputs = model(**inputs)
        >>> last_hidden_state = outputs.last_hidden_state
        >>> pooled_output = outputs.pooler_output  # pooled (EOS token) states
        ```r   rt   r   r   r   )r  )r;   r   rt   r   r   r   r*   r*   r+   rj   4  s   zSiglip2TextModel.forwardr   )r#   r$   r%   r   r  rB   rE   Moduler	  r  r   r   r   r'   rl   r   r   rj   rn   r*   r*   rM   r+   r     s2    r  c                       sH   e Zd ZdZdef fddZddejdeej dejfd	d
Z	  Z
S )r   zMultihead Attention Pooling.r@   c                    sd   t    ttdd|j| _tjj|j|j	dd| _
tj|j|jd| _t|| _|j	| _d S )Nr   T)Zbatch_firstr   )rA   rB   rE   	Parameterr'   randnrC   r   ZMultiheadAttentionr   r   r   r   	layernormr   r   r   rL   rM   r*   r+   rB   [  s   

z-Siglip2MultiheadAttentionPoolingHead.__init__Nhidden_statert   r5   c                 C   s   |j d }| j|dd}|d ur3|j d |j d }}t||j|}|d| j|d}|d||}| j||||dd }|}| |}|| 	| }|d d df S )Nr   r   rR   )Z	attn_mask)
rY   r   repeatr   rT   r   r`   r   r  r   )r;   r  rt   rb   r   Z
target_lenZ
source_lenr   r*   r*   r+   rj   d  s   

z,Siglip2MultiheadAttentionPoolingHead.forwardr   )r#   r$   r%   r&   r   rB   r'   rl   r   rj   rn   r*   r*   rM   r+   r   X  s    *	r   zN
    The vision model from Siglip2 without any head or projection on top.
    c                       sx   e Zd ZeZdZdef fddZdejfddZ	e
e		ddejd	ejd
ejdee dee defddZ  ZS )Siglip2VisionModelrg   r@   c                    r  r   )rA   rB   r   vision_modelr  rL   rM   r*   r+   rB     s   
zSiglip2VisionModel.__init__r5   c                 C   r  r   )r  ri   rG   r:   r*   r*   r+   r	    r
  z'Siglip2VisionModel.get_input_embeddingsNpixel_attention_maskrP   r   r   c                 C   r  )a9  
        pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
            Mask to avoid performing attention on padding pixel indices.
        spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
            Tensor containing the spatial dimensions (height, width) of the input images.

        Examples:

        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoProcessor, Siglip2VisionModel

        >>> model = Siglip2VisionModel.from_pretrained("google/siglip2-base-patch16-224")
        >>> processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-224")

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> inputs = processor(images=image, return_tensors="pt")

        >>> outputs = model(**inputs)
        >>> last_hidden_state = outputs.last_hidden_state
        >>> pooled_output = outputs.pooler_output  # pooled features
        ```rg   rt   rP   r   r   )r  )r;   rg   r  rP   r   r   r*   r*   r+   rj     s   #zSiglip2VisionModel.forwardr   )r#   r$   r%   r   r  main_input_namerB   rE   r  r	  r   r   r'   r(   rl   rm   r   r   r   rj   rn   r*   r*   rM   r+   r  w  s,    r  c                       s@  e Zd ZeZdef fddZe					ddeej	 deej	 deej	 dee
 d	ee
 d
ejfddZe					ddeej deej	 deej dee
 d	ee
 d
ejfddZee									ddeej deej deej	 deej deej	 deej dee
 dee
 d	ee
 d
efddZ  ZS )r   r@   c                    s   t  | t|jtstdt|j dt|jts(tdt|j d|j}|j}t	
|}t
|}|j| _|j| _ttd| _ttd| _|   d S )NzNconfig.text_config is expected to be of type Siglip2TextConfig but is of type .zRconfig.vision_config is expected to be of type Siglip2VisionConfig but is of type r   )rA   rB   r   text_configr   	TypeErrorr[   r   r   r  _from_configr  r  r  rE   r  r'   r  r   r   r  )r;   r@   r  r   r  r  rM   r*   r+   rB     s,   

zSiglip2Model.__init__Nr   rt   r   r   r   r5   c                 C   F   |dur|n| j j}|dur|n| j j}| j|||||d}|j}|S )aM  
        Returns:
            text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
            applying the projection layer to the pooled output of [`Siglip2TextModel`].

        Examples:

        ```python
        >>> from transformers import AutoTokenizer, AutoModel
        >>> import torch

        >>> model = AutoModel.from_pretrained("google/siglip2-base-patch16-224")
        >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip2-base-patch16-224")

        >>> # important: make sure to set padding="max_length" as that's how the model was trained
        >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")
        >>> with torch.no_grad():
        ...     text_features = model.get_text_features(**inputs)
        ```Nr  )r@   r   r   r  r   )r;   r   rt   r   r   r   text_outputsr   r*   r*   r+   get_text_features  s   zSiglip2Model.get_text_featuresrg   r  rP   c                 C   r  )a  
        pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
            Mask to avoid performing attention on padding pixel indices.
        spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
            Tensor containing the spatial dimensions (height, width) of the input images.

        Returns:
            image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
            applying the projection layer to the pooled output of [`Siglip2VisionModel`].

        Examples:

        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoProcessor, AutoModel
        >>> import torch

        >>> model = AutoModel.from_pretrained("google/siglip2-base-patch16-224")
        >>> processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-224")

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> inputs = processor(images=image, return_tensors="pt")

        >>> with torch.no_grad():
        ...     image_features = model.get_image_features(**inputs)
        ```
        Nr  )r@   r   r   r  r   )r;   rg   r  rP   r   r   vision_outputsr   r*   r*   r+   get_image_features	  s   (zSiglip2Model.get_image_featuresreturn_lossc
              	   C   sD  |dur|n| j j}|	dur|	n| j j}	| j|||||	d}
| j|||||	d}|
j}|j}||jdddd }||jdddd }t||	 
|j}| j
|j| j
|j}}||  | }|	 }d}|rtj|d|jd	}t| d|  }tjj|| }tj|dd
 }| }t|||||||
dS )a  
        pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
            Mask to avoid performing attention on padding pixel indices.
        spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
            Tensor containing the spatial dimensions (height, width) of the input images.
        return_loss (`bool`, *optional*):
            Whether or not to return the contrastive loss.

        Examples:

        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoProcessor, AutoModel
        >>> import torch

        >>> model = AutoModel.from_pretrained("google/siglip2-base-patch16-224")
        >>> processor = AutoProcessor.from_pretrained("google/siglip2-base-patch16-224")

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> texts = ["a photo of 2 cats", "a photo of 2 dogs"]
        >>> # important: we pass `padding=max_length` since the model was trained with this
        >>> inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt")

        >>> with torch.no_grad():
        ...     outputs = model(**inputs)

        >>> logits_per_image = outputs.logits_per_image
        >>> probs = torch.sigmoid(logits_per_image) # these are the probabilities
        >>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'")
        31.9% that image 0 is 'a photo of 2 cats'
        ```
        Nr  r  rU   rR   T)ry   rx   Zkeepdimr   )rS   rx   )r0   r1   r2   r.   r   r3   r4   )r@   r   r   r  r  r   Znormr'   r{   tr\   rS   r   r   expeyerW   Z	ones_likerE   r|   Z
logsigmoidsumr   r/   )r;   r   rg   r  rP   rt   r   r"  r   r   r   r  r   r.   r2   r   r   r1   r0   r&  Zm1_diag1ZloglikZnllr*   r*   r+   rj   B  sR   2zSiglip2Model.forwardr   )	NNNNNNNNN)r#   r$   r%   r   r  rB   r   r   r'   rl   r   r(   r  rm   r!  r   r/   rj   rn   r*   r*   rM   r+   r     s     -8	
r   z
    Siglip2 vision encoder with an image classification head on top (a linear layer on top of the pooled final hidden states of
    the patch tokens) e.g. for ImageNet.
    c                       s   e Zd ZdZdeddf fddZee						ddee	j
 dee	j
 dee	j d	ee	j
 d
ee dee defddZ  ZS )r   rg   r@   r5   Nc                    sZ   t  | |j| _t|j}|j| _|jdkr"t|jj	|jnt
 | _|   d S )Nr   )rA   rB   
num_labelsr  r  r   r  rE   rF   rC   ZIdentityr   r  )r;   r@   r  rM   r*   r+   rB     s   "z&Siglip2ForImageClassification.__init__r  rP   labelsr   r   c                 C   s  |dur|n| j j}|dur|n| j j}| j|||||d}|j}|dur>|d |j}	tj||	 ddtj|	dd }ntj	|dd}| 
|}
d}|dur||
j}| j jdu r| jdkrfd| j _n| jdkr||jtjksw|jtjkr|d| j _nd| j _| j jdkrt }| jdkr||
 | }n+||
|}n%| j jdkrt }||
d	| j|d	}n| j jdkrt }||
|}t||
|j|jd
S )a  
        pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*):
            Mask to avoid performing attention on padding pixel indices.
        spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
            Tensor containing the spatial dimensions (height, width) of the input images.
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).

        Examples:

        ```python
        >>> from transformers import AutoImageProcessor, Siglip2ForImageClassification
        >>> import torch
        >>> from PIL import Image
        >>> import requests

        >>> torch.manual_seed(3)  # doctest: +IGNORE_RESULT
        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> # note: we are loading a `Siglip2Model` from the hub here,
        >>> # so the head will be randomly initialized, hence the predictions will be random if seed is not set above.
        >>> image_processor = AutoImageProcessor.from_pretrained("google/siglip2-base-patch16-224")
        >>> model = Siglip2ForImageClassification.from_pretrained("google/siglip2-base-patch16-224")

        >>> inputs = image_processor(images=image, return_tensors="pt")
        >>> outputs = model(**inputs)
        >>> logits = outputs.logits
        >>> # model predicts one of the two classes
        >>> predicted_class_idx = logits.argmax(-1).item()
        >>> print("Predicted class:", model.config.id2label[predicted_class_idx])
        Predicted class: LABEL_1
        ```
        N)rt   rP   r   r   ).Nr   r#  Z
regressionZsingle_label_classificationZmulti_label_classificationrR   )r0   logitsr!   r"   )r@   r   r   r  r    r\   rS   r'   r'  r   r   Zproblem_typer(  rT   longrH   r
   Zsqueezer	   r   r   r   r!   r"   )r;   rg   r  rP   r)  r   r   r   Zsequence_outputZ	pool_maskr*  r0   Zloss_fctr*   r*   r+   rj     sT   /"


"


z%Siglip2ForImageClassification.forward)NNNNNN)r#   r$   r%   r  r   rB   r   r   r   r'   rl   rm   r   r   rj   rn   r*   r*   rM   r+   r     s4    r   )r   r   r  r  r   )ro   )ro   r   r   r   )r   r   r   )Kr   r   dataclassesr   typingr   r   r   r   r   numpyr   r'   Ztorch.nnrE   Ztorch.nn.functionalr|   r_   r   r	   r
   Ztorch.nn.initr   Zactivationsr   Zmodeling_attn_mask_utilsr   Zmodeling_layersr   Zmodeling_outputsr   r   r   Zmodeling_utilsr   r   utilsr   r   r   r   Zconfiguration_siglip2r   r   r   Z
get_loggerr#   r   r   r-   r/   r  r?   rl   floatr   r   r   r   r   r   r   r   r   r   r   r   r   r   r  r   r  r   r   __all__r*   r*   r*   r+   <module>   s   
$l
G0P=(%

?>3; u~