o
    ZhX                     @   sl  d dl Z d dlZd dlmZ d dlmZmZmZmZ d dl	Z	d dl
Z	d dl	mZ ddlmZ ddlmZ ddlmZ dd	lmZmZ dd
lmZ ddlmZ ddlmZmZ ddlmZmZ ddlm Z m!Z!m"Z" ddl#m$Z$m%Z%m&Z&m'Z'm(Z(m)Z)m*Z*m+Z+m,Z,m-Z- ddl.m/Z/ e rd dl0m1Z1 d dl2m3Z3m4Z4 nd\Z1Z3Z4e rd dl5m6Z6m7Z7 nd\Z7Z6e8e1e6e7fZ9dZ:e;e<Z=G dd de	jj>Z?G dd de,Z@G dd de(ZAG dd  d eZBG d!d" d"e$ZCG d#d$ d$ej>ZDG d%d& d&ej>ZEG d'd( d(e%ZFG d)d* d*e*ZGG d+d, d,e)ZHG d-d. d.eZIG d/d0 d0e+eIZJG d1d2 d2e&ZKG d3d4 d4e'ZLg d5ZMdS )6    N)cycle)CallableOptionalTupleUnion)nn   )ACT2FN)FlashAttentionKwargs)BaseModelOutputWithPast)ALL_ATTENTION_FUNCTIONSPreTrainedModel)Unpack)logging)is_causal_conv1d_availableis_mamba_ssm_available   )LlamaRotaryEmbeddingapply_rotary_pos_emb)pad_tensor_by_sizereshape_into_chunkssegment_sum)
ZambaAttentionZambaAttentionDecoderLayerZambaForCausalLMZambaForSequenceClassificationZambaHybridDynamicCacheZambaHybridLayerZambaMambaDecoderLayer
ZambaModelZambaRMSNormeager_attention_forward   )Zamba2Config)selective_state_update)mamba_chunk_scan_combined mamba_split_conv1d_scan_combinedNNN)causal_conv1d_fncausal_conv1d_updateNNzZyphra/Zamba2-2.7Bc                       s(   e Zd Zd fdd	ZdddZ  ZS )	Zamba2RMSNormGatedư>c                    s,   t    tt|| _|| _|| _d S N)	super__init__r   	Parametertorchonesweightvariance_epsilon
group_size)selfhidden_sizer5   eps	__class__ X/var/www/auris/lib/python3.10/site-packages/transformers/models/zamba2/modular_zamba2.pyr/   J   s   

zZamba2RMSNormGated.__init__Nc           	      C   s   |j }|tj}|d ur|tj|tj }|j^ }}|| j }|j	g ||| jR  }|
djddd}|t|| j  }|j	g ||| j R  }| j|| S )Nr   T)Zkeepdim)dtypetor1   float32r   
functionalsilushaper5   viewpowmeanZrsqrtr4   r3   )	r6   hidden_statesgateZinput_dtypeZprefix_dimsZlast_dimZgroup_countZhidden_states_groupZvariancer;   r;   r<   forwardP   s   
zZamba2RMSNormGated.forward)r,   r-   )__name__
__module____qualname__r/   rI   __classcell__r;   r;   r9   r<   r+   I   s    r+   c                   @      e Zd ZdS )Zamba2RMSNormNrJ   rK   rL   r;   r;   r;   r<   rO   ^       rO   c                
   @   sx   e Zd ZdZejdfdededejde	e
 fddZd	ed
ejdejdejfddZdd Zdd	e	e defddZdS )Zamba2HybridDynamicCachea  
    A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache
    (which has a constant shape regardless of seq_len).

    This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states`
    and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor
    For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`,
    while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors).
    For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors),
    while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`,
    and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`.
    Nconfig
