o
    Zh                    @   s  d dl Z d dlmZ d dlmZ d dlmZ d dlmZm	Z	m
Z
mZ d dlZd dlmZ ddlmZ ddlmZmZmZ dd	lmZ dd
lmZ ddlmZmZ ddlmZmZ ddlm Z m!Z! ddl"m#Z# ddl$m%Z%m&Z&m'Z'm(Z(m)Z)m*Z* ddl+m,Z, ddl-m.Z. ddl/m0Z0m1Z1 e( rd dl2m3Z3 ddl4m5Z5 e*6e7Z8eG dd deZ9eG dd de%Z:G dd dej;Z<G dd dej=Z>G dd  d ej=Z?G d!d" d"ej=Z@d#d$ ZAdKd%d&ZBd'ejCd(eDd)ejCfd*d+ZE	,		dLd-ej=d.ejCd/ejCd0ejCd1e	ejC d2eFd3e	eF d4e	eF d)e
ejCejCf fd5d6ZGG d7d8 d8ej=ZHG d9d: d:ej=ZIe&G d;d< d<e!ZJe&G d=d> d>eJZKe&G d?d@ d@eJeZLG dAdB dBej=ZMe&dCdDG dEdF dFeJZNe&dGdDG dHdI dIeJeZOg dJZPdS )M    N)Callable)	dataclass)partial)ListOptionalTupleUnion   )ACT2FN)CacheHybridCacheStaticCache)GenerationMixin)FlashAttentionKwargs)BaseModelOutputWithPastCausalLMOutputWithPast)ROPE_INIT_FUNCTIONSdynamic_rope_update)ALL_ATTENTION_FUNCTIONSPreTrainedModel)Unpack)ModelOutputauto_docstringcan_return_tupleis_torch_flex_attn_availableis_torchdynamo_compilinglogging)deprecate_kwarg   )	AutoModel   )Gemma3ConfigGemma3TextConfig)	BlockMask)make_flex_block_causal_maskc                   @   s$   e Zd ZU dZdZeej ed< dS )Gemma3ModelOutputWithPasta  
    Base class for Gemma3 outputs, with hidden states and attentions.

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
            Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
            `(batch_size, num_heads, sequence_length, embed_size_per_head)`)

            Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
            `past_key_values` input) to speed up sequential decoding.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
        image_hidden_states (`torch.FloatTensor`, *optional*):
            A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
            image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
    Nimage_hidden_states)	__name__
