o
    ZhA                     @   s  d Z ddlZddlmZ ddlmZmZmZ ddlZddl	Zddlm
Z
 ddlmZ ddlmZ dd	lmZ dd
lmZmZmZ ddlmZmZ ddlmZ eeZe reddlmZ ddlm Z m!Z! nd\Z Z!Ze rvddl"m#Z#m$Z$ nd\Z$Z#e%ee e!e#e$fZ&dej'de(fddZ)dd Z*dd Z+dd Z,G dd dZ-G dd  d ej
j.Z/G d!d" d"e
j.Z0G d#d$ d$e
j.Z1G d%d& d&e
j.Z2eG d'd( d(eZ3eG d)d* d*eZ4eG d+d, d,eZ5eG d-d. d.e3Z6ed/d0G d1d2 d2e3eZ7g d3Z8dS )4zPyTorch MAMBA2 model.    N)	dataclass)OptionalTupleUnion)nn   )ACT2FN)GenerationMixin)PreTrainedModel)ModelOutputauto_docstringlogging)is_causal_conv1d_availableis_mamba_2_ssm_available   )Mamba2Config)selective_state_update)mamba_chunk_scan_combined mamba_split_conv1d_scan_combinedNNN)causal_conv1d_fncausal_conv1d_update)NNinput_tensorpad_sizec                 C   sH   t | jdkrddddd|ddfnddd|ddf}tjjj| |dddS )z
    Padding x tensor with `pad_size` on the seq_len dim (dim=1)

    Assumes that we only have tensors of either size 4 or 3
       r   Zconstant)modevalue)lenshapetorchr   
functionalpad)r   r   Z	pad_shape r"   Y/var/www/auris/lib/python3.10/site-packages/transformers/models/mamba2/modeling_mamba2.pypad_tensor_by_sizeA   s   2r$   c                 C   sX   t | |} t| jdkr| | jd d|| jd S | | jd d|| jd | jd S )z
    Padding input_tensor with `pad_size` on the seq_len dim (dim=1) and
    simultaneously splitting it into chunk sequences.

    Assumes that we only have tensors of either size 4 or 3
    r   r      )r$   r   r   reshape)r   r   
chunk_sizer"   r"   r#   reshape_into_chunksL   s   
r)   c                 C   s   |  d}| d jg |   |R  } tjtj||| jtjddd}| | d} tj| dd}tjtj||| jtjddd}|| tj	 }|S )zo
    More stable segment sum calculation. Uses cumulative sums and masking instead of direct subtractions.
    r%   .Ndevicedtype)Zdiagonalr   dim)
sizeexpandr   Ztrilonesr,   boolZmasked_fillcumsuminf)r   r(   maskZtensor_segsumr"   r"   r#   segment_sum`   s   
  r8   c                 C   sN   |dur%|j d dkr%|j d dkr%| j}| |dddddf  |} | S )zm
    Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66
    Nr   r   )r   r-   to)hidden_statesattention_maskr-   r"   r"   r#   apply_mask_to_padding_statest   s   $ r<   c                
   @   sv   e Zd ZdZejdfdededejde	e
 fddZ		dd
edejdedejfddZd
edejfddZdd ZdS )Mamba2Cachea  
    Arguments:
        config: Mamba2Config
        batch_size: int
        dtype: torch.dtype
        device: torch.device

    Attributes:
        dtype: (`torch.dtype`):
            The default `dtype` used to initializing the cache.
        conv_kernel_size: (`int`):
            Model's convolution kernel size taken from config.
        n_groups: (`int`):
            Model's number of groups taken from the config - similar to tensor parallel in Transformer.
        state_size: (`int`):
            Model's SSM state size taken from config.
        num_heads: (`int`):
            The number of heads used in the linear attention / SSM.
        head_dim: (`int`):
            The respective dimension of the heads used in the linear attention / SSM.
        intermediate_size: (`int`):
            Model's intermediate_size based on (expand * hidden_dim) from config.
        conv_states: (`torch.Tensor`):
            A tensor of shape `[num_layers, batch_size, conv_kernel_size, intermediate_size + 2 * n_groups * state_size]` that holds convolutional states.
        ssm_states: (`torch.Tensor`):
            A tensor of shape `[num_layers, batch_size, num_heads, head_dim, state_size]` that holds ssm states.
    Nconfig
