o
    ZhU                    @   s  d dl Z d dlZd dlmZ d dlmZmZmZmZm	Z	m
Z
mZ d dlZd dlmZ d dlmZmZmZ ddlmZ ddlmZmZ dd	lmZ dd
lmZ ddlmZ ddlmZmZm Z  ddl!m"Z"m#Z# ddl$m%Z%m&Z& ddl'm(Z( ddl)m*Z*m+Z+ ddl,m-Z-m.Z. ddl/m0Z0 e. rd dl1m2Z2 d dl3m4Z4m5Z5 nd\Z2Z4Z5e- rd dl6m7Z7m8Z8 nd\Z8Z7e+9e:Z;G dd dejj<Z=G dd dej<Z>G dd deZ?G dd  d ej<Z@d!ejAd"eBd#ejAfd$d%ZC	&dSd'ej<d(ejAd)ejAd*ejAd+e	ejA d,eDd-eDfd.d/ZEd0d1 ZFdTd2d3ZGG d4d5 d5ej<ZHd6ejAd7eBfd8d9ZId:d; ZJd<d= ZKeLe2e7e8fZMG d>d? d?ej<ZNG d@dA dAej<ZOG dBdC dCej<ZPG dDdE dEej<ZQG dFdG dGej<ZRG dHdI dIe&ZSe*G dJdK dKeSZTG dLdM dMeSeZUe*dNdOG dPdQ dQeSZVg dRZWdS )U    N)cycle)AnyCallableDictListOptionalTupleUnion)nn)BCEWithLogitsLossCrossEntropyLossMSELoss   )ACT2FN)CacheDynamicCache)GenerationMixin)AttentionMaskConverter)FlashAttentionKwargs)BaseModelOutputWithPastCausalLMOutputWithPast SequenceClassifierOutputWithPast)ROPE_INIT_FUNCTIONSdynamic_rope_update)ALL_ATTENTION_FUNCTIONSPreTrainedModel)Unpack)auto_docstringlogging)is_causal_conv1d_availableis_mamba_ssm_available   )Zamba2Config)selective_state_update)mamba_chunk_scan_combined mamba_split_conv1d_scan_combinedNNN)causal_conv1d_fncausal_conv1d_updateNNc                       s(   e Zd Zd fdd	ZdddZ  ZS )	Zamba2RMSNormGatedư>c                    s,   t    tt|| _|| _|| _d S N)	super__init__r
   	Parametertorchonesweightvariance_epsilon
group_size)selfhidden_sizer4   eps	__class__ Y/var/www/auris/lib/python3.10/site-packages/transformers/models/zamba2/modeling_zamba2.pyr.   =   s   

zZamba2RMSNormGated.__init__Nc           	      C   s   |j }|tj}|d ur|tj|tj }|j^ }}|| j }|j	g ||| jR  }|
djddd}|t|| j  }|j	g ||| j R  }| j|| S N   T)Zkeepdim)dtypetor0   float32r
   
functionalsilushaper4   viewpowmeanrsqrtr3   r2   )	r5   hidden_statesgateinput_dtypeZprefix_dimsZlast_dimZgroup_countZhidden_states_groupvariancer:   r:   r;   forwardC   s   
zZamba2RMSNormGated.forwardr+   r,   )__name__
__module____qualname__r.   rM   __classcell__r:   r:   r8   r;   r*   <   s    r*   c                       s.   e Zd Zd fdd	Zdd Zdd Z  ZS )	Zamba2RMSNormr+   c                    s&   t    tt|| _|| _dS )z<
        Zamba2RMSNorm is equivalent to T5LayerNorm
        N)r-   r.   r
   r/   r0   r1   r2   r3   )r5   r6   r7   r8   r:   r;   r.   R   s   

zZamba2RMSNorm.__init__c                 C   sJ   |j }|tj}|djddd}|t|| j  }| j|| S r<   )	r?   r@   r0   rA   rF   rG   rH   r3   r2   )r5   rI   rK   rL   r:   r:   r;   rM   Z   s
   zZamba2RMSNorm.forwardc                 C   s   t | jj d| j S )Nz, eps=)tupler2   rD   r3   r5   r:   r:   r;   
extra_repra   s   zZamba2RMSNorm.extra_reprrN   )rO   rP   rQ   r.   rM   rV   rR   r:   r:   r8   r;   rS   Q   s    rS   c                   @   s  e Zd ZdZejdfdededejde	e
 fddZ	d"d	ejd
ejdede	ee
ef  deejejf f
ddZdejfddZd#de	e defddZdeeej eej f fddZed"de	eeej   ddfddZdedejdejdejfddZd d! ZdS )$Zamba2HybridDynamicCachea  
    A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache
    (which has a constant shape regardless of seq_len).

    This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states`
    and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor
    For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`,
    while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors).
    For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors),
    while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`,
    and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`.
    Nconfig