__module____qualname____doc__r&   r   torchFloatTensor__annotations__ r.   r.   Y/var/www/auris/lib/python3.10/site-packages/transformers/models/gemma3/modeling_gemma3.pyr%   =   s   
 r%   c                   @   s   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eeeej ef  ed< dZeeej  ed< dZeeej  ed< dZeej ed< dS )	Gemma3CausalLMOutputWithPasta  
    Base class for Gemma3 causal language model (or autoregressive) outputs.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Language modeling loss (for next-token prediction).
        logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`):
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
            Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
            `(batch_size, num_heads, sequence_length, embed_size_per_head)`)

            Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
            `past_key_values` input) to speed up sequential decoding.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
        image_hidden_states (`torch.FloatTensor`, *optional*):
            A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
            image_hidden_states of the model produced by the vision encoder after projecting last hidden state.
    Nlosslogitspast_key_valueshidden_states
attentionsr&   )r'   r(   r)   r*   r1   r   r+   r,   r-   r2   r3   r   r   r   r4   r   r5   r&   r.   r.   r.   r/   r0   ^   s   
 r0   c                	       sH   e Zd ZdZddedededef fddZd	ejf fd
dZ	  Z
S )Gemma3TextScaledWordEmbeddingz\
    This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
          ?num_embeddingsembedding_dimpadding_idxembed_scalec                    s*   t  ||| | jdt|dd d S )Nr;   F
persistent)super__init__register_bufferr+   tensor)selfr8   r9   r:   r;   	__class__r.   r/   r?      s   z&Gemma3TextScaledWordEmbedding.__init__	input_idsc                    s   t  || j| jj S N)r>   forwardr;   toweightdtype)rB   rE   rC   r.   r/   rG      s   z%Gemma3TextScaledWordEmbedding.forward)r7   )r'   r(   r)   r*   intfloatr?   r+   TensorrG   __classcell__r.   r.   rC   r/   r6      s     r6   c                       s*   e Zd Zdef fddZdd Z  ZS )	Gemma3MLPconfigc                    sr   t    || _|j| _|j| _tj| j| jdd| _tj| j| jdd| _tj| j| jdd| _	t
|j | _d S NFbias)r>   r?   rP   hidden_sizeZintermediate_sizennLinear	gate_projup_proj	down_projr
   Zhidden_activationact_fnrB   rP   rC   r.   r/   r?      s   
zGemma3MLP.__init__c                 C   s$   |  | | || | }|S rF   )rY   rZ   rW   rX   )rB   xrY   r.   r.   r/   rG      s    zGemma3MLP.forward)r'   r(   r)   r"   r?   rG   rN   r.   r.   rC   r/   rO      s    
rO   c                       s@   e Zd Zddedef fddZdd Zdd	 Zd
d Z  Z	S )Gemma3RMSNormư>dimepsc                    s&   t    || _tt|| _d S rF   )r>   r?   r`   rU   	Parameterr+   zerosrI   )rB   r_   r`   rC   r.   r/   r?      s   
zGemma3RMSNorm.__init__c                 C   s$   |t |djddd| j  S )Nr   T)Zkeepdim)r+   Zrsqrtpowmeanr`   )rB   r\   r.   r.   r/   _norm   s   $zGemma3RMSNorm._normc                 C   s*   |  | }|d| j   }||S )Nr7   )rf   rL   rI   type_as)rB   r\   outputr.   r.   r/   rG      s   
zGemma3RMSNorm.forwardc                 C   s   t | jj d| j S )Nz, eps=)tuplerI   shaper`   rB   r.   r.   r/   
extra_repr   s   zGemma3RMSNorm.extra_repr)r^   )
r'   r(   r)   rK   rL   r?   rf   rG   rl   rN   r.   r.   rC   r/   r]      s
    r]   c                       s8   e Zd Zddef fddZe edd Z  Z	S )Gemma3RotaryEmbeddingNrP   c                    s   t    t|dr|jd ur|jd|jd| _nd| _|j| _|j| _|| _	t
| j | _| | j	|\}| _| jd|dd | j| _d S )Nrope_scaling	rope_typetypedefaultinv_freqFr<   )r>   r?   hasattrrn   getro   Zmax_position_embeddingsZmax_seq_len_cachedZoriginal_max_seq_lenrP   r   Zrope_init_fnattention_scalingr@   rr   Zoriginal_inv_freq)rB   rP   devicerr   rC   r.   r/   r?      s   
zGemma3RotaryEmbedding.__init__c           
      C   s   | j d d d d f  |jd dd|j}|d d d d d f  }t|jjtr6|jjdkr6|jjnd}t	j
|dd+ | |  dd}t	j||fdd	}| | j }| | j }	W d    n1 smw   Y  |j|jd
|	j|jd
fS )Nr   rc   r    ZmpscpuF)device_typeenabledr   r_   rJ   )rr   rL   expandrj   rH   rv   
isinstancerp   strr+   Zautocast	transposecatcosru   sinrJ   )
rB   r\   position_idsZinv_freq_expandedZposition_ids_expandedrx   ZfreqsZembr   r   r.   r.   r/   rG      s   0&zGemma3RotaryEmbedding.forwardrF   )
r'   r(   r)   r"   r?   r+   no_gradr   rG   rN   r.   r.   rC   r/   rm      s
    rm   c                 C   sH   | dd| j d d f }| d| j d d df }tj| |fddS )z*Rotates half the hidden dims of the input..Nrc   r   rz   )rj   r+   r   )r\   x1Zx2r.   r.   r/   rotate_half   s   r   c                 C   sD   | |}| |}| | t| |  }|| t||  }||fS )a  Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`, *optional*):
            Deprecated and unused.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    )	unsqueezer   )qkr   r   r   Zunsqueeze_dimZq_embedZk_embedr.   r.   r/   apply_rotary_pos_emb   s
   

r   r4   n_repreturnc                 C   s^   | j \}}}}|dkr| S | dddddddddf |||||} | ||| ||S )z
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    r    N)rj   r|   reshape)r4   r   batchnum_key_value_headsslenhead_dimr.   r.   r/   	repeat_kv   s
   0r           modulequerykeyvalueattention_maskdropoutscalingsoftcapc                 K   s   |d u r	| j d }t|| j}	t|| j}
t||	dd| }|d ur2|| }t|}|| }|d urM|d d d d d d d |	jd f }|| }tj	j
|dtjd|j}tj	j||| jd}t||
}|dd }||fS )	N      r   r	   rc   )r_   rJ   )ptrainingr    )r   r   num_key_value_groupsr+   matmulr   tanhrj   rU   
functionalZsoftmaxZfloat32rH   rJ   r   r   
contiguous)r   r   r   r   r   r   r   r   kwargs
key_statesvalue_statesattn_weightscausal_maskattn_outputr.   r.   r/   eager_attention_forward  s"   