batch_sizer>   devicec              	      s  || _ |j| _d| _t|j|j | _|j| _|j	| _
|j| _g | _i | _i | _i | _i | _i | _t|jD ]7}tj | jd|j |j  | j
|d| j|< tj | j|j| j|d| j|< | j| dkrm| j| q6 fddt|jD | _ fddt|jD | _d S )NFr   rU   r>   hybridc                        g | ]}t jg g  d qS rU   r1   Ztensor.0_rT   rU   r;   r<   
<listcomp>        z5Zamba2HybridDynamicCache.__init__.<locals>.<listcomp>c                    rX   rY   r[   r\   r_   r;   r<   r`      ra   )r>   layers_block_typehas_previous_stateintmamba_expandr7   intermediate_sizemamba_d_statessm_state_sizemamba_d_convconv_kernel_sizen_mamba_headstransformer_layersZ_modules_parameters_buffersconv_states
ssm_statesrangenum_hidden_layersr1   zerosmamba_ngroupsmamba_headdimappend	key_cacheZvalue_cache)r6   rS   rT   r>   rU   ir;   r_   r<   r/   p   s:    z!Zamba2HybridDynamicCache.__init__	layer_idxnew_conv_statecache_positionreturnc                 C   sr   | j | }|d| jd }|jddd}||j|d d d d |f< | j |   | j |  |7  < | j | S )Nr   r"   r=   Zshiftsdims)ro   clamprj   rollr?   rU   zero_)r6   ry   rz   r{   
conv_stater;   r;   r<   update_conv_state   s   

z*Zamba2HybridDynamicCache.update_conv_statec                 C   s   | j   | j  d S r-   )ro   r   rp   )r6   r;   r;   r<   reset   s   
zZamba2HybridDynamicCache.resetr   c                 C   sL   || j vr
| j d n|}t| j|ks| j|  dkrdS | j| jd S )zYReturns the sequence length of the cached states. A layer index can be optionally passed.r   )rl   lenrw   ZnumelrC   )r6   ry   r;   r;   r<   get_seq_length   s    z'Zamba2HybridDynamicCache.get_seq_length)r   )rJ   rK   rL   __doc__r1   Zfloat16r#   rd   r>   r   strr/   Tensor
LongTensorr   r   r   r;   r;   r;   r<   rR   b   s.    
 
rR   c                       s&   e Zd Z	ddef fddZ  ZS )Zamba2RotaryEmbeddingNrS   c                    s,   t  || | j||j|jd\}| _d S )N)rU   basedim)r.   r/   Zrope_init_fnZ
rope_thetaZattention_head_dimZattention_scaling)r6   rS   rU   Zinv_freqr9   r;   r<   r/      s   
zZamba2RotaryEmbedding.__init__r-   )rJ   rK   rL   r#   r/   rM   r;   r;   r9   r<   r      s
    r   c                       s   e Zd ZdZ			ddedee dee dee f fddZ			dd	ej	ded
eej	 dee
 deeej	ej	f  dee deej	eej	 eeej	  f fddZ  ZS )Zamba2AttentionaJ  
    Multi-headed attention from 'Attention Is All You Need' paper.

    Adapted from transformers.models.mistral.modeling_mistral.MistralAttention:
    The input dimension here is attention_hidden_size = 2 * hidden_size, and head_dim = attention_hidden_size // num_heads.
    The extra factor of 2 comes from the input being the concatenation of original_hidden_states with the output of the previous (mamba) layer
    (see fig. 2 in https://arxiv.org/pdf/2405.16712).
    Additionally, replaced
    attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) with
    attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim/2)
    Finally, this attention layer contributes to tied transformer blocks aimed to increasing compute without increasing model size. Because this
    layer is tied, un-tied adapters (formally the same as LoRA but used in the base model) modules are added to the q, k, v projectors to increase
    expressivity with a small memory overhead (see Fig. 2 of https://arxiv.org/pdf/2411.15242).
    NrS   ry   num_fwd_mem_blocksblock_idc           	   	      sR  t  || || _|j| _|| _|jrtg | _	tg | _
tg | _t| jD ]p}||j |kr}ttj| j| jjddtj| jj| jdd}ttj| j| jjddtj| jj| jdd}ttj| j| jjddtj| jj| jdd}nt }t }t }| j	| | j
| | j| q+dd t| jD | _d S )NFbiasc                 S      i | ]\}}||qS r;   r;   r]   indexvaluer;   r;   r<   
<dictcomp>       z,Zamba2Attention.__init__.<locals>.<dictcomp>)r.   r/   r   hybrid_layer_idslayer_block_mapr   use_shared_attention_adapterr   
ModuleListlinear_q_adapter_listlinear_k_adapter_listlinear_v_adapter_listrq   num_mem_blocks
SequentialLinearZattention_hidden_sizerS   adapter_rankIdentityrv   	enumerate	layer_dic)	r6   rS   ry   r   r   rx   Zlinear_q_adapterZlinear_k_adapterZlinear_v_adapterr9   r;   r<   r/      s:   zZamba2Attention.__init__rG   attention_maskpast_key_valueposition_embeddingskwargsr|   c                 K   s  |j d d }g |d| jR }| |}	| |}
| |}| jjrD| j| }|	| j| | }	|
| j	| | }
|| j
| | }|	|dd}	|
|dd}
||dd}| jjrp|\}}t|	|
||\}	}
|d ur}||
||\}
}t}| jjdkr| jjdkr|ddrtd nt| jj }|| |	|
||f| jsd	n| j| jd
|\}}|jg |dR   }| |}||fS )Nr=   r"   r   eagerZsdpaoutput_attentionsFz`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.        )Zdropoutscaling)rC   head_dimZq_projZk_projZv_projrS   r   r   r   r   r   rD   	transposeuse_mem_roper   updater!   _attn_implementationgetloggerwarning_oncer   trainingZattention_dropoutr   reshape
contiguousZo_proj)r6   rG   ry   r   r   r   r   Zinput_shapeZhidden_shapeZquery_statesZ
key_statesZvalue_statesZadapter_layer_idxcossinZattention_interfaceZattn_outputZattn_weightsr;   r;   r<   rI      sP   	