batch_sizer?   devicec              	      s  || _ |j| _d| _t|j|j | _|j| _|j	| _
|j| _g | _i | _i | _i | _i | _i | _t|jD ]7}tj | jd|j |j  | j
|d| j|< tj | j|j| j|d| j|< | j| dkrm| j| q6 fddt|jD | _ fddt|jD | _d S )NFr=   rZ   r?   hybridc                        g | ]}t jg g  d qS rZ   r0   Ztensor.0_rY   rZ   r:   r;   
<listcomp>        z5Zamba2HybridDynamicCache.__init__.<locals>.<listcomp>c                    r]   r^   r`   ra   rd   r:   r;   re      rf   )r?   layers_block_typehas_previous_stateintmamba_expandr6   intermediate_sizemamba_d_statessm_state_sizemamba_d_convconv_kernel_sizen_mamba_headstransformer_layersZ_modules_parameters_buffersconv_states
ssm_statesrangenum_hidden_layersr0   zerosmamba_ngroupsmamba_headdimappend	key_cachevalue_cache)r5   rX   rY   r?   rZ   ir:   rd   r;   r.   s   s:    z!Zamba2HybridDynamicCache.__init__
key_statesvalue_states	layer_idxcache_kwargsreturnc                 C   sz   | j | jd dkr|| j |< || j|< ntj| j | |gdd| j |< tj| j| |gdd| j|< | j | | j| fS )Nr>   r   r=   dim)r|   rD   r}   r0   cat)r5   r   r   r   r   r:   r:   r;   update   s   
zZamba2HybridDynamicCache.updatebeam_idxc                 C   s   t t| jD ]V}| j| j}| j| d||| j|< | j| j}| j| d||| j|< | j| j}| j| d||| j|< | j| j}| j| d||| j|< qdS )zDReorders the cache for beam search, given the selected beam indices.r   N)	rv   lenr|   rZ   Zindex_selectr@   r}   rt   ru   )r5   r   r   rZ   r:   r:   r;   reorder_cache   s    z&Zamba2HybridDynamicCache.reorder_cacher   c                 C   sL   || j vr
| j d n|}t| j|ks| j|  dkrdS | j| jd S )zYReturns the sequence length of the cached states. A layer index can be optionally passed.r   )rq   r   r|   ZnumelrD   )r5   r   r:   r:   r;   get_seq_length   s    z'Zamba2HybridDynamicCache.get_seq_lengthc                 C      t dNzAZamba2HybridDynamicCache does not have a legacy cache equivalent.NotImplementedErrorrU   r:   r:   r;   to_legacy_cache      z(Zamba2HybridDynamicCache.to_legacy_cachepast_key_valuesr   c                 C   r   r   r   )clsr   r:   r:   r;   from_legacy_cache   s   z*Zamba2HybridDynamicCache.from_legacy_cachenew_conv_statecache_positionc                 C   sr   | j | }|d| jd }|jddd}||j|d d d d |f< | j |   | j |  |7  < | j | S )Nr   r!   r>   Zshiftsdims)rt   clampro   rollr@   rZ   zero_)r5   r   r   r   
conv_stater:   r:   r;   update_conv_state   s   

z*Zamba2HybridDynamicCache.update_conv_statec                 C   s   | j   | j  d S r,   )rt   r   ru   rU   r:   r:   r;   reset   s   
zZamba2HybridDynamicCache.resetr,   )r   )rO   rP   rQ   __doc__r0   Zfloat16r"   ri   r?   r   strr.   Tensorr   r   r   r   
LongTensorr   r   r   classmethodFloatTensorr   r   r   r:   r:   r:   r;   rW   e   sN    
%
"$
rW   c                       s:   e Zd Z	ddef fddZe edd Z  Z	S )Zamba2RotaryEmbeddingNrX   c                    s   t    t|dr|jd ur|jd|jd| _nd| _|j| _|j| _|| _	t
| j | _| j||j|jd\}| _| jd|dd | j| _d S )	Nrope_scaling	rope_typetypedefault)rZ   baser   inv_freqF)
persistent)r-   r.   hasattrr   getr   max_position_embeddingsZmax_seq_len_cachedZoriginal_max_seq_lenrX   r   Zrope_init_fnZ
rope_thetaattention_head_dimattention_scalingZregister_bufferr   Zoriginal_inv_freq)r5   rX   rZ   r   r8   r:   r;   r.      s   

zZamba2RotaryEmbedding.__init__c           
      C   s   | j d d d d f  |jd dd|j}|d d d d d f  }t|jjtr6|jjdkr6|jjnd}t	j
|dd+ | |  dd}t	j||fdd	}| | j }| | j }	W d    n1 smw   Y  |j|jd
|	j|jd
fS )Nr   r>   r!   ZmpscpuF)device_typeenabledr=   r   r?   )r   floatexpandrD   r@   rZ   
isinstancer   r   r0   Zautocast	transposer   cosr   sinr?   )
r5   xposition_idsZinv_freq_expandedZposition_ids_expandedr   ZfreqsZembr   r   r:   r:   r;   rM      s   0&zZamba2RotaryEmbedding.forwardr,   )
rO   rP   rQ   r"   r.   r0   Zno_gradr   rM   rR   r:   r:   r8   r;   r      s    r   rI   n_repr   c                 C   s^   | j \}}}}|dkr| S | dddddddddf |||||} | ||| ||S )z
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    r!   N)rD   r   reshape)rI   r   batchnum_key_value_headsslenhead_dimr:   r:   r;   	repeat_kv   s
   0r           modulequerykeyvalueattention_maskscalingdropoutc                 K   s   t || j}t || j}	t||dd| }
|d ur3|d d d d d d d |jd f }|
| }
tjj|
dtj	d
|j}
tjj|
|| jd}
t|
|	}|dd }||
fS )Nr=   r   r   r>   )r   r?   )ptrainingr!   )r   num_key_value_groupsr0   matmulr   rD   r
   rB   ZsoftmaxrA   r@   r?   r   r   
contiguous)r   r   r   r   r   r   r   kwargsr   r   attn_weightscausal_maskattn_outputr:   r:   r;   eager_attention_forward  s   
&r   c                 C   sH   | dd| j d d f }| d| j d d df }tj| |fddS )z*Rotates half the hidden dims of the input..Nr>   r=   r   )rD   r0   r   )r   x1Zx2r:   r:   r;   rotate_half  s   r   c                 C   sD   | |}| |}| | t| |  }|| t||  }||fS )a  Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`, *optional*):
            Deprecated and unused.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    )	unsqueezer   )qkr   r   r   Zunsqueeze_dimZq_embedZk_embedr:   r:   r;   apply_rotary_pos_emb&  s
   