&r   c                       s   e Zd ZdZdedef fddZ		ddejdejd	e	ej d
e	e
 de	ej dee deeje	ej e	eej  f fddZ  ZS )Gemma3Attentionz=Multi-headed attention from 'Attention Is All You Need' paperrP   	layer_idxc                    s&  t    t|d |j | _|| _|| _t|d|j|j	 | _
|j	|j | _|jd | _| jj| _d| _tj|j|j	| j
 |jd| _tj|j|j| j
 |jd| _tj|j|j| j
 |jd| _tj|j	| j
 |j|jd| _| jj| _| jr||jnd | _t|j
|jd| _t|j
|jd| _d S )Nr    r   r   TrR   )r_   r`   )r>   r?   boolZsliding_window_pattern
is_slidingrP   r   getattrrT   Znum_attention_headsr   r   r   Zquery_pre_attn_scalarr   attention_dropoutZ	is_causalrU   rV   Zattention_biasq_projk_projv_projo_projZattn_logit_softcappingsliding_windowr]   rms_norm_epsq_normk_normrB   rP   r   rC   r.   r/   r?   -  s2   


zGemma3Attention.__init__Nr4   position_embeddingsr   past_key_valuecache_positionr   r   c                 K   s  |j d d }g |d| jR }| ||dd}	| ||dd}
| ||dd}| |	}	| |
}
|\}}t	|	|
||\}	}
|d ur|||| j
d}||
|| j|\}
}|d ur| jjdkr|j d }|
d d d d d |d d f |d d d d d |d d f }
}t}| jjdkr| jjdkr|dd	rtd
 nt| jj }|d ur||	}|| |	|
||f| jr| jnd| j| j
d|\}}|jg |dR   }| |}||fS )Nrc   r    r   )r   r   r   r   flash_attention_2eagerZsdpaoutput_attentionsFz`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.r   )r   r   r   )rj   r   r   viewr   r   r   r   r   r   r   updater   rP   _attn_implementationr   rt   loggerwarning_oncer   rH   r   r   r   r   r   r   )rB   r4   r   r   r   r   r   Zinput_shapeZhidden_shapeZquery_statesr   r   r   r   Zcache_kwargsseq_lenZattention_interfacer   r   r.   r.   r/   rG   J  sX   	


B
	

zGemma3Attention.forward)NN)r'   r(   r)   r*   r"   rK   r?   r+   rM   r   r   
LongTensorr   r   ri   rG   rN   r.   r.   rC   r/   r   *  s(    "r   c                       s   e Zd Zdedef fddZeddd								dd
ejdejdejde	ej de	ej
 de	e de	e de	e de	ej
 deeje	eejejf  f fddZ  ZS )Gemma3DecoderLayerrP   r   c                    s   t    || _|j| _|| _t||d| _t|| _t	| j|j