batch_sizer-   r,   c              	   C   s   || _ |j| _|j| _|j| _|j| _|j| _t|j|j	 | _
tj|j|| j
d| j | j  | j||d| _tj|j|| j| j| j||d| _d S )Nr&   r+   )r-   conv_kernelconv_kernel_sizen_groups
state_size	num_headshead_dimintr2   hidden_sizeintermediate_sizer   Zzerosnum_hidden_layersconv_states
ssm_states)selfr>   r?   r-   r,   r"   r"   r#   __init__   s0   zMamba2Cache.__init__F	layer_idxnew_conv_state
cache_initreturnc                 C   sv   |r| | jj| j|< n)| j| jddd| j|< |d d dd d f  | jj| j| d d d d df< | j| S )Nr%   )Zshiftsdimsr   )r9   rJ   r,   Zroll)rL   rN   rO   rP   r"   r"   r#   update_conv_state   s
   8
zMamba2Cache.update_conv_statenew_ssm_statec                 C   s   | | jj| j|< | j| S N)r9   rK   r,   )rL   rN   rT   r"   r"   r#   update_ssm_state   s   
zMamba2Cache.update_ssm_statec                 C   s   | j   | j  d S rU   )rJ   Zzero_rK   rL   r"   r"   r#   reset   s   
zMamba2Cache.reset)F)__name__
__module____qualname____doc__r   Zfloat16r   rF   r-   r   strrM   Tensorr4   rS   rV   rX   r"   r"   r"   r#   r=      s0    


r=   c                       s(   e Zd Zd fdd	ZdddZ  ZS )	MambaRMSNormGatedư>c                    s&   t    tt|| _|| _d S rU   superrM   r   	Parameterr   r3   weightvariance_epsilonrL   rG   eps	__class__r"   r#   rM      s   

zMambaRMSNormGated.__init__Nc                 C   sj   |j }|tj}|d ur|tj|tj }|djddd}|t	|| j
  }| j|| S Nr&   r%   T)Zkeepdim)r-   r9   r   float32r   r    silupowmeanrsqrtre   rd   )rL   r:   gateinput_dtypevariancer"   r"   r#   forward   s   zMambaRMSNormGated.forwardr`   rU   rY   rZ   r[   rM   rs   __classcell__r"   r"   rh   r#   r_      s    r_   c                
       s   e Zd ZdZdedef fddZ			ddejde	e
 d	e	ej d
e	ej fddZ			ddejde	e
 d	e	ej d
e	ej fddZ			dde	e
 d	e	ej d
e	ej fddZ  ZS )Mamba2Mixeru  
    Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
    A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
    ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
    and is why Mamba is called **selective** state spaces)
    r>   rN   c                    s  t    |j| _|j| _|j| _|j| _t|j	| j | _
t|j| _|| _|j| _|j| _t|j | _|j| _|j| _|j| _|j| _|j| _|j| _|j| _|j| _| j
d| j | j  | _tj| j| j|j|j| j|jd d| _| j
| j | j }tj| j||jd| _ t!t"#| j| _$t"%d| jd }t!t"&|| _'d| j'_(t)| j
| jd| _*t!t"#| j| _+d| j+_(tj| j
| j|jd| _,|j| _t-st./d d S d S )Nr&   r   )Zin_channelsZout_channelsbiasZkernel_sizegroupspaddingrx   Trg   a  The fast path is not available because on of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)` is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and https://github.com/Dao-AILab/causal-conv1d)0rb   rM   rD   rG   rC   ssm_state_sizer@   rA   rF   r2   rH   Ztime_step_rankrN   use_conv_biasZ
hidden_act
activationr   actlayer_norm_epsilonZrms_normrB   rE   r(   time_step_limittime_step_mintime_step_maxconv_dimr   ZConv1dconv1dLinearZuse_biasin_projrc   r   r3   dt_biasarangelogA_log_no_weight_decayr_   normDout_projis_fast_path_availableloggerZwarning_once)rL   r>   rN   Zprojection_sizeArh   r"   r#   rM      s`   

	zMamba2Mixer.__init__Nr:   cache_paramscache_positionr;   c                 C   s  t ||}| |}|j\}}}| j| j }	|jd d| j  d| j | j  | j d }