zZamba2Attention.forwardr'   )rJ   rK   rL   r   r#   r   rd   r/   r1   r   rR   r   r   r
   rI   rM   r;   r;   r9   r<   r      s@    -r   c                       s   e Zd ZdZddedee f fddZ		ddej	dee
 d	eej	 fd
dZddee
 d	eej	 fddZ		ddee
 d	eej	 fddZ  ZS )Zamba2MambaMixeru  
    Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
    A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
    ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
    and is why Mamba is called **selective** state spaces)
    NrS   ry   c                    s  t    || _|j| _|j| _|j| _t|j	| j | _
|| _|j| _d| _t | _|j| _|j| _|j| _| jj| _|j| _|j| _|j| _|j| _| j
d| j | j  | _tj| j| jd|j| j|jd d| _| j
| j | j }tj| j||j d| _!t"t#$| j| _%t#&d| jd }t"t#'|| _(d| j(_)t*| j
| j
| j dd| _+t"t#$| j| _,d| j,_)tj| j
| j|j d| _-t.st/0d	 d S d S )
NrB   r   Tr"   )Zin_channelsZout_channelsr   Zkernel_sizegroupspaddingr   gh㈵>)r5   r8   a  The fast path is not available because on of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)` is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and https://github.com/Dao-AILab/causal-conv1d)1r.   r/   rS   r7   rg   rh   ri   rj   rd   re   rf   ry   use_conv_bias
activationr   ZSiLUactuse_mem_eff_pathrt   n_groupsru   r   rk   	num_heads
chunk_sizetime_step_limittime_step_mintime_step_maxconv_dimConv1dconv1dr   add_bias_linearin_projr0   r1   r2   dt_biasarangelogA_logZ_no_weight_decayr+   normDout_projis_fast_path_availabler   r   )r6   rS   ry   Zprojection_sizeAr9   r;   r<   r/   1  s`   



	zZamba2MambaMixer.__init__rG   cache_paramsr   c                 C   sF  |j \}}}| j| j }d| j d| j | j  | j }|d ur|jr| |d}	|	j d | d }
|
|
| j| j| jg}t	j
|	|dd\}}}}}t||j| j | jjd| jj| j}t	j
|| j||gdd\}}}t	| j  }|d d d df d d d d d f d| j| jjt	jd}|d d d d d f dd| j}| jd d d df d| j}| jd d d df d| j}||| j|j d | j }||| j|j d | j }||| j| j}t|j| j ||||||d |dd
}||| j| j }| ||}| |d d d df }|S |d ur;t	 |dks;|j!}||d d d d d f  |}| |}t	| j  }| j"d u rQi nd	| j"i}|d urct	 |dk}nd}| j#r| j$r|d u r|rt%|| jjd| jj| j|f| j| j&d | j| jj| jj'| jj| jj| j| jd
dd|\}}|S t	j
|| j| j| jgdd\}}}|d ur|(dd}t)j*+|| j,|j d  df}|j| j -| t.d u s| jdvr| /| |(dd(ddd d d |f }n t.|(dd| jjd| jj| jd(ddd d d |f }t	j
|| j||gdd\}}}|d urNt	 |dksN|j!}||d d d d d f  |}t0|||d| j|||||| jd|||| jdf| j&| jd d d| jdd|\}}|d ur|d ur|j| j -| |||d}| ||}| |}|S )Nr   r"   r=   r   .r>   T)zr   dt_softplusZdt_limitF)r   r   seq_idxr   Zrmsnorm_weightZrmsnorm_epsZoutproj_weightZoutproj_biasZheaddimZngroupsZnorm_before_gatereturn_final_statesr   )rB   Zswish)xr3   r   r   )r   r   r   r   r   r   r   )1rC   r   rh   rf   r   rc   r   squeezer   r1   splitr)   ro   ry   r   r3   r   r   expr   floatexpandr   r?   r@   r   r   rD   r$   rp   r   r   allr>   r   r   r   r&   r   r4   r   r   rA   padrj   copy_r(   r   r%   )r6   rG   r   r   rT   seq_lenr^   Zgroups_time_state_sizeZd_to_removeZin_projected_statesd_mlpZsplit_projection_dimrH   Zhidden_states_B_CdtBCr   r   r   Zhidden_states_reshapedoutr>   projected_statesZdt_limit_kwargsZinput_not_masked	ssm_stateZ	time_stepZhidden_states_B_C_tr   scan_outputr;   r;   r<   cuda_kernels_forwardr  s   

<"
] 

 
L
(

 

z%Zamba2MambaMixer.cuda_kernels_forwardc           1   
      s	  |j \}}}|j}|d ur|jr|d}n |d ur4t|dks4||d d d d d f  |}|}|j d dj  dj	 j
  j d }	|j|	|	jjjgdd\}}}
}}|d ur8|jj  }||j}|jr|
d}
|jj }tj|ddd}|jdkr|d d dd d f n||d d d d df< |jj | tj||jjjd d dd d f  dd}jr|jj7 }||d d d df }n||dd}tj |j!|j d  df}|jj | |ddd d d |d d f }|d ur7t|dks7|j}||d d d d d f  |}n&tj"|jj#j
f|j|d	}|dddd |f dd}tj|jj	j
 j	j
 gdd\}}}t$j%&  }|d ur|jr|jdkr|d d d df n|d d dd d f d d d df }|dd'||j d j#}j(d
 'j(j d j#}tjj)|||j }t*|j+}|d 'jj#j
jtj,d}t$|d
 | }|-|j	ddd d d f }|'|j	jj	 |j d . }|-|d|j d }|d
 |dd d d f  }|-|dj#}||d
  }|jj |jj | |  |-|j	ddd d d f }|'|j	jj	 |j d . }|-|d|j d }|jj |j}|/|j j#j
}|/|j j
d}t0||}|/|jj#}j1d
 'j1j d j#}|||  |j}|-|dd d d df }ntj)|j( }t*|j+}|-||dj#& }|-||dj
& }|-||dj
& }|j2jj	 djd}|j2jj	 djd}j3|j3  j3  j1d
 t4|  }||d
  }||j| } fdd||||fD \}}}}|5dddd}tj6|dd}t$t7|}|d d d d d d d d d d d f |d d d d d d d d d d d f  }|jdd}|d
 |5dddddd
  } | jdd}!|!d
 |d d d d d f  d}"t$|d d d d d d dd f | }#||#5ddddd
  }$|$5dddddd
 |5ddddddd d d f  jdd5ddddd}%|d ur|jr|jj d d d df }&nt8|%d d d df }&tj9|&|%gdd}%t$t7tj |d d d d d d df d}'|%5ddddd}(|'d |(d d d d d df  jdd})|)5ddddd}*|*d d d df |*d d df }%}t$|}+|dd d d f |%d d d d d df  },|+5dddd}-|,d|-d
  }.|"|. }|-|djj#}|| } dkr|d d d |d d d d f }|-||d}|d ur|d ur|jj | :||
}/;|/|}0|0S )Nr"   r=   r   r   r}   r   r   .rV   ).N).NNr   )r   Zoutput_sizec                    s   g | ]	}t | jqS r;   )r   r   )r]   tZpad_sizer6   r;   r<   r`     s    z2Zamba2MambaMixer.torch_forward.<locals>.<listcomp>   )r"   r   )<rC   r>   rc   r   r   r1   r   r?   rf   r   rh   r   r   r   rp   ry   clonerU   	unsqueezero   r   ndimr   sumr   r3   r   r   r   r   r   rA   r   rj   rs   r   r   r   r   r   r   Zsoftplusr   r   r@   r   r   rD   Zbmmr   Zrepeat_interleaver   r   ZpermuteZcumsumr   Z
zeros_likecatr   r   )1r6   Zinput_statesr   r   rT   r   r^   r>   r   r   rH   rG   r   r   r   r   r   r   r   ZdAZdBZdBxrp   Zssm_states_reshapedZ
C_reshapedyr   Z
D_residualZA_cumsumLZG_intermediateGZM_intermediateMZY_diagZdecay_statesZB_decay_contractionZstatesZprevious_statesZdecay_chunkZstates_permutedresultZ
new_statesZstate_decay_outZC_times_statesZstate_decay_out_permutedZY_offr   Zcontextualized_statesr;   r   r<   torch_forward	  s    
.