d| _t	| j|j
d| _t	| j|j
d| _t	| j|j
d| _| jj| _|j| _d S )N)rP   r   r`   )r>   r?   rP   rT   r   r   	self_attnrO   mlpr]   r   input_layernormpost_attention_layernormpre_feedforward_layernormpost_feedforward_layernormr   r   r   rC   r.   r/   r?     s   


zGemma3DecoderLayer.__init__Zlast_cache_positionz4.53.0)versionNFr4   position_embeddings_globalposition_embeddings_localr   r   r   r   	use_cacher   r   c
                 K   sv  | j rn|d urnt|	jd | j}| jjdkr"|d d | d f }nLt|jj	}tj
tj|tjd| j d}t|||}|	d | d }tj|dd}tjt	||jd |jd}||7 }|d d d d d d |f }|}| |}| jj r||}n|}| jd
||||||||	d	|
\}}| |}|| }|}| |}| |}| |}|| }|f}|r||f7 }|S )Nr   r   r{   Zdiagonalrc   r    )minrv   )r4   r   r   r   r   r   r   r   r.   )r   maxrj   r   rP   r   r+   finforJ   r   ZtrilZ	ones_liker   whereclamparangerv   r   r   r   r   r   r   )rB   r4   r   r   r   r   r   r   r   r   r   Zeffective_seq_len	min_dtypeZsliding_window_maskoffsetZmask_indexesZresidualr   Zself_attn_weightsoutputsr.   r.   r/   rG     sX   
	





zGemma3DecoderLayer.forward)NNNFFN)r'   r(   r)   r"   rK   r?   r   r+   rM   r   r   r   r   ri   r,   rG   rN   r.   r.   rC   r/   r     s<    
	
r   c                   @   sJ   e Zd ZeZdZdZg dZdgZdZ	dZ
dZdZdZdZdZdd ZdS )Gemma3PreTrainedModel T)r   ZSiglipVisionEmbeddingsZSiglipEncoderLayerZ#SiglipMultiheadAttentionPoolingHeadr3   c                 C   s   | j j}t|tjtjfr%|jjjd|d |j	d ur#|j	j
  d S d S t|tjrF|jjjd|d |jd urD|jj|j 
  d S d S t|trT|jjd d S t|tra|jj
  d S d S )Nr   )re   stdr7   )rP   Zinitializer_ranger}   rU   rV   ZConv2drI   dataZnormal_rS   Zzero_	Embeddingr:   r]   Zfill_Gemma3MultiModalProjectormm_input_projection_weight)rB   r   r   r.   r.   r/   _init_weights  s    



z#Gemma3PreTrainedModel._init_weightsN)r'   r(   r)   r!   config_classbase_model_prefixZsupports_gradient_checkpointingZ_no_split_modulesZ_skip_keys_device_placementZ_supports_flash_attn_2Z_supports_sdpaZ_supports_flex_attnZ_supports_cache_classZ_supports_quantized_cacheZ_supports_static_cacheZ_supports_attention_backendr   r.   r.   r.   r/   r     s    r   c                       s  e Zd ZeZdef fddZdd Zdd Zee										d!d	e
ej d
e
ej de
ej de
e de
ej de
e de
e de
e de
ej dee defddZe 	d"d
eejdf dejdejdedef
ddZed
ejdededejdejdefdd Z  ZS )#Gemma3TextModelrP   c                    s   t     j| _ j| _t j j| j| jjd d| _t	
 fddt jD | _t j jd| _t d| _d| _t   j _dd	i _t d| _|   d S )
N      ?)r;   c                    s   g | ]}t  |qS r.   )r   ).0r   rP   r.   r/   
<listcomp>  s    z,Gemma3TextModel.__init__.<locals>.<listcomp>r   r   Fro   rq   )r>   r?   pad_token_idr:   
vocab_sizer6   rT   rP   embed_tokensrU   Z
ModuleListrangenum_hidden_layerslayersr]   r   normrm   
rotary_embgradient_checkpointingcopydeepcopyZrope_local_base_freqZ
rope_thetarn   rotary_emb_local	post_initr[   rC   r   r/   r?     s"   

zGemma3TextModel.__init__c                 C      | j S rF   r   rk   r.   r.   r/   get_input_embeddings.     z$Gemma3TextModel.get_input_embeddingsc                 C   
   || _ d S rF   r   rB   r   r.   r.   r/   set_input_embeddings1     
z$Gemma3TextModel.set_input_embeddingsNrE   r   r   r3   inputs_embedsr   r   output_hidden_statesr   flash_attn_kwargsr   c
                 K   s  |d ur|n| j j}|d ur|n| j j}|d ur|n| j j}|d u |d uA r*td| jr9| jr9|r9td d}|d u rB| 	|}|r[|d u r[| js[|j
\}}}t| j |||jd}|	d u rw|d urg| nd}tj|||j
d  |jd}	|d u r|	d}| |||	||}|}| ||}| ||}|rdnd }|rdnd }| jd | j j D ]C}|r||f7 }| jr| jr| t|jfi |
|||||||||	
}n||f||||||||	d	|
}|d }|r||d f7 }q| |}|r||f7 }t||||d
S )N:You must specify exactly one of input_ids or inputs_embedszX`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.F)Zmax_batch_sizeZmax_cache_lenrJ   r   r    r   r.   )r   r   r   r   r   r   r   r   )last_hidden_stater3   r4   r5   )rP   r   r  r   
ValueErrorr   r   r   r   r   rj   r   rJ   get_seq_lengthr+   r   rv   r   _update_causal_maskr   r   r   r   Z_gradient_checkpointing_funcr   __call__r   r   )rB   rE   r   r   r3   r  r   r   r  r   r  
batch_sizer   _past_seen_tokensr   r4   r   r   Zall_hidden_statesZall_self_attnsZdecoder_layerZlayer_outputsr.   r.   r/   rG   4  s   

	



zGemma3TextModel.forwardFr#   input_tensorc              	   C   s   | j jdkr|S | j jdkrt|tjrt|}|S |j|j}}|jd }t|t	t
fr2| }	n|d ur;|jd n|jd }	| j|||	||||jd d}
|
S )Nr   Zflex_attentionr    rc   r   sequence_lengthtarget_lengthrJ   rv   r   r  )rP   r   r}   r+   rM   r$   rJ   rv   rj   r   r   get_max_cache_shape5_prepare_4d_causal_attention_mask_with_cache_position)rB   r   r  r   r3   r   rJ   rv   r  r  r   r.   r.   r/   r
    s*   

	z#Gemma3TextModel._update_causal_maskr  r  rJ   r  c                 K   D  | dur|   dkr| }|S t|j}tj||f|||jd}|dkr+tj|dd}|tj||jd|ddk9 }|ddddddf 	|ddd}| dur|
 }| jd }	|ddddddd|	f | ddddddf |j }