|d ur#|d ur#|d dkr#|dj|
|
| j| j	| jgdd\}}}}}t
||j| j | jjd| jj| j}tj|| j|	|	gdd\}}}t| j  }|d d d df d d d d d f d| j| jjtjd}|d d d d d f dd| j}| jd d d df d| j}| jd d d df d| j}||| j|jd | j }||| j|jd | j }||| j| j}t|j| j ||||||d |dd	
}||| j| j }| ||}| |d d d df }|S t| j  }| j d
tdfkr8i nd| j i}| j!rv|d u rvt"|| jjd| jj| j|f| j| j#d | j| jj| jj$| jj| jj| j| jddd|}|S |j|
|
| j| j	| jgdd\}}}}}|d ur|%dd}t&j'(||j)|jd  df}|j*| j|dd | jdvr| +| |%dddd |f %dd}nt,|%dd| jjd| jj| jd%dd}t ||}tj|| j|	|	gdd\}}}t-|||d| j|||||| jd|||| jdf| j#| jd d d| jdd|\}}|d ur6|d ur6|j.| j|d |||d}| ||}| |}|S )Nr%   r&   r   r   r/   .r-   T)zr   dt_softplusg        r6   Zdt_limitF)r   r(   seq_idxr   Zrmsnorm_weightZrmsnorm_epsZoutproj_weightZoutproj_biasZheaddimZngroupsZnorm_before_gatereturn_final_statesrN   rO   rP   )rl   Zswish)xrd   rx   r   )r(   r   r   r   r   r   r   rN   rT   )/r<   r   r   rB   r}   rH   rD   squeezesplitr   r   rJ   rN   r   rd   rx   r   r   expr   floatr2   rE   r9   rk   r   r   viewr   rK   r   r   r   trainingr   r(   re   	transposer   r    r!   rA   rS   r   r   r   rV   )rL   r:   r   r   r;   projected_statesr?   seq_len_Zgroups_time_state_sized_mlprp   hidden_states_B_CdtBCr   r   r   Zhidden_states_reshapedoutZdt_limit_kwargshidden_states_B_C_transposedrJ   scan_output	ssm_stater"   r"   r#   cuda_kernels_forward(  s  

"


<"
]"T
$




z Mamba2Mixer.cuda_kernels_forwardc           2   
      s  |j \}}}|j}t||}|}	|	j d dj  dj j  j d }
|	j|
|
jj	jgdd\}}}}}|d ur|d ur|d dkr|j
j|dd |jj jjjjd}tj|jjd dd}jry|jj }|}n8|d ur|dd}tj||j|j d  df}|j
j|d	d |ddd
d |f dd}t||}tj|jjj jj gdd\}}}tj  }|d ur7|d ur7|d dkr7|jj}|d d dd d f d d d d
f }|dd ||j d j!}j"d  j"j d j!}tjj#|||j }t$|j%d j%d }|d  jj!jjtj&d}t|d | j|d}|'|jdd
d d d f }| |jjj |j d ( }|'|d|j d }|d |d
d d d f  }|'|dj!}||d  j|d}|j)j|jj | | d |'|jdd
d d d f }| |jjj |j d ( }|'|d|j d }|jj j|j|jd}|*|j j!j}|*|j jd}t+||}|*|jj!}j,d  j,j d j!}|||  |j}|'|dd d d d
f }ntj#|j" }t$|j%d j%d }|'||dj! }|'||dj }|'||dj }|j-jj djd}|j-jj djd}j.|j.  j.  j,d t/|  }||d  }||j| } fdd||||fD \}}}}|0dddd}tj1|dd}tt2|}|d d d d d d d d d d d f |d d d d d d d d d d d f  } | jdd}!|!d |0dddddd  }"|"jdd}#|#d |d d d d d f  jdd}$t|d d d d d d dd f | }%||%0ddddd  }&|&d
d d d f |d  jdd}'|d ur|d ur|d dkr|jj d d d d
f j|'jd}(nt3|'d d d df }(tj4|(|'gdd}'tt2tj|d d d d d d df d})|)dd})|)d |'d d d d d d
f  jdd}*|*d d d df |*d d df }'}+t|},|d
d d d f |'d d d d d d
f  }-|,0dddd}.|-d|.d  }/|$|/ }|'|djj!}|| } dkr,|d d d |d d d d f }|'||d}|+d urE|d urE|j)j|+d 5||}06|0|}1|1S )Nr%   r&   r/   r   Fr   r,   r   T.r*   ).NNr   r   r+   )r0   Zoutput_sizec                    s   g | ]	}t | jqS r"   )r)   r(   ).0tr   rL   r"   r#   
<listcomp>J  s    z-Mamba2Mixer.torch_forward.<locals>.<listcomp>r   r   r.   )r   r   )7r   r-   r<   r   rH   rB   r}   rD   r   r   rS   rN   rJ   r9   r   rd   r,   r   sumr   r~   rx   r   r   r   r    r!   rA   r   r   r   rK   r2   rE   r   Zsoftplusclampr   rk   r'   
contiguousrV   r   Zbmmr   Zrepeat_interleaver(   r$   Zpermuter5   r8   Z
zeros_likecatr   r   )2rL   r:   r   r   r;   r?   r   r   r-   r   r   rp   r   r   rJ   r   r   r   r   Zcache_devicer   ZdAZdBZdBxrK   Zssm_states_reshapedZ
C_reshapedyr   Z
D_residualZA_cumsumLZG_intermediateGZM_intermediateMZY_diagZdecay_statesZB_decayZstatesZprevious_statesZdecay_chunkZ
new_statesr   Zstate_decay_outZC_times_statesZstate_decay_out_permutedZY_offr   Zcontextualized_statesr"   r   r#   torch_forward  s   