60 . ,.B"$$$P$*L0(&
*
 zZamba2MambaMixer.torch_forwardc                 C   s0   t rd| jjjjv r| |||S | |||S )Ncuda)r   r   r3   rU   typer   r   )r6   rG   r   r   r;   r;   r<   rI     s   zZamba2MambaMixer.forwardr-   r*   )rJ   rK   rL   r   r#   r   rd   r/   r1   r   rR   r   r   rI   rM   r;   r;   r9   r<   r   )  s,    D
  Fr   c                       s6   e Zd Zddedee f fddZd	ddZ  ZS )
	Zamba2MLPNrS   r   c              	      s   t    || _|j| _|j| _|| _|| _tj| jd| j |j	d| _
tj| j| j|j	d| _t|j | _tg | _t| jD ]/}||j |krfttj| jj| jjddtj| jjd| j dd}nt }| j| qA|j}dd t|D | _dS )aQ  
        This MLP layer contributes to tied transformer blocks aimed to increasing compute without increasing model size. Because this layer
        is tied, un-tied adapter modules (formally same as LoRA, but used in the base model) are added to the up and gate projectors to increase expressivity with a small memory overhead.
        r   r   Fc                 S   r   r;   r;   r   r;   r;   r<   r     r   z&Zamba2MLP.__init__.<locals>.<dictcomp>N)r.   r/   rS   r7   rf   r   r   r   r   r   gate_up_proj	down_projr	   Z