|
dk}
|ddddddd|	f |
||ddddddd|	f< |S 	aM  
        Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
        `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.

        Args:
            attention_mask (`torch.Tensor`):
                A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
                `(batch_size, 1, query_length, key_value_length)`.
            sequence_length (`int`):
                The sequence length being processed.
            target_length (`int`):
                The target length: when generating with static cache, the mask should be as long as the static cache,
                to account for the 0 padding, the part of the cache that is not filled yet.
            dtype (`torch.dtype`):
                The dtype to use for the 4D attention mask.
            cache_position (`torch.Tensor`):
                Indices depicting the position of the input sequence tokens in the sequence.
            batch_size (`torch.Tensor`):
                Batch size.
        N   Z
fill_valuerJ   rv   r    r   r   rc   r   r_   r+   r   r   fullrv   triur   r   r|   clonerj   rH   masked_fillr   r  r  rJ   r   r  r   r   r   mask_lengthpadding_maskr.   r.   r/   r    ,    $
6  zEGemma3TextModel._prepare_4d_causal_attention_mask_with_cache_position)	NNNNNNNNNF)r'   r(   r)   r"   r   r?   r   r  r   r   r   r+   r   rM   r   r,   r   r   r   r   rG   r   r   r
  staticmethodrK   rJ   r  rN   r.   r.   rC   r/   r     s    	
v&r   c                       s,  e Zd ZdgZddiZddgdgfiZeZdZdef fdd	Z	d
d Z
dd Zdd Zdd Zdd Zdd Zee											d)deej deej deej dee deej deej dee dee d ee d!eej d"eeejf d#efd$d%Z						&	d* fd'd(	Z  ZS )+Gemma3ForCausalLMlm_head.weightlm_headZcolwise_repr4   r2   language_modelrP   c                    s@   t  | t|| _|j| _tj|j|jdd| _| 	  d S rQ   )
r>   r?   r   modelr   rU   rV   rT   r&  r   r[   rC   r.   r/   r?     s
   
zGemma3ForCausalLM.__init__c                 C      | j jS rF   r(  r   rk   r.   r.   r/   r     s   z&Gemma3ForCausalLM.get_input_embeddingsc                 C   s   || j _d S rF   r*  r   r.   r.   r/   r    s   z&Gemma3ForCausalLM.set_input_embeddingsc                 C   r   rF   r&  rk   r.   r.   r/   get_output_embeddings"  r   z'Gemma3ForCausalLM.get_output_embeddingsc                 C   r   rF   r+  rB   Znew_embeddingsr.   r.   r/   set_output_embeddings%  r  z'Gemma3ForCausalLM.set_output_embeddingsc                 C   r   rF   r(  )rB   decoderr.   r.   r/   set_decoder(  r  zGemma3ForCausalLM.set_decoderc                 C   r   rF   r/  rk   r.   r.   r/   get_decoder+  r   zGemma3ForCausalLM.get_decoderNr   rE   r   r   r3   r  labelsr   r   r  r   logits_to_keepr   c                 K   s  | j r| jjdkrtd| jj d |dur|n| jj}|	dur$|	n| jj}	| jd||||||||	|
d	|}|j}t	|t
rHt| dn|}| |dd|ddf }| jjduro|| jj }t|}|| jj }d}|dur| j||| jfi |}t|||j|j|jdS )a'  
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

        Example:

        ```python
        >>> from transformers import AutoTokenizer, Gemma3ForCausalLM

        >>> model = Gemma3ForCausalLM.from_pretrained("google/gemma-2-9b")
        >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")

        >>> prompt = "What is your favorite condiment?"
        >>> inputs = tokenizer(prompt, return_tensors="pt")

        >>> # Generate
        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        "What is your favorite condiment?"
        ```r   zhIt is strongly recommended to train Gemma3 models with the `eager` attention implementation instead of `zp`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`.N)	rE   r   r   r3   r  r   r   r  r   )r1   r2   r3   r4   r5   r.   )r   rP   r   r   r   r   r  r(  r  r}   rK   slicer&  Zfinal_logit_softcappingr+   r   Zloss_functionr   r   r3   r4   r5   )rB   rE   r   r   r3   r  r3  r   r   r  r   r4  Zloss_kwargsr   r4   slice_indicesr2   r1   r.   r.   r/   rG   .  sN   (

zGemma3ForCausalLM.forwardTc	              
      s   t  j|f|||||||d|	}