r   c                       s   e Zd ZdZ			ddedee dee dee f fddZ			dd	ej	ded
eej	 dee
 deeej	ej	f  dee deej	eej	 eeej	  f fddZ  ZS )Zamba2Attentiona  
    Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
    and "Generating Long Sequences with Sparse Transformers".

    Adapted from transformers.models.mistral.modeling_mistral.MistralAttention:
    The input dimension here is attention_hidden_size = 2 * hidden_size, and head_dim = attention_hidden_size // num_heads.
    The extra factor of 2 comes from the input being the concatenation of original_hidden_states with the output of the previous (mamba) layer
    (see fig. 2 in https://arxiv.org/pdf/2405.16712).
    Additionally, replaced
    attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) with
    attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim/2)

    Multi-headed attention from 'Attention Is All You Need' paper.

    Adapted from transformers.models.mistral.modeling_mistral.MistralAttention:
    The input dimension here is attention_hidden_size = 2 * hidden_size, and head_dim = attention_hidden_size // num_heads.
    The extra factor of 2 comes from the input being the concatenation of original_hidden_states with the output of the previous (mamba) layer
    (see fig. 2 in https://arxiv.org/pdf/2405.16712).
    Additionally, replaced
    attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) with
    attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim/2)
    Finally, this attention layer contributes to tied transformer blocks aimed to increasing compute without increasing model size. Because this
    layer is tied, un-tied adapters (formally the same as LoRA but used in the base model) modules are added to the q, k, v projectors to increase
    expressivity with a small memory overhead (see Fig. 2 of https://arxiv.org/pdf/2411.15242).
    NrX   r   num_fwd_mem_blocksblock_idc           	   	      s  t    || _|| _|j| _|j| _|j|j | _	|j
| _
| jd d | _d| _|j| _tj|j|j| j dd| _tj|j|j| j dd| _tj|j|j| j dd| _tj|j| j |jdd| _|| _|j| _|| _|jrtg | _tg | _tg | _t| jD ]p}||j |krt tj| j| jj!ddtj| jj!| jdd}t tj| j| jj!ddtj| jj!| jdd}t tj| j| jj!ddtj| jj!| jdd}nt" }t" }t" }| j#| | j#| | j#| qdd t$| jD | _%d S )Nr=   g      TFbiasc                 S      i | ]\}}||qS r:   r:   rb   indexr   r:   r:   r;   
<dictcomp>      z,Zamba2Attention.__init__.<locals>.<dictcomp>)&r-   r.   rX   r   attention_hidden_sizer   r   Znum_attention_headsr   r   r   r   Z	is_causalattention_dropoutr
   Linearq_projk_projv_projr6   o_projr   hybrid_layer_idslayer_block_mapr   use_shared_attention_adapter