hidden_actact_fnr   gate_up_proj_adapter_listrq   r   r   r   r   rv   r   r   r   )r6   rS   r   r   rx   Zgate_up_proj_adapterr   r9   r;   r<   r/     s(   
zZamba2MLP.__init__c                 C   sZ   |  |}| j| }|| j| | }tj|ddd}| |d |d  }| |}|S )Nr   r=   r   r   r"   )r  r   r  r1   chunkr  r  )r6   Zhidden_statery   Zgate_up_stateoutputr;   r;   r<   rI     s   


zZamba2MLP.forwardr*   r-   )	rJ   rK   rL   r#   r   rd   r/   rI   rM   r;   r;   r9   r<   r    s    r  c                       s   e Zd Zddedee dee f fddZ				ddejd	ejded
eej dee	 dee
 deej dee deejeeejejf  f fddZ  ZS )Zamba2AttentionDecoderLayerNrS   r   ry   c                    sD   || _ t|j}t || t|d||d| _t|||d| _d S )Nr=   )ry   r   r   )r   r   )	r   r   r   r.   r/   r   	self_attnr  feed_forward)r6   rS   r   ry   Znum_gsr9   r;   r<   r/     s
   
z$Zamba2AttentionDecoderLayer.__init__FrG   original_hidden_statesr   r   r   r   r   r|   c              	   K   sl   t j||gdd}| |}| jd||||||d|\}}	| |}| ||}|f}
|r4|
|	f7 }
|
S )a  
        Args:
            hidden_states (`torch.FloatTensor`): output of previous Mamba layer of shape `(batch, seq_len, embed_dim)`
            original_hidden_states (`torch.FloatTensor`): word embedding output of shape `(batch, seq_len, embed_dim)`.
                This is concatenated with `hidden_states` (which is the output of the previous (mamba) layer). The
                concatenated tensor is then used as input of the pre-attention RMSNorm
                (see fig. 2 in https://arxiv.org/pdf/2405.16712).
            attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
                `(batch, sequence_length)` where padding elements are indicated by 0.
            past_key_value (`Zamba2HybridDynamicCache`, *optional*): cached past key and value projection states
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            use_cache (`bool`, *optional*):
                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
                (see `past_key_values`).
            position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
                Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
                with `head_dim` being the embedding dimension of each attention head.
        r=   r   )rG   ry   r   r   r   r   Nr;   )r1   Zconcatenateinput_layernormr
  Zpre_ff_layernormr  )r6   rG   r  ry   r   r   r   r   r   self_attn_weightsZoutputsr;   r;   r<   rI   	  s$   