|d u r|
dd }t|trb|jdkrb| jjdksb|
d d ur@|
d j\}}}|
d j	}n|
d j\}}|
d j	}| j
j||| | jjj|||d}||
d< |
S )	N)r3   r   r  r   r   r   r4  r4  r   r   r  rE   r  r   )r>   prepare_inputs_for_generationpopr}   r   ndimrP   r   rj   rv   r(  r  r  r&  rI   rJ   )rB   rE   r3   r   r  r   r   r   r4  r   model_inputsr  r  r  rv   rC   r.   r/   r7    sF   	

	z/Gemma3ForCausalLM.prepare_inputs_for_generation)NNNNNNNNNNr   )NNNNNTN)r'   r(   r)   _tied_weights_keysZ_tp_planZ_pp_planr"   r   r   r?   r   r  r,  r.  r1  r2  r   r   r   r+   r   rM   r   r,   r   r   rK   r   rG   r7  rN   r.   r.   rC   r/   r$    sv    		
Ur$  c                       s2   e Zd Zdef fddZdejfddZ  ZS )r   rP   c                    s   t    tt|jj|jj| _	t
|jj|jjd| _t|jj|jj | _t|jd | _| j| j | _tj| j| jd| _d S )Nr   r   )kernel_sizeZstride)r>   r?   rU   ra   r+   rb   vision_configrT   text_configr   r]   Zlayer_norm_epsmm_soft_emb_normrK   Z
image_sizeZ
patch_sizepatches_per_imageZmm_tokens_per_imageZtokens_per_sider<  Z	AvgPool2davg_poolr[   rC   r.   r/   r?     s   
z"Gemma3MultiModalProjector.__init__vision_outputsc           	      C   sv   |j \}}}|dd}|||| j| j}| }| |}|d}|dd}| |}t	|| j
}||S )Nr    r   )rj   r   r   r@  r   rA  flattenr?  r+   r   r   rg   )	rB   rB  r  r  Z
seq_lengthZreshaped_vision_outputsZpooled_vision_outputsZnormed_vision_outputsZprojected_vision_outputsr.   r.   r/   rG     s   



z!Gemma3MultiModalProjector.forward)	r'   r(   r)   r!   r?   r+   rM   rG   rN   r.   r.   rC   r/   r     s    r   zx
    The Base Gemma3 model which consists of a vision backbone and a language model withou language modeling head.,
    )Zcustom_introc                !       s  e Zd ZddiZdef fddZdd Zdd	 Z	
d!defddZ	de
jde
jfddZee													d"de
jde
jdee
j dee
j deeee
j ef  dee
j dee
j dee
j dee
j dee dee dee dee deeef fdd Z  ZS )#Gemma3Modelzlanguage_model.modelr'  rP   c                    sj   t  | tj|jd| _t|| _|jj	| _	tj|jd}|| _
| jjd ur,| jjnd| _|   d S )Nr   rc   )r>   r?   r   Zfrom_configr=  vision_towerr   multi_modal_projectorr>  r   r'  rP   r   r   )rB   rP   r'  rC   r.   r/   r?     s   

zGemma3Model.__init__c                 C   
   | j  S rF   )r'  r   rk   r.   r.   r/   r     r  z Gemma3Model.get_input_embeddingsc                 C      | j | d S rF   )r'  r  r   r.   r.   r/   r       z Gemma3Model.set_input_embeddingsFis_trainingc                 C   s  | j jjdkr	|S |d ur| dkr|S t|t}t| jj	}|j
d d \}	}
|r1| }nt|tr;| }nt|tjrF|j
d n|d |
 d }|d urZ| dkrZ|S tj|
|f|| j|jd}|
dkrrtj|dd}|tj||jd	|ddk9 }|d d d d d d f |	ddd}|d ur |
dkr |d|dk}d
||dk< |dk}|tjj|dddd d d df  @ }tj| ddd }t||t|d}|d|dk}d
||dk< ||@ dj|jtjd}| }|d d d d d d d |
f |d|d d d d d d d |
f< |d urt| }|j
d }|d d d d d d d |f |d d d d d d f |j }|dk}|d d d d d d d |f |||d d d d d d d |f< |S )Nr   r  r   rc   r   r    r  r   r   F)r    r   )r   rz   r{   r   ) rP   r>  r   r_   r}   r   r+   r   rJ   r   rj   r  r   rM   r  rv   r  r   r   r|   r   rU   r   padZcumsumrK   r   Z	full_likerH   r   r  r  )rB   r   token_type_idsr3   r   r  rJ  Zusing_static_cacher   Zinputs_lead_dimr  r  r   Ztoken_type_maskZis_imageZnew_image_startZimage_group_idsZsame_image_maskZ
image_maskr  r   r.   r.   r/   r
    s^   	




 $(  

@  zGemma3Model._update_causal_maskpixel_valuesr   c                 C   s   | j |dj}| |}|S )a  
        Projects the last hidden state from the vision model into language model space.

        Args:
            pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
               The tensors corresponding to the input images.
        Returns:
            image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
        )rM  )rE  r  rF  )rB   rM  rB  image_featuresr.   r.   r/   get_image_featuresG  s   