ModuleListlinear_q_adapter_listlinear_k_adapter_listlinear_v_adapter_listrv   num_mem_blocks
Sequentialadapter_rankIdentityr{   	enumerate	layer_dic)	r5   rX   r   r   r   r~   Zlinear_q_adapterZlinear_k_adapterZlinear_v_adapterr8   r:   r;   r.   \  sT   
zZamba2Attention.__init__rI   r   past_key_valueposition_embeddingsr   r   c                 K   s  |j d d }g |d| jR }| |}	| |}
| |}| jjrD| j| }|	| j| | }	|
| j	| | }
|| j
| | }|	|dd}	|
|dd}
||dd}| jjrp|\}}t|	|
||\}	}
|d ur}||
||\}
}t}| jjdkr| jjdkr|ddrtd nt| jj }|| |	|
||f| jsd	n| j| jd
|\}}|jg |dR   }| |}||fS )Nr>   r!   r=   eagersdpaoutput_attentionsFz`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.r   )r   r   )rD   r   r   r   r   rX   r   r   r   r   r   rE   r   use_mem_roper   r   r   _attn_implementationr   loggerwarning_oncer   r   r   r   r   r   r   )r5   rI   r   r   r   r   r   Zinput_shapeZhidden_shapeZquery_statesr   r   Zadapter_layer_idxr   r   Zattention_interfacer   r   r:   r:   r;   rM     sP   	





zZamba2Attention.forwardr&   )rO   rP   rQ   r   r"   r   ri   r.   r0   r   rW   r   r   r   rM   rR   r:   r:   r8   r;   r   A  s@    <r   input_tensorpad_sizec                 C   sH   t | jdkrddddd|ddfnddd|ddf}tjjj| |dddS )z
    Padding x tensor with `pad_size` on the seq_len dim (dim=1)

    Assumes that we only have tensors of either size 4 or 3
       r   Zconstant)moder   )r   rD   r0   r
   rB   pad)r  r  Z	pad_shaper:   r:   r;   pad_tensor_by_size  s   2r
  c                 C   sX   t | |} t| jdkr| | jd d|| jd S | | jd d|| jd | jd S )z
    Padding input_tensor with `pad_size` on the seq_len dim (dim=1) and
    simultaneously splitting it into chunk sequences.

    Assumes that we only have tensors of either size 4 or 3
    r   r   r>   r=   )r
  r   rD   r   )r  r  
chunk_sizer:   r:   r;   reshape_into_chunks  s   
r  c                 C   s   |  d}| d jg |   |R  } tjtj||| jtjddd}| | d} tj| dd}tjtj||| jtjddd}|| tj	 }|S )zo
    More stable segment sum calculation. Uses cumulative sums and masking instead of direct subtractions.
    r>   .Nr[   Zdiagonalr   r   r   )
sizer   r0   Ztrilr1   rZ   boolmasked_fillcumsuminf)r  r  maskZtensor_segsumr:   r:   r;   segment_sum  s   
  r  c                       s   e Zd ZdZddedee f fddZ		ddej	dee
 d	eej	 fd
dZddee
 d	eej	 fddZ		ddee
 d	eej	 fddZ  ZS )Zamba2MambaMixeru  
    Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
    A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
    ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
    and is why Mamba is called **selective** state spaces)
    NrX   r   c                    s  t    || _|j| _|j| _|j| _t|j	| j | _
|| _|j| _d| _t | _|j| _|j| _|j| _| jj| _|j| _|j| _|j| _|j| _| j
d| j | j  | _tj| j| jd|j| j|jd d| _| j
| j | j }tj| j||j d| _!t"t#$| j| _%t#&d| jd }t"t#'|| _(d| j(_)t*| j
| j
| j dd| _+t"t#$| j| _,d| j,_)tj| j
| j|j d| _-t.st/0d	 d S d S )
NrC   r=   Tr!   )Zin_channelsZout_channelsr   Zkernel_sizegroupspaddingr   gh㈵>)r4   r7   a  The fast path is not available because on of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)` is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and https://github.com/Dao-AILab/causal-conv1d)1r-   r.   rX   r6   rl   rm   rn   ro   ri   rj   rk   r   use_conv_bias
activationr
   ZSiLUactuse_mem_eff_pathry   n_groupsrz   r   rp   	num_headsr  time_step_limittime_step_mintime_step_maxconv_dimConv1dconv1dr   add_bias_linearin_projr/   r0   r1   dt_biasarangelogA_logZ_no_weight_decayr*   normDout_projis_fast_path_availabler  r  )r5   rX   r   Zprojection_sizeAr8   r:   r;   r.     s`   



	zZamba2MambaMixer.__init__rI   cache_paramsr   c                 C   sF  |j \}}}| j| j }d| j d| j | j  | j }|d ur|jr| |d}	|	j d | d }
|
|
| j| j| jg}t	j
|	|dd\}}}}}t||j| j | jjd| jj| j}t	j
|| j||gdd\}}}t	| j  }|d d d df d d d d d f d| j| jjt	jd}|d d d d d f dd| j}| jd d d df d| j}| jd d d df d| j}||| j|j d | j }||| j|j d | j }||| j| j}t|j| j ||||||d |dd
}||| j| j }| ||}| |d d d df }|S |d ur;t	 |dks;|j!}||d d d d d f  |}| |}t	| j  }| j"d u rQi nd	| j"i}|d urct	 |dk}nd}| j#r| j$r|d u r|rt%|| jjd| jj| j|f| j| j&d | j| jj| jj'| jj| jj| j| jd
dd|\}}|S t	j
|| j| j| jgdd\}}}|d ur|(dd}t)j*+|| j,|j d  df}|j| j -| t.d u s| jdvr| /| |(dd(ddd d d |f }n t.|(dd| jjd| jj| jd(ddd d d |f }t	j
|| j||gdd\}}}|d urNt	 |dksN|j!}||d d d d d f  |}t0|||d| j|||||| jd|||| jdf| j&| jd d d| jdd|\}}|d ur|d ur|j| j -| |||d}| ||}| |}|S )Nr=   r!   r>   r   .r   T)zr'  dt_softplusZdt_limitF)r,  r  seq_idxr  Zrmsnorm_weightZrmsnorm_epsZoutproj_weightZoutproj_biasZheaddimZngroupsZnorm_before_gatereturn_final_statesr   )rC   Zswish)r   r2   r   r  )r  r,  r1  r3  r4  r'  r2  )1rD   r  rm   rk   r  rh   r&  squeezer"  r0   splitr(   rt   r   r$  r2   r   r  expr*  r   r   r   r@   rA   r'  r,  rE   r#   ru   r+  r-  allr?   r  r  r   r%   r  r3   r   r
   rB   r	  ro   copy_r'   r  r$   )r5   rI   r0  r   rY   seq_lenrc   Zgroups_time_state_sizeZd_to_removeZin_projected_statesd_mlpZsplit_projection_dimrJ   Zhidden_states_B_CdtBCr/  r'  r,  Zhidden_states_reshapedoutr?   projected_statesZdt_limit_kwargsZinput_not_masked	ssm_stateZ	time_stepZhidden_states_B_C_tr   scan_outputr:   r:   r;   cuda_kernels_forwardP  s   

<"
] 

 
L
(

 

z%Zamba2MambaMixer.cuda_kernels_forwardc           1   
      s	  |j \}}}|j}|d ur|jr|d}n |d ur4t|dks4||d d d d d f  |}|}|j d dj  dj	 j
  j d }	|j|	|	jjjgdd\}}}
}}|d ur8|jj  }||j}|jr|
d}
|jj }tj|ddd}|jdkr|d d dd d f n||d d d d df< |jj | tj||jjjd d dd d f  dd}jr|jj7 }||d d d df }n||dd}tj |j!|j d  df}|jj | |ddd d d |d d f }|d ur7t|dks7|j}||d d d d d f  |}n&tj"|jj#j
f|j|d	}|dddd |f dd}tj|jj	j
 j	j
 gdd\}}}t$j%&  }|d ur|jr|jdkr|d d d df n|d d dd d f d d d df }|dd'||j d j#}j(d
 'j(j d j#}tjj)|||j }t*|j+}|d 'jj#j
jtj,d}t$|d
 | }|-|j	ddd d d f }|'|j	jj	 |j d . }|-|d|j d }|d
 |dd d d f  }|-|dj#}||d
  }|jj |jj | |  |-|j	ddd d d f }|'|j	jj	 |j d . }|-|d|j d }|jj |j}|/|j j#j
}|/|j j
d}t0||}|/|jj#}j1d
 'j1j d j#}|||  |j}|-|dd d d df }ntj)|j( }t*|j+}|-||dj#& }|-||dj
& }|-||dj
& }|j2jj	 djd}|j2jj	 djd}j3|j3  j3  j1d
 t4|  }||d
  }||j| } fdd||||fD \}}}}|5dddd}tj6|dd}t$t7|}|d d d d d d d d d d d f |d d d d d d d d d d d f  }|jdd}|d
 |5dddddd
  } | jdd}!|!d
 |d d d d d f  d}"t$|d d d d d d dd f | }#||#5ddddd
  }$|$5dddddd
 |5ddddddd d d f  jdd5ddddd}%|d ur|jr|jj d d d df }&nt8|%d d d df }&tj9|&|%gdd}%t$t7tj |d d d d d d df d}'|%5ddddd}(|'d |(d d d d d df  jdd})|)5ddddd}*|*d d d df |*d d df }%}t$|}+|dd d d f |%d d d d d df  },|+5dddd}-|,d|-d
  }.|"|. }|-|djj#}|| } dkr|d d d |d d d d f }|-||d}|d ur|d ur|jj | :||
}/;|/|}0|0S )Nr!   r>   r=   r   r   r   r   .r[   r  ).NNr   )r   Zoutput_sizec                    s   g | ]	}t | jqS r:   )r  r  )rb   tr  r5   r:   r;   re   b  s    z2Zamba2MambaMixer.torch_forward.<locals>.<listcomp>r  )r!   r   )<rD   r?   rh   r&  r5  r0   r8  r@   rk   r  rm   r  r6  r"  ru   r   clonerZ   r   rt   r   ndimr9  sumr$  r2   r  r   r  r   r
   rB   r	  ro   rx   r   r7  r*  r   r   r'  Zsoftplusr   r   rA   r   r   rE   Zbmmr,  Zrepeat_interleaver  r
  Zpermuter  r  Z
zeros_liker   r+  r-  )1r5   Zinput_statesr0  r   rY   r:  rc   r?   r@  r;  rJ   rI   r<  rA  r   r=  r>  r/  r'  ZdAZdBZdBxru   Zssm_states_reshapedZ
C_reshapedyr,  Z
D_residualZA_cumsumLZG_intermediateGZM_intermediateMZY_diagZdecay_statesZB_decay_contractionZstatesZprevious_statesZdecay_chunkZstates_permutedresultZ
new_statesZstate_decay_outZC_times_statesZstate_decay_out_permutedZY_offrB  Zcontextualized_statesr:   rE  r;   torch_forward  s    
.

60 . ,.B"$$$P$*L0(&
*
 zZamba2MambaMixer.torch_forwardc                 C   s0   t rd| jjjjv r| |||S | |||S )Ncuda)r.  r&  r2   rZ   r   rC  rN  )r5   rI   r0  r   r:   r:   r;   rM     s   zZamba2MambaMixer.forwardr,   r)   )rO   rP   rQ   r   r"   r   ri   r.   r0   r   rW   rC  rN  rM   rR   r:   r:   r8   r;   r    s,    D
  Fr  c                       s6   e Zd Zddedee f fddZd	ddZ  ZS )
	Zamba2MLPNrX   r   c              	      s   t    || _|j| _|j| _|| _|| _tj| jd| j |j	d| _
tj| j| j|j	d| _t|j | _tg | _t| jD ]/}||j |krfttj| jj| jjddtj| jjd| j dd}nt }| j| qA|j}dd t|D | _dS )aQ  
        This MLP layer contributes to tied transformer blocks aimed to increasing compute without increasing model size. Because this layer
        is tied, un-tied adapter modules (formally same as LoRA, but used in the base model) are added to the up and gate projectors to increase expressivity with a small memory overhead.
        r=   r   Fc                 S   r   r:   r:   r   r:   r:   r;   r     r   z&Zamba2MLP.__init__.<locals>.<dictcomp>N)r-   r.   rX   r6   rk   r   r   r
   r   r%  gate_up_proj	down_projr   Z
hidden_actact_fnr   gate_up_proj_adapter_listrv   r   r   r   r   r{   r   r   r   )r5   rX   r   r   r~   Zgate_up_proj_adapterr   r8   r:   r;   r.     s(   
zZamba2MLP.__init__c                 C   sZ   |  |}| j| }|| j| | }tj|ddd}| |d |d  }| |}|S )Nr=   r>   r   r   r!   )rQ  r   rT  r0   chunkrS  rR  )r5   Zhidden_stater   Zgate_up_stateoutputr:   r:   r;   rM     s   


zZamba2MLP.forwardr)   r,   )	rO   rP   rQ   r"   r   ri   r.   rM   rR   r:   r:   r8   r;   rP    s    rP  c                       s   e Zd Zddedee dee f fddZ				ddejd	ejded
eej dee	 dee
 deej dee deejeeejejf  f fddZ  ZS )Zamba2AttentionDecoderLayerNrX   r   r   c                    sd   t    || _t|j}t|d||d| _t|||d| _t	|j
|jd| _t	|j|jd| _d S )Nr>   )r   r   r   )r   r   r7   )r-   r.   r   r   r   r   	self_attnrP  feed_forwardrS   r   rms_norm_epsinput_layernormr6   pre_ff_layernorm)r5   rX   r   r   Znum_gsr8   r:   r;   r.     s   

z$Zamba2AttentionDecoderLayer.__init__FrI   original_hidden_statesr   r   r   r   r   r   c              	   K   sl   t j||gdd}| |}| jd||||||d|\}}	| |}| ||}|f}
|r4|
|	f7 }
|
S )a  
        Args:
            hidden_states (`torch.FloatTensor`): output of previous Mamba layer of shape `(batch, seq_len, embed_dim)`
            original_hidden_states (`torch.FloatTensor`): word embedding output of shape `(batch, seq_len, embed_dim)`.
                This is concatenated with `hidden_states` (which is the output of the previous (mamba) layer). The
                concatenated tensor is then used as input of the pre-attention RMSNorm
                (see fig. 2 in https://arxiv.org/pdf/2405.16712).
            attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
                `(batch, sequence_length)` where padding elements are indicated by 0.
            past_key_value (`Zamba2HybridDynamicCache`, *optional*): cached past key and value projection states
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            use_cache (`bool`, *optional*):
                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
                (see `past_key_values`).
            position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
                Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
                with `head_dim` being the embedding dimension of each attention head.
        r>   r   )rI   r   r   r   r   r   Nr:   )r0   Zconcatenater\  rY  r]  rZ  )r5   rI   r^  r   r   r   r   r   r   self_attn_weightsoutputsr:   r:   r;   rM     s$   




z#Zamba2AttentionDecoderLayer.forwardr)   )NNFN)rO   rP   rQ   r"   r   ri   r.   r0   r   rW   r  r   r   r   r   r   rM   rR   r:   r:   r8   r;   rW    s2    $	
rW  c                       s   e Zd Zdedef fddZ									ddejdeej dee d	eej d
eej dee	 dee
 dee
 deej deej deejeeejejf  f fddZ  ZS )Zamba2MambaDecoderLayerrX   r   c                    s4   t    t||d| _t|j|jd| _|| _d S )N)rX   r   rX  )	r-   r.   r  mambarS   r6   r[  r\  r   )r5   rX   r   r8   r:   r;   r.      s   

z Zamba2MambaDecoderLayer.__init__NFrI   r^  r   r   r   r   	use_cacher   transformer_hidden_statesr   c                 K   sd   |}|
dur
||
 n|}|  |}| j|||d}d}|| }|f}|r)||f7 }|r0||f7 }|S )a  
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
                `(batch, sequence_length)` where padding elements are indicated by 0.
            past_key_value (`Zamba2HybridDynamicCache`, *optional*): cached past key and value projection states
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            use_cache (`bool`, *optional*):
                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
                (see `past_key_values`).
            cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
                Indices depicting the position of the input sequence tokens in the sequence.
        N)rI   r0  r   )r\  rb  )r5   rI   r^  r   r   r   r   r   rc  r   rd  r   Zresidualr_  r`  r:   r:   r;   rM   &  s"   