z#Zamba2AttentionDecoderLayer.forwardr*   )NNFN)rJ   rK   rL   r#   r   rd   r/   r1   r   rR   boolr   r   r
   r   FloatTensorrI   rM   r;   r;   r9   r<   r	    s2    $	
r	  c                       s&   e Zd Zdedef fddZ  ZS )Zamba2MambaDecoderLayerrS   ry   c                    s2   t  || t||d| _t|j|jd| _d S )N)rS   ry   r8   )r.   r/   r   mambarO   r7   rms_norm_epsr  )r6   rS   ry   r9   r;   r<   r/   @  s   z Zamba2MambaDecoderLayer.__init__)rJ   rK   rL   r#   rd   r/   rM   r;   r;   r9   r<   r  ?  s    r  c                       s   e Zd Zdedejdef fddZ								ddej	d	e
ej	 d
e
e de
ej	 de
ej	 de
e de
e de
e de
ej deeje
eejejf  f fddZ  ZS )Zamba2HybridLayershared_transformerlinearr  c                    s   t  ||| | `|| _d S r-   )r.   r/   Zshared_transfr  )r6   r  r  r  r9   r;   r<   r/   G  s   
zZamba2HybridLayer.__init__NFrG   r  ry   r   causal_maskr   r   	use_cacher   r|   c
              	   C   sn   | j |||||||	d}
|
d }|r|
d }| |}| j|||||||	d}
|r5|
d |f|
dd  }
|
S )aX  
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            original_hidden_states (`torch.FloatTensor`): word embedding output that will be concatenated with
            hidden activations to form the input of the shared transformer layer.
            layer_idx (`int`): layer number.
            attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
                `(batch, sequence_length)` where padding elements are indicated by 0.
            past_key_value (`Zamba2HybridDynamicCache`, *optional*): cached past key and value projection states
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            use_cache (`bool`, *optional*):
                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
                (see `past_key_values`).
            position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
                Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
                with `head_dim` being the embedding dimension of each attention head.
        )r  ry   r   r   r   r   r   r"   )transformer_hidden_statesr   r   r   r  r   r   N)r  r  Zmamba_decoder)r6   rG   r  ry   r   r  r   r   r  r   layer_outputsr  r  r;   r;   r<   rI   N  s2    


zZamba2HybridLayer.forward)NNNNNFFN)rJ   rK   rL   r	  r   r   r  r/   r1   r   r   rd   rR   r  r   r   r  rI   rM   r;   r;   r9   r<   r  F  sJ    
	
r  c                   @   s@   e Zd ZeZdZdZddgZdZdZ	dZ
dZdZdZdd ZdS )	Zamba2PreTrainedModelmodelTr	  r  past_key_valuesc                 C   sb  | j j}t|tjtjfr%|jjjd|d |j	d ur#|j	j
  d S d S t|tjrF|jjjd|d |jd urD|jj|j 
  d S d S t|ttfrV|jjd d S t|trtt| j jt| j jt| j j  t| j j j| j jd}|tt|   }|jj| td|jd }|j jt| |j!jd d S d S )Nr   )rF   stdg      ?)minr"   )"rS   Zinitializer_range
isinstancer   r   r   r3   dataZnormal_r   r   	Embeddingpadding_idxrO   r+   Zfill_r   r1   r   Zrandrk   mathr   r   r   r   Ztime_step_floorexpm1r   r   r   r   r   r   )r6   moduler  r   Zinv_dtr   r;   r;   r<   _init_weights  s:   


z#Zamba2PreTrainedModel._init_weightsN)rJ   rK   rL   r#   Zconfig_classZbase_model_prefixZsupports_gradient_checkpointingZ_no_split_modulesZ_skip_keys_device_placementZ_supports_flash_attn_2Z_supports_flex_attnZ_supports_sdpaZ_supports_cache_classZ_is_statefulr(  r;   r;   r;   r<   r    s    r  c                   @   s   e Zd ZdZdefddZdd Z										ddeej	 d	eej
 d
eej	 dee deej dee dee dee dee deej	 deeef fddZdS )Zamba2Modelzh
    Model consisting of *config.num_hidden_layers* layers.

    Args:
        config: Zamba2Config
    rS   c                    sN  t |    | _ j| _ j| _t j j| j| _	 fddt
 jD }g }g } j| _t
 jD ]2} j| dkrH|t |d q5 j| dkrg|tj| jj| jjdd |t |d q5t|}t|}t|}| |||}t|| _ j| _t j jd| _ jr jrtd	 t | _d| _ | !  d S )
Nc                    s   g | ]}t  |d qS ))r   )r	  )r]   krS   r;   r<   r`     s    z(Zamba2Model.__init__.<locals>.<listcomp>r  ry   rW   Fr   r  ze`use_long_context` set to `True`: using rescaled `rope_theta` and extended `max_position_embeddings`.)"r  r/   rS   Zpad_token_idr$  Z
vocab_sizer   r#  r7   embed_tokensrq   r   rb   rr   rv   r  r   iterr   
get_layersr   layersr   rO   r  final_layernormr   Zuse_long_contextr   r   r   
rotary_embgradient_checkpointingZ	post_init)r6   rS   blocksmamba_layerslinear_layersrx   r0  r;   r+  r<   r/     s>   
zZamba2Model.__init__c                 C   sp  g }g | _ d| _t| jD ]\}}|dkr| jdkr|| _t|}| jjt| jj dkrd| d}t	
|d d d d	 d
 }	| j |	 d}