zGemma3Model.get_image_featuresNrE   r   r   r3   rL  r   r  r3  r   r   r  return_dictc                 K   s.  |du |duA rt d|dur|n| jj}|dur|n| jj}|dur&|n| jj}|duo1|	du}|durL| jj| jkrL|| jjk}| }d||< n|}|du rX|  |}|du rt|durd|	 nd}t
j|||jd  |jd}|dur| |}|du r||  t
j| jjt
j|jdk}n|| jjkd}|||j}t s||  | kr|jddjddd }t d	| d
|jd |jd   d||j|j}|||}| ||||||}| jd|||||
||d|d	|}t|j|
r|jnd|j|j|dur|dS ddS )a]  
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
            config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.

        Example:

        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration

        >>> model = Gemma3ForConditionalGeneration.from_pretrained("google/gemma32-3b-mix-224")
        >>> processor = AutoProcessor.from_pretrained("google/gemma32-3b-mix-224")

        >>> prompt = "Where is the cat standing?"
        >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> inputs = processor(images=image, text=prompt,  return_tensors="pt")

        >>> # Generate
        >>> generate_ids = model.generate(**inputs,)
        >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        "Where is the cat standing?\nsnow"
        ```Nr  r   r    r   )rJ   rv   rc   rz   zVNumber of images does not match number of special image tokens in the input text. Got z image tokens in the text but z tokens from image embeddings.T)	r   r   r3   r  r   r   r  rP  r   )r  r3   r4   r5   r&   r.   ) r  rP   r   r  use_return_dictZimage_token_idr   r  r   r	  r+   r   rj   rv   rO  rA   longr   Z	expand_asrH   r   ZnumelsumrJ   Zmasked_scatterr
  r'  r%   r  r3   r4   r5   )rB   rE   rM  r   r   r3   rL  r   r  r3  r   r   r  rP  	lm_kwargsrJ  Zspecial_image_maskZllm_input_idsr  rN  Zimage_tokens_in_textr   r   r.   r.   r/   rG   U  s~   .


zGemma3Model.forwardr"  )NNNNNNNNNNNNN)r'   r(   r)   _checkpoint_conversion_mappingr!   r?   r   r  r   r
  r+   rM   rO  r   r   r   r,   r   r   r   r   r   r%   rG   rN   r.   r.   rC   r/   rD    sl    

P	

rD  zy
    The Base Gemma3 model which consists of a vision backbone and a language model without language modeling head.,
    c                "       s  e Zd ZdddddZdgZdef fdd	Zd
d Zdd Zdd Z	dd Z
edd Zedd Zedd Ze														d4dejdejdeej deej deeeej ef  deej d eej d!eej d"eej d#ee d$ee d%ee d&ee d'eeejf d(eeef fd)d*Z								+		d5 fd,d-	Zedejd.ed/ed0ejd ejd1efd2d3Z   Z!S )6Gemma3ForConditionalGenerationzmodel.language_modelzmodel.vision_towerzmodel.multi_modal_projectorr&  )z^language_model.modelz^vision_towerz^multi_modal_projectorz^language_model.lm_headr%  rP   c                    s<   t  | t|| _tj|jj|jjdd| _	| 
  d S rQ   )r>   r?   rD  r(  rU   rV   r>  rT   r   r&  r   r[   rC   r.   r/   r?     s   
z'Gemma3ForConditionalGeneration.__init__c                 C   rG  rF   )r(  r   rk   r.   r.   r/   r     r  z3Gemma3ForConditionalGeneration.get_input_embeddingsc                 C   rH  rF   )r(  r  r   r.   r.   r/   r    rI  z3Gemma3ForConditionalGeneration.set_input_embeddingsc                 C   r   rF   r+  rk   r.   r.   r/   r,    r   z4Gemma3ForConditionalGeneration.get_output_embeddingsc                 C   r   rF   r+  r-  r.   r.   r/   r.    r  z4Gemma3ForConditionalGeneration.set_output_embeddingsc                 C   r)  rF   )r(  r'  rk   r.   r.   r/   r'       z-Gemma3ForConditionalGeneration.language_modelc                 C   r)  rF   )r(  rE  rk   r.   r.   r/   rE    rW  z+Gemma3ForConditionalGeneration.vision_towerc                 C   r)  rF   )r(  rF  rk   r.   r.   r/   rF    rW  z4Gemma3ForConditionalGeneration.multi_modal_projectorNr   rE   rM  r   r   r3   rL  r   r  r3  r   r   r  rP  r4  r   c                 K   s  |dur|n| j j}|dur|n| j j}|dur|n| j j}| jd||||||||
|	||||d|}|d }t|trCt| dn|}| |dd|ddf }d}|	dur|	 }|dddddf }|	dddf }|dur|dd|j
d  df |j}|||jdk  }|||jdk  }n| }| }t }|d| j jj}|d|j}|||}|s|f|dd  }|dur|f| S |S t|||j|j|j|jdS )	a  
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
            config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.

        Example:

        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration

        >>> model = Gemma3ForConditionalGeneration.from_pretrained("google/gemma-3-4b-it")
        >>> processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it")

        >>> messages = [
        ...     {
        ...         "role": "system",
        ...         "content": [
        ...             {"type": "text", "text": "You are a helpful assistant."}
        ...         ]
        ...     },
        ...     {
        ...         "role": "user", "content": [
        ...             {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"},
        ...             {"type": "text", "text": "Where is the cat standing?"},
        ...         ]
        ...     },
        ... ]

        >>> inputs = processor.apply_chat_template(
        ...     messages,
        ...     tokenizer=True,
        ...     return_dict=True,
        ...     return_tensors="pt",
        ...     add_generation_prompt=True
        ... )
        >>> # Generate
        >>> generate_ids = model.generate(**inputs)
        >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        "user\nYou are a helpful assistant.\n\n\n\n\n\nWhere is the cat standing?\nmodel\nBased on the image, the cat is standing in a snowy area, likely outdoors. It appears to"
        ```
        N)rE   rM  rL  r   r   r3   r  r   r3  r   r  rP  r   r   .rc   r    )r1   r2   r3   r4   r5   r&   r.   )rP   r   r  rQ  r(  r}   rK   r5  r&  rL   rj   rH   rv   r   rU   ZCrossEntropyLossr   r>  r   r0   r3   r4   r5   r&   )rB   rE   rM  r   r   r3   rL  r   r  r3  r   r   r  rP  r4  rT  r   r4   r6  r2   r1   Zshift_logitsZshift_labelsZshift_attention_maskZloss_fctZflat_logitsZflat_labelsrh   r.   r.   r/   rG     sd   @$