zZamba2MambaDecoderLayer.forward)	NNNNNFFNN)rO   rP   rQ   r"   ri   r.   r0   r   r   rW   r  r   r   r   rM   rR   r:   r:   r8   r;   ra    sD    		
ra  c                       s   e Zd Zdedejdef fddZ								ddej	d	e
ej	 d
e
e de
ej	 de
ej	 de
e de
e de
e de
ej deeje
eejejf  f fddZ  ZS )Zamba2HybridLayershared_transformerlinearrb  c                    s    t    || _|| _|| _d S r,   )r-   r.   rg  mamba_decoderrf  )r5   rf  rg  rb  r8   r:   r;   r.   d  s   

zZamba2HybridLayer.__init__NFrI   r^  r   r   r   r   r   rc  r   r   c
              	   C   sn   | j |||||||	d}
|
d }|r|
d }| |}| j|||||||	d}
|r5|
d |f|
dd  }
|
S )aX  
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            original_hidden_states (`torch.FloatTensor`): word embedding output that will be concatenated with
            hidden activations to form the input of the shared transformer layer.
            layer_idx (`int`): layer number.
            attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
                `(batch, sequence_length)` where padding elements are indicated by 0.
            past_key_value (`Zamba2HybridDynamicCache`, *optional*): cached past key and value projection states
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            use_cache (`bool`, *optional*):
                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
                (see `past_key_values`).
            position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
                Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
                with `head_dim` being the embedding dimension of each attention head.
        )r^  r   r   r   r   r   r   r!   )rd  r   r   r   rc  r   r=   N)rf  rg  rh  )r5   rI   r^  r   r   r   r   r   rc  r   layer_outputsrd  r_  r:   r:   r;   rM   l  s2    