.,
"$"$$$P&*""&0(&
*
 zMamba2Mixer.torch_forwardc                 C   s4   t rd| jjjjv r| ||||S | ||||S )Ncuda)r   r   rd   r,   typer   r   )rL   r:   r   r   r;   r"   r"   r#   rs     s   zMamba2Mixer.forwardr   )rY   rZ   r[   r\   r   rF   rM   r   r^   r   r=   
LongTensorr   r   rs   rv   r"   r"   rh   r#   rw      sN    E
 '
 Irw   c                       s&   e Zd Zd fdd	Zdd Z  ZS )Mamba2RMSNormr`   c                    s&   t    tt|| _|| _dS )zM
        Mamba2RMSNorm is equivalent to T5LayerNorm and LlamaRMSNorm
        Nra   rf   rh   r"   r#   rM     s   

zMamba2RMSNorm.__init__c                 C   sJ   |j }|tj}|djddd}|t|| j  }| j|| S rj   )	r-   r9   r   rk   rm   rn   ro   re   rd   )rL   r:   rq   rr   r"   r"   r#   rs     s
   zMamba2RMSNorm.forwardrt   ru   r"   r"   rh   r#   r     s    r   c                       sJ   e Zd Z fddZ			d	dee deej deej fddZ	  Z
S )
Mamba2Blockc                    sB   t    || _|| _|j| _t|j|jd| _t	||d| _
d S )Nr|   rN   )rb   rM   r>   rN   residual_in_fp32r   rG   r   r   rw   mixer)rL   r>   rN   rh   r"   r#   rM     s   
zMamba2Block.__init__Nr   r   r;   c                 C   sL   |}|  |j| j jjd}| jr|tj}| j||||d}|| }|S )Nr   r   r   r;   )r   r9   rd   r-   r   r   rk   r   )rL   r:   r   r   r;   Zresidualr"   r"   r#   rs     s   zMamba2Block.forwardr   )rY   rZ   r[   rM   r   r=   r   r   r^   rs   rv   r"   r"   rh   r#   r     s    r   c                   @   s*   e Zd ZeZdZdgZdZdZdd Z	dS )Mamba2PreTrainedModelbackboner   Tc              	   C   s  t |tr\d|j_d|j_tt| jj	t
| jjt
| jj  t
| jj j| jjd}|tt|   }t  |j| W d   n1 sSw   Y  d|j_t |tjrv|jdurut|jddsutj|j nt |tjrtjj|j| jjd | jjr|  D ]2\}}|dv rtjj!|t
"dd	 t  |t
"| jj# }W d   n1 sw   Y  qdS dS )
zInitialize the weights.T)minN
_no_reinitF)Zstd)zout_proj.weight   )a)$
isinstancerw   r   r   r   r   r   Zrandr>   rD   mathr   r   r   r   Ztime_step_floorexpm1Zno_gradr   Zcopy_r   r   r   rx   getattrinitZzeros_	EmbeddingZnormal_rd   Zinitializer_rangeZrescale_prenorm_residualZnamed_parametersZkaiming_uniform_sqrtrI   )rL   moduler   Zinv_dtnamepr"   r"   r#   _init_weights  sD   



z#Mamba2PreTrainedModel._init_weightsN)
rY   rZ   r[   r   Zconfig_classZbase_model_prefixZ_no_split_modulesZsupports_gradient_checkpointingZ_is_statefulr   r"   r"   r"   r#   r     s    r   c                   @   sJ   e Zd ZU dZdZeej ed< dZ	ee
 ed< dZeeej  ed< dS )Mamba2Outputa%  
    Class for the MAMBA2 model outputs.

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        cache_params (`Mamba2Cache`):
            The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
            avoid providing the old `input_ids`.

            Includes both the State space model state matrices after the selective scan, and the Convolutional states
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
    Nlast_hidden_stater   r:   )rY   rZ   r[   r\   r   r   r   FloatTensor__annotations__r   r=   r:   r   r"   r"   r"   r#   r     s
   
 r   c                   @   s\   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
ee ed< dZeeej  ed< dS )Mamba2CausalLMOutputa  
    Base class for causal language model (or autoregressive) outputs.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Language modeling loss (for next-token prediction).
        logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        cache_params (`Mamba2Cache`):
            The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
            avoid providing the old `input_ids`.

            Includes both the State space model state matrices after the selective scan, and the Convolutional states
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
    Nlosslogitsr   r:   )rY   rZ   r[   r\   r   r   r   r   r   r   r   r=   r:   r   r"   r"   r"   r#   r     s   
 r   c                       s   e Zd Z fddZdd Zdd Zdd Ze																dd
ee	j
 dee	j
 dee dee dee dee dee	j
 dee	j deeef fddZ  ZS )Mamba2Modelc                    sn   t    t j j| _t fddt j	D | _
d| _t j jd| _| | j |   d S )Nc                    s   g | ]}t  |d qS )r   )r   )r   idxr>   r"   r#   r   9  s    z(Mamba2Model.__init__.<locals>.<listcomp>Fr|   )rb   rM   r   r   
vocab_sizerG   
embeddingsZ
ModuleListrangerI   layersgradient_checkpointingr   r   norm_fZ"_register_load_state_dict_pre_hook	load_hook	post_initrL   r>   rh   r   r#   rM   5  s    zMamba2Model.__init__c                 G   s2   |D ]}d|v r| |||dd<  d S qd S )Nz
embedding.zembeddings.)popreplace)rL   Z
state_dictprefixargskr"   r"   r#   r   A  s   zMamba2Model.load_hookc                 C      | j S rU   r   rW   r"   r"   r#   get_input_embeddingsG     z Mamba2Model.get_input_embeddingsc                 C   
   || _ d S rU   r   rL   Znew_embeddingsr"   r"   r#   set_input_embeddingsJ     
z Mamba2Model.set_input_embeddingsN	input_idsinputs_embedsr   	use_cacheoutput_hidden_statesreturn_dictr   r;   rQ   c	                 K   s  |dur|n| j j}|dur|n| js| j jnd}|dur|n| j j}|du |duA r/td|du r8| |}| jrB| jrB|rBd}|rk|du rbt| j |	d|j
|jd}tjd| j j|j
d}n|du rjtdnd}|}
|rsdnd}| jD ]"}| jr| jr| |j|
|||}
n||
|||d	}
|r||
f }qx| |
}
|r||
f }|std
d |
||fD S t|
|r||dS d|dS )a  
        cache_params (`Mamba2Cache`, *optional*):
            If passed along, the model uses the previous state in all the blocks (which will give the output for the
            `input_ids` provided as if the model add `state_input_ids + input_ids` as context).
        use_cache (`bool`, *optional*):
            If set to `True`, the `cache_params` is returned and can be used to quickly generate the next logits.
        cache_position (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            The position of the current input in the cache. This is used to ensure that the cache is correctly updated.
            If `cache_params` is passed, `cache_position` should also be passed.
        NFz:You must specify exactly one of input_ids or inputs_embedsr   r+   r   zYou have to specify the `cache_position` manually when `use_cache=True` and `cache_params` is passed, you don't have to pass a `cache_params` if you are in prefilling stage because in that case it will be initialized for you automaticallyr"   r   c                 s   s    | ]	}|d ur|V  qd S rU   r"   )r   vr"   r"   r#   	<genexpr>  s    z&Mamba2Model.forward.<locals>.<genexpr>)r   r   r:   )r>   r   r   r   use_return_dict
ValueErrorr   r   r=   r1   r,   r-   r   r   r@   r   Z_gradient_checkpointing_func__call__r   tupler   )rL   r   r   r   r   r   r   r   r;   kwargsr:   Zall_hidden_statesZmixer_blockr"   r"   r#   rs   M  sf   





zMamba2Model.forward)NNNNNNNN)rY   rZ   r[   rM   r   r   r   r   r   r   r   r=   r4   r^   r   r   r   rs   rv   r"   r"   rh   r#   r   3  sB    	
r   z
    The MAMBA2 Model transformer with a language modeling head on top (linear layer with weights not tied to the input
    embeddings).
    )Zcustom_introc                       s   e Zd Zg Z fddZdd Zdd Zdd Zd	d
 Z					dde	e
 de	ej de	ej fddZe									dde	ej de	ej de	e
 de	ej de	e de	e de	e de	ej de	ej deeef fddZ  ZS )Mamba2ForCausalLMc                    s8   t  | t|| _tj|j|jdd| _| 	  d S )NFr{   )
rb   rM   r   r   r   r   rG   r   lm_headr   r   rh   r"   r#   rM     s   
zMamba2ForCausalLM.__init__c                 C   r   rU   r  rW   r"   r"   r#   get_output_embeddings  r   z'Mamba2ForCausalLM.get_output_embeddingsc                 C   r   rU   r  r   r"   r"   r#   set_output_embeddings  r   z'Mamba2ForCausalLM.set_output_embeddingsc                 C   s
   | j  S rU   )r   r   rW   r"   r"   r#   r     r   z&Mamba2ForCausalLM.get_input_embeddingsc                 C   s   | j |S rU   )r   r   r   r"   r"   r#   r     s   z&Mamba2ForCausalLM.set_input_embeddingsNr   r   r;   c           	      K   s   |r,|d u r
t d|d dkr!|d d df d }|d ur d }ntjd| jj|jd}|d ur9|d u r9d|i}nd|i}|||||d |S )	Nz`cache_position` should not be None as it should have been initialized in `model.generate`, you are responsible for passing in a valid `cache_position` if you are calling `prepare_inputs_for_generation` directly with `use_cache=True`r   r%   r*   r   r   r   )r;   r   r   r   )r  r   r   r>   r@   r,   update)	rL   r   r   r   r   r   r;   r  Zmodel_inputsr"   r"   r#   prepare_inputs_for_generation  s,   
z/Mamba2ForCausalLM.prepare_inputs_for_generationr   r   labelsr   r   r   rQ   c
              
   K   s   |dur|n| j j}| j||||||||	d}|d }| || jjj }d}|dur<| jd||| j j	d|
}|sR|f|dd  }|durP|f| S |S t
|||j|jdS )ao  
        cache_params (`Mamba2Cache`, *optional*):
            If passed along, the model uses the previous state in all the blocks (which will give the output for the
            `input_ids` provided as if the model add `state_input_ids + input_ids` as context).
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
            `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
            are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
        use_cache (`bool`, *optional*):
            If set to `True`, the `cache_params` is returned and can be used to quickly generate the next logits.
        cache_position (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            The position of the current input in the cache. This is used to ensure that the cache is correctly updated.
            If `cache_params` is passed, `cache_position` should also be passed.
        N)r   r   r   r   r   r   r;   r   )r   r  r   r   )r   r   r   r:   r"   )r>   r  r   r  r9   rd   r-   r   Zloss_functionr   r   r   r:   )rL   r   r   r   r  r   r   r   r   r;   r  Zmamba2_outputsr:   r   r   outputr"   r"   r#   rs     s2   
zMamba2ForCausalLM.forward)NNNNN)	NNNNNNNNN)rY   rZ   r[   Z_tied_weights_keysrM   r	  r
  r   r   r   r=   r   r   r^   r  r   r   r4   r   r   r   rs   rv   r"   r"   rh   r#   r    sd    
/	

r  )r  r   r   )9r\   r   dataclassesr   typingr   r   r   r   Ztorch.utils.checkpointr   Zactivationsr   Z
generationr	   Zmodeling_utilsr
   utilsr   r   r   Zutils.import_utilsr   r   Zconfiguration_mamba2r   Z
get_loggerrY   r   Z+mamba_ssm.ops.triton.selective_state_updater   Z!mamba_ssm.ops.triton.ssd_combinedr   r   Zcausal_conv1dr   r   allr   r^   rF   r$   r)   r8   r<   r=   Moduler_   rw   r   r   r   r   r   r   r  __all__r"   r"   r"   r#   <module>   sn   

M   A2r 