z&Gemma3ForConditionalGeneration.forwardTc                    s   t  j|f||||||	|
|d|}|d dkr||d< |d uo$|d u}|d dkrGt|trG|d ur6|n|}| j||||||}||d< |S )N)r3   r  r   r   r   r   r4  rL  r   rM  r   )r>   r7  r}   r   r(  r
  )rB   rE   r3   r  r   r   rM  r   rL  r   r4  r3  r   r:  rJ  r  r   rC   r.   r/   r7  z  s0   
z<Gemma3ForConditionalGeneration.prepare_inputs_for_generationr  r  rJ   r  c                 K   r  r  r  r  r.   r.   r/   r    r!  zTGemma3ForConditionalGeneration._prepare_4d_causal_attention_mask_with_cache_position)NNNNNNNNNNNNNr   )
NNNNNNNTNN)"r'   r(   r)   rU  r;  r!   r?   r   r  r,  r.  propertyr'  rE  rF  r   r+   r   r,   r   rM   r   r   r   r   rK   r   r0   rG   r7  r#  rJ   r  rN   r.   r.   rC   r/   rV    s    


	

 +rV  )r   r   r$  rV  rD  )Nr    )r   NN)Qr   collections.abcr   dataclassesr   	functoolsr   typingr   r   r   r   r+   Ztorch.nnrU   Zactivationsr
   Zcache_utilsr   r   r   Z
generationr   Zmodeling_flash_attention_utilsr   Zmodeling_outputsr   r   Zmodeling_rope_utilsr   r   Zmodeling_utilsr   r   Zprocessing_utilsr   utilsr   r   r   r   r   r   Zutils.deprecationr   autor   Zconfiguration_gemma3r!   r"   Z!torch.nn.attention.flex_attentionr#   Zintegrations.flex_attentionr$   Z
get_loggerr'   r   r%   r0   r   r6   ModulerO   r]   rm   r   r   rM   rK   r   rL   r   r   r   r   r   r$  r   rD  rV  __all__r.   r.   r.   r/   <module>   s    
 '"


#c^$ { .$ m  