zZamba2HybridLayer.forward)NNNNNFFN)rO   rP   rQ   rW  r
   r   ra  r.   r0   r   r   ri   rW   r  r   r   r   rM   rR   r:   r:   r8   r;   re  c  sJ    	
re  c                   @   s@   e Zd ZeZdZdZddgZdZdZ	dZ
dZdZdZdd ZdS )	Zamba2PreTrainedModelmodelTrW  ra  r   c                 C   sb  | j j}t|tjtjfr%|jjjd|d |j	d ur#|j	j
  d S d S t|tjrF|jjjd|d |jd urD|jj|j 
  d S d S t|ttfrV|jjd d S t|trtt| j jt| j jt| j j  t| j j j| j jd}|tt|   }|jj| td|jd }|j jt| |j!jd d S d S )Nr   )rG   stdg      ?)minr!   )"rX   Zinitializer_ranger   r
   r   r#  r2   dataZnormal_r   r   	Embeddingpadding_idxrS   r*   Zfill_r  r0   r7  Zrandrp   mathr)  r!  r   r   Ztime_step_floorexpm1r'  r9  r(  r  r*  r,  )r5   r   rl  r<  Zinv_dtr/  r:   r:   r;   _init_weights  s:   


z#Zamba2PreTrainedModel._init_weightsN)rO   rP   rQ   r"   Zconfig_classZbase_model_prefixZsupports_gradient_checkpointingZ_no_split_modulesZ_skip_keys_device_placementZ_supports_flash_attn_2Z_supports_flex_attnZ_supports_sdpaZ_supports_cache_classZ_is_statefulrs  r:   r:   r:   r;   rj    s    rj  c                       s   e Zd ZdZdef fddZdd Zdd Ze																				dd
e	e
j de	e
j de	e
j de	e de	e
j de	e de	e de	e de	e de	e
j deeef fddZdd Zdd Z  ZS )Zamba2Modelzh
    Model consisting of *config.num_hidden_layers* layers.

    Args:
        config: Zamba2Config
    rX   c                    sN  t     | _ j| _ j| _t j j| j| _	 fddt
 jD }g }g } j| _t
 jD ]2} j| dkrH|t |d q5 j| dkrg|tj| jj| jjdd |t |d q5t|}t|}t|}| |||}t|| _ j| _t j jd| _ jr jrtd	 t | _d| _ | !  d S )