| jD ]$}|dkrm|
| jj |jkrmt	
dt|
 d }| j | |
d7 }
qM| jjrd}
| jD ]$}|dkr|
| jj |jkrt	
dt|
 d }| j | |
d7 }
q{|t|t|t| q|t| q|S )Nr   rW   r"   z	^layers\.z\.shared_transformer\.z(?:z3self_attn\.(?:q_proj|k_proj|v_proj|o_proj)\.weight|z1feed_forward\.(?:gate_up_proj|down_proj)\.weight|z,(?:input_layernorm|pre_ff_layernorm)\.weightz)$z>^shared_transformer\.feed_forward\.gate_up_proj_adapter_list\.z\.(?:0|1)\.weight$zg^shared_transformer\.self_attn\.(?:linear_q_adapter_list|linear_k_adapter_list|linear_v_adapter_list)\.)Z_tied_weights_keysfirst_transformer_layer_idr   rb   nextrS   r   r   r   recompilerv   r   r   r   r  )r6   r4  r6  r5  r0  Zlayer_idZ
layer_typeblockZprefix_patternZmain_keys_patternZ
adapter_idZ_layer_typeZadapter_patternZattn_adapter_patternr;   r;   r<   r/    sh   




zZamba2Model.get_layersN	input_idsr   position_idsr  inputs_embedsr  r   output_hidden_statesreturn_dictr{   r|   c                 C   s`  |d ur|n| j j}|d ur|n| j j}|d ur|n| j j}|	d ur$|	n| j j}	|d u |d uA r4td| jrC| jrC|rCt	d d}|d u rL| 
|}|}t|}|rr|d u rr|d urb|jd n|jd }t| j || j| jd}|
d u r|d ur|j| jdnd}tj|||jd  |jd}
|d u r|
d}| |||
}| j jr| ||}nd }|rd	nd }|rd	nd }t| jD ]C\}}|r||f7 }| jr| jr| |j|||||||||
}n||||||||||d
	}|d }|r|d d ur||d f7 }q| |}|r||f7 }|r|jsd|_t||r!|nd ||d}|	r,|S | S )NzaYou cannot specify both input_ids and inputs_embeds at the same time, and must specify either onezX`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.Fr   )r>   rU   r,  r"   rZ   r;   )r  ry   r   r  r   r   r  r   T)Zlast_hidden_stater  rG   Z