Nc                    s   g | ]}t  |d qS ))r   )rW  )rb   r   rX   r:   r;   re     s    z(Zamba2Model.__init__.<locals>.<listcomp>rb  r   r\   Fr   rX  ze`use_long_context` set to `True`: using rescaled `rope_theta` and extended `max_position_embeddings`.)"r-   r.   rX   pad_token_idrp  
vocab_sizer
   ro  r6   embed_tokensrv   r   rg   rw   r{   ra  r   iterr   
get_layersr   layersr  rS   r[  final_layernormr  Zuse_long_contextr  r  r   
rotary_embgradient_checkpointing	post_init)r5   rX   blocksmamba_layerslinear_layersr~   r|  r8   ru  r;   r.     s>   
zZamba2Model.__init__c                 C      | j S r,   ry  rU   r:   r:   r;   get_input_embeddings     z Zamba2Model.get_input_embeddingsc                 C   
   || _ d S r,   r  r5   r   r:   r:   r;   set_input_embeddings     
z Zamba2Model.set_input_embeddingsN	input_idsr   r   r   inputs_embedsrc  r   output_hidden_statesreturn_dictr   r   c                 C   s`  |d ur|n| j j}|d ur|n| j j}|d ur|n| j j}|	d ur$|	n| j j}	|d u |d uA r4td| jrC| jrC|rCt	d d}|d u rL| 
|}|}t|}|rr|d u rr|d urb|jd n|jd }t| j || j| jd}|
d u r|d ur|j| jdnd}tj|||jd  |jd}
|d u r|
d}| |||
}| j jr| ||}nd }|rd	nd }|rd	nd }t| jD ]C\}}|r||f7 }| jr| jr| |j|||||||||
}n||||||||||d
	}|d }|r|d d ur||d f7 }q| |}|r||f7 }|r|jsd|_t||r!|nd ||d}|	r,|S | S )NzaYou cannot specify both input_ids and inputs_embeds at the same time, and must specify either onezX`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.Fr   r?   rZ   rv  r!   r_   r:   )r^  r   r   r   r   r   rc  r   T)Zlast_hidden_stater   rI   
attentions) rX   r   r  rc  use_return_dict
ValueErrorr  r   r  r  ry  r0   rF  rD   rW   r?   rZ   r   first_transformer_layer_idr(  r   _update_causal_maskr  r~  r   r|  Z_gradient_checkpointing_func__call__r}  rh   r   Zto_tuple)r5   r  r   r   r   r  rc  r   r  r  r   rI   r^  rY   Zpast_seen_tokensr   r   Zall_hidden_statesZall_self_attnsr   layerri  rV  r:   r:   r;   rM     s   





zZamba2Model.forwardc                 C   sv  | j jdkr|d urd|v r|S d S |j|j}}t|j}|jd }|d d }tj||f|||d}	|dkr@tj	|	dd}	|	tj
||d|ddk9 }	|	d d d d d d f |jd ddd}	|d ur|	 }	| d	kr|jd }
|	d
d |
f d|d d d d d d f d }|	d
d |
f |||	d
d |
f< | j jdkr|d ur|jjdv rt|	|}	|	S )NZflash_attention_2r   r!   r>   )Z
fill_valuer?   rZ   r  r_   r   r=   .r   )rO  ZxpuZnpu)rX   r  r?   rZ   r0   Zfinform  rD   fullZtriur(  r   r   rF  r   eqr  r   r   Z_unmask_unattended)r5   r   r  r   r?   rZ   Z	min_dtypeZsequence_lengthZtarget_lengthr   Zmask_lengthZpadding_maskr:   r:   r;   r    s0   
*
4$zZamba2Model._update_causal_maskc                 C   sp  g }g | _ d| _t| jD ]\}}|dkr| jdkr|| _t|}| jjt| jj dkrd| d}t	
|d d d d	 d
 }	| j |	 d}
| jD ]$}|dkrm|
| jj |jkrmt	
dt|
 d }| j | |
d7 }
qM| jjrd}
| jD ]$}|dkr|
| jj |jkrt	
dt|
 d }| j | |
d7 }
q{|t|t|t| q|t| q|S )Nr   r\   r!   z	^layers\.z\.shared_transformer\.z(?:z3self_attn\.(?:q_proj|k_proj|v_proj|o_proj)\.weight|z1feed_forward\.(?:gate_up_proj|down_proj)\.weight|z,(?:input_layernorm|pre_ff_layernorm)\.weightz)$z>^shared_transformer\.feed_forward\.gate_up_proj_adapter_list\.z\.(?:0|1)\.weight$zg^shared_transformer\.self_attn\.(?:linear_q_adapter_list|linear_k_adapter_list|linear_v_adapter_list)\.)_tied_weights_keysr  r   rg   nextrX   r   r   r   recompiler{   r   r   r   re  )r5   r  r  r  r|  Zlayer_idZ
layer_typeblockZprefix_patternZmain_keys_patternZ
adapter_idZ_layer_typeZadapter_patternZattn_adapter_patternr:   r:   r;   r{    sh   




zZamba2Model.get_layers
NNNNNNNNNN)rO   rP   rQ   r   r"   r.   r  r  r   r   r0   r   r   rW   r   r  r	   r   r   rM   r  r{  rR   r:   r:   r8   r;   rt    sR    $	

x#rt  c                       s  e Zd Zdef fddZdd Zdd Zdd	 Zd
d Zdd Z	dd Z
e												d$deej deej deej dee deej deej dee dee dee dee deej deeejf deeef fdd Z						!d%d"d#Z  ZS )&Zamba2ForCausalLMrX   c                    sP   t  | t|| _dg| jj| _|j| _tj|j|jdd| _	| 
  d S )Nzlm_head.weightFr   )r-   r.   rt  rk  r  rx  r
   r   r6   lm_headr  r5   rX   r8   r:   r;   r.     s   
zZamba2ForCausalLM.__init__c                 C      | j jS r,   rk  ry  rU   r:   r:   r;   r    r   z&Zamba2ForCausalLM.get_input_embeddingsc                 C      || j _d S r,   r  r  r:   r:   r;   r       z&Zamba2ForCausalLM.set_input_embeddingsc                 C   r  r,   r  rU   r:   r:   r;   get_output_embeddings  r  z'Zamba2ForCausalLM.get_output_embeddingsc                 C   r  r,   r  )r5   Znew_embeddingsr:   r:   r;   set_output_embeddings  r  z'Zamba2ForCausalLM.set_output_embeddingsc                 C   r  r,   rk  )r5   decoderr:   r:   r;   set_decoder  r  zZamba2ForCausalLM.set_decoderc                 C   r  r,   r  rU   r:   r:   r;   get_decoder  r  zZamba2ForCausalLM.get_decoderNr   r  r   r   r   r  labelsrc  r   r  r  r   logits_to_keepr   c                 K   s   |dur|n| j j}|	dur|	n| j j}	|
dur|
n| j j}
| j||||||||	||
d
}|d }t|tr<t| dn|}| |dd|ddf }d}|dur^| j	||| j
fi |}|
st|f|dd  }|durr|f| S |S t|||j|j|jdS )al  
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

        Example:

        ```python
        >>> from transformers import AutoTokenizer, Zamba2ForCausalLM

        >>> model = Zamba2ForCausalLM.from_pretrained("Zyphra/Zamba2-7B-v1")
        >>> tokenizer = AutoTokenizer.from_pretrained("Zyphra/Zamba2-7B-v1")

        >>> prompt = "Hey, are you conscious? Can you talk to me?"
        >>> inputs = tokenizer(prompt, return_tensors="pt")

        >>> # Generate
        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
        ```N)
r  r   r   r   r  rc  r   r  r   r  r   r!   losslogitsr   rI   r  )rX   r   r  r  rk  r   ri   slicer  Zloss_functionrx  r   r   rI   r  )r5   r  r   r   r   r  r  rc  r   r  r  r   r  Zloss_kwargsr`  rI   Zslice_indicesr  r  rV  r:   r:   r;   rM     s@   (zZamba2ForCausalLM.forwardTc              	   K   s  |d u }	|	s5|d us|d |j d kr"|d d |j d  d f }n!|j d |j d kr4|d d |f }nt| j|j d | j| jd}|d url|d u rl| dd }||dkd |	sl|d d |j d  d f }|d urw|	rwd|i}
nd| i}
|
	||||| jj
|d |
S )Nr>   r!   r   r  r  r  )r   r   rc  r   r  r   )rD   rW   rX   r?   rZ   longr  Zmasked_fill_r   r   Znum_logits_to_keep)r5   r  r   r   r  r   r   rc  r   Zempty_past_kvZmodel_inputsr:   r:   r;   prepare_inputs_for_generationD  s:   

z/Zamba2ForCausalLM.prepare_inputs_for_generation)NNNNNNNNNNNr   )NNNNNT)rO   rP   rQ   r"   r.   r  r  r  r  r  r  r   r   r0   r   r   rW   r   r  r	   ri   r   r   rM   r  rR   r:   r:   r8   r;   r    sn    
	

Tr  a  
    The Zamba2 Model with a sequence classification head on top (linear layer).

    [`Zamba2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
    (e.g. GPT-2) do.

    Since it does classification on the last token, it requires to know the position of the last token. If a
    `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
    no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
    padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
    each row of the batch).
    )Zcustom_introc                       s   e Zd Z fddZdd Zdd Ze										ddeej	 d	eej
 d
eej	 deeeeej f  deej deej	 dee dee dee dee deeef fddZ  ZS )Zamba2ForSequenceClassificationc                    sJ   t  | |j| _t|| _| jj| _tj|j| jdd| _	| 
  d S )NFr   )r-   r.   
num_labelsrt  rk  r  r
   r   r6   scorer  r  r8   r:   r;   r.     s   

z(Zamba2ForSequenceClassification.__init__c                 C   r  r,   r  rU   r:   r:   r;   r    r   z4Zamba2ForSequenceClassification.get_input_embeddingsc                 C   r  r,   r  r  r:   r:   r;   r    r  z4Zamba2ForSequenceClassification.set_input_embeddingsNr  r   r   r   r  r  rc  r   r  r  r   c                 C   sB  |
dur|
n| j j}
| j||||||||	|
d	}|d }| |}|dur+|jd }n|jd }| j jdu r>|dkr>td| j jdu rGd}n1|durl|| j jk|jt	j
}t	j|jd |jt	j
d}|| d}nd}t| jj d |t	j||jd	|f }d}|dur||j}| j jdu r| jdkrd
| j _n| jdkr|jt	jks|jt	jkrd| j _nd| j _| j jd
krt }| jdkr|| | }n+|||}n%| j jdkrt }||d| j|d}n| j jdkrt }|||}|
s|f|dd  }|dur|f| S |S t|||j|j|jdS )a  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        N)r   r   r   r  rc  r   r  r  r   r!   z=Cannot handle batch sizes > 1 if no padding token is defined.r>   r[   z will not detect padding tokens in `inputs_embeds`. Results may be unexpected if using padding tokens in conjunction with `inputs_embeds.`r_   Z
regressionZsingle_label_classificationZmulti_label_classificationr  )rX   r  rk  r  rD   rw  r  r@   rZ   r0   Zint32r(  Zargmaxr  r  r9   rO   Zproblem_typer  r?   r  ri   r   r5  r   rE   r   r   r   rI   r  )r5   r  r   r   r   r  r  rc  r   r  r  Ztransformer_outputsrI   r  rY   Zlast_non_pad_tokenZnon_pad_maskZtoken_indicesZpooled_logitsr  Zloss_fctrV  r:   r:   r;   rM     sx   



"


z'Zamba2ForSequenceClassification.forwardr  )rO   rP   rQ   r.   r  r  r   r   r0   r   r   r	   r   r   r   r  r   r   rM   rR   r:   r:   r8   r;   r    sL    
	

r  )r  r  rt  rj  )r   )Nr!   )Xrq  r  	itertoolsr   typingr   r   r   r   r   r   r	   r0   r
   Ztorch.nnr   r   r   Zactivationsr   Zcache_utilsr   r   Z
generationr   Zmodeling_attn_mask_utilsr   Zmodeling_flash_attention_utilsr   Zmodeling_outputsr   r   r   Zmodeling_rope_utilsr   r   Zmodeling_utilsr   r   Zprocessing_utilsr   utilsr   r   Zutils.import_utilsr   r    Zconfiguration_zamba2r"   Z+mamba_ssm.ops.triton.selective_state_updater#   Z!mamba_ssm.ops.triton.ssd_combinedr$   r%   Zcausal_conv1dr'   r(   Z
get_loggerrO   r  Moduler*   rS   rW   r   r   ri   r   r   r   r   r   r   r
  r  r  r8  r.  r  rP  rW  ra  re  rj  rt  r  r  __all__r:   r:   r:   r;   <module>   s   $

l(

    1*@DJ'   ,p