attentions) rS   r   r?  r  Zuse_return_dict
ValueErrorr3  r   r   r   r-  r1   r   rC   rR   r>   rU   r   r7  r   r   Z_update_causal_maskr   r2  r   r0  Z_gradient_checkpointing_func__call__r1  rc   r   Zto_tuple)r6   r<  r   r=  r  r>  r  r   r?  r@  r{   rG   r  rT   Zpast_seen_tokensr  r   Zall_hidden_statesZall_self_attnsry   layerr  r  r;   r;   r<   rI     s   





zZamba2Model.forward)
NNNNNNNNNN)rJ   rK   rL   r   r#   r/   r/  r   r1   r   r   rR   r  r  r   r   r   rI   r;   r;   r;   r<   r)    sJ    $2	

r)  c                   @   rN   )Zamba2ForCausalLMNrP   r;   r;   r;   r<   rD    rQ   rD  c                   @   rN   )Zamba2ForSequenceClassificationNrP   r;   r;   r;   r<   rE    rQ   rE  )rD  rE  r)  r  )Nr%  r9  	itertoolsr   typingr   r   r   r   r1   Ztorch.utils.checkpointr   Zactivationsr	   Zmodeling_flash_attention_utilsr
   Zmodeling_outputsr   Zmodeling_utilsr   r   Zprocessing_utilsr   utilsr   Zutils.import_utilsr   r   Zllama.modeling_llamar   r   Zmamba2.modeling_mamba2r   r   r   Zzamba.modeling_zambar   r   r   r   r   r   r   r   r    r!   Zconfiguration_zamba2r#   Z+mamba_ssm.ops.triton.selective_state_updater$   Z!mamba_ssm.ops.triton.ssd_combinedr%   r&   Zcausal_conv1dr(   r)   r   r   Z_CONFIG_FOR_DOCZ
get_loggerrJ   r   Moduler+   rO   rR   r   r   r   r  r	  r  r  r  r)  rD  rE  __all__r;   r;   r;   r<   <module>   s^   0

Gs   1*>I' V