o
    Zh+                     @   s  d Z ddlZddlmZ ddlmZ ddlmZmZm	Z	m
Z
 ddlZddlZddlmZ ddlmZ dd	lmZ dd
lmZmZmZmZmZmZ ddlmZ eeZdadd ZG dd dej j!Z"d)ddZ#d)ddZ$G dd dej%Z&G dd dej%Z'G dd dej%Z(eG dd deZ)eG dd deZ*eG d d! d!eZ+eG d"d# d#e)Z,ed$d%G d&d' d'e)eZ-g d(Z.dS )*zPyTorch RWKV model.    N)	dataclass)Path)ListOptionalTupleUnion)nn   )GenerationMixin)PreTrainedModel)ModelOutputauto_docstringis_bitsandbytes_availableis_ninja_availableis_torch_cuda_availablelogging   )
RwkvConfigc                    s   ddl m} tt jjjd d   fdddD }td ur'tj| kr'd S t	d|  d	 d
dddddd|  g}|d|  |t
 t
jk|da| t_d S )Nr   )loadZkernelsrwkvc                    s   g | ]} | qS  r   ).0fZkernel_folderr   U/var/www/auris/lib/python3.10/site-packages/transformers/models/rwkv/modeling_rwkv.py
<listcomp>4       z(load_wkv_cuda_kernel.<locals>.<listcomp>)z
wkv_op.cppzwkv_cuda.cuzwkv_cuda_bf16.cuz2Loading CUDA kernel for RWKV at context length of .z
-res-usagez--maxrregcount 60z--use_fast_mathz-O3z-Xptxas -O3z--extra-device-vectorizationz-DTmax=Zwkv_)namesourcesverboseZextra_cuda_cflags)Ztorch.utils.cpp_extensionr   r   __file__resolveparentrwkv_cuda_kernelmax_seq_lengthloggerinfor   Zget_verbosityDEBUG)context_lengthZload_kernelZcuda_kernel_filesflagsr   r   r   load_wkv_cuda_kernel.   s*   	
r+   c                   @   s(   e Zd ZedddZedddZdS )	RwkvLinearAttentionNFc              	   C   s  |  \}}}	|tjkrtd| dtj d||	 t|	d dkr4td| d|	 dt|	d d	|j| _|jjd
ksP|jjd
ksP|jjd
ksP|jjd
krTtdt	
|   }|jt	jkrp| }| }| }| }| }| }t	j|t	jd}
|s|d ur|d u rt	j||	dt	j|jt	jd}|d d d d df  d8  < nt	jdd |D dd }|jt	jkrtj}ntj}||||||
| n|jt	jkrtjntj}||||||
 | |||||
 |d urdd t	j|dddD }|
| j|fS )NzCannot process a batch with z+ tokens at the same time, use a maximum of z with this model.    r   zThe product of batch size (z) and hidden size (z") needs to be a round multiple of r   cudazUCalling the CUDA kernel for wkv attention requires all tensors to be on CUDA devices.memory_formatr	   )dtypedevicer0      籡*Gc                 S      g | ]}| d qS r3   Z	unsqueezer   sr   r   r   r   }       z/RwkvLinearAttention.forward.<locals>.<listcomp>)dimc                 S   r5   r6   )Zsqueezer8   r   r   r   r      r:   )sizer$   r%   
ValueErrorminr1   input_dtyper2   typetorchexpfloat
contiguousfloat16
empty_likecontiguous_formatzerosfloat32catbfloat16Zforward_with_state_bf16Zforward_with_stateZforward_bf16forwardZsave_for_backwardchunkto)ctx
time_decay
time_firstkeyvaluestatereturn_stateZ
batch_sizeZseq_lenhidden_sizeoutputZforward_funcr   r   r   rL   O   sd   
 zRwkvLinearAttention.forwardc                 C   s   | j }| j\}}}}}tj|tj|tjkrtjntjd}	tj|tjd}
tj|tjd}tj|tjd}|tjkr>| }|tjkrFt	j
nt	j}||||||| |	|
||
 |	||
|||||d d fS )N)r0   r1   r/   )r?   Zsaved_tensorsrA   rF   rG   rK   rI   rE   rC   r$   Zbackward_bf16backwardrD   rN   )rO   Zg_outputZg_stater?   rP   rQ   rR   rS   rW   Zg_time_decayZg_time_firstZg_keyZg_valueZbackward_funcr   r   r   rX      s@   
zRwkvLinearAttention.backwardNFN)__name__
__module____qualname__staticmethodrL   rX   r   r   r   r   r,   N   s
    >r,   Fc                 C   s  |  \}}}t|}|d u r=tj|d d df tjd}	tj|d d df tjd}
tj|d d df tjdd }n|\}	}
}t|  } t|D ]p}|d d |f  }|d d |f }t||| }t|| }t|| | }||	 ||  }||
 | }|| |j	|d d |f< t||  |}t||  | }t|| }||	 ||  }	||
 | }
|}qL|s|d ur|	|
|g}||fS )Nr   )r1   r4   )
r<   rA   Z
zeros_likerI   rB   rangerC   maximumrN   r1   )rP   rQ   rR   rS   rT   rU   _Z
seq_lengthrW   Z	num_stateZ	den_stateZ	max_stateZcurrent_indexcurrent_keycurrent_valueZmax_for_outpute1e2	numeratordenominatorZmax_for_stater   r   r   rwkv_linear_attention_cpu   s4   
"

rh   c                 C   s`   t dd | |||fD }|ddk}td u s|s|r&t| |||||dS t| |||||S )Nc                 s   s    | ]	}|j jd kV  qdS )r.   N)r2   r@   )r   tr   r   r   	<genexpr>       z(rwkv_linear_attention.<locals>.<genexpr>r   rT   rU   )anyr<   r$   rh   r,   apply)rP   rQ   rR   rS   rT   rU   Zno_cudaZ	one_tokenr   r   r   rwkv_linear_attention   s
   ro   c                       s2   e Zd Zd
 fdd	ZdddZddd	Z  ZS )RwkvSelfAttentionr   c                    sD  t    || _td uotj|jk}t r0t r0|s0zt|j W n t	y/   t
d Y nw || _|j}|jd ur>|jn|}|| _tt|| _tt|| _ttdd|| _ttdd|| _ttdd|| _td| _tj||dd| _tj||dd| _tj||dd| _tj||dd| _d S )Nz9Could not load the custom CUDA kernel for RWKV attention.r   r   r   r   FZbias)super__init__configr$   r%   r)   r   r   r+   	Exceptionr&   r'   layer_idrV   attention_hidden_sizer   	ParameterrA   emptyrP   rQ   time_mix_keytime_mix_valuetime_mix_receptance	ZeroPad2d
time_shiftLinearrR   rS   
receptancerW   )selfrv   rx   Zkernel_loadedrV   ry   	__class__r   r   ru      s0   
zRwkvSelfAttention.__init__Nc                 C   s  | ddkr|d ur|d d d d d | jf }n| |}|d ur7|d d d d d | jf |d d df< || j |d| j   }|| j |d| j   }|| j |d| j   }| |}| |}t	| 
|}|d ur|d d df |d d d d d | jf< ||||fS Nr   r   rr   )r<   rx   r   r|   r}   r~   rR   rS   rA   sigmoidr   )r   hiddenrT   shiftedrR   rS   r   r   r   r   extract_key_value  s   
(

(z#RwkvSelfAttention.extract_key_valueFc           	         s    j ||d\}}}}|d urt fdd|dd  D nd }t j j||||d\}}|d urb|d |d d d d d  jf< |d |d d d d d  jf< |d |d	 d d d d  jf<  || |fS )
NrT   c                 3   s(    | ]}|d d d d  j f V  qd S rZ   rx   r8   r   r   r   rj   #  s   & z,RwkvSelfAttention.forward.<locals>.<genexpr>r3   rl   r   r   r	      )r   tuplero   rP   rQ   rx   rW   )	r   r   rT   	use_cacher   rR   rS   Zlayer_stater   r   r   r   rL   !  s   *
	   zRwkvSelfAttention.forwardr   rZ   rY   )r[   r\   r]   ru   r   rL   __classcell__r   r   r   r   rp      s    
rp   c                       s(   e Zd Zd fdd	ZdddZ  ZS )	RwkvFeedForwardr   c                    s   t    || _|| _|j}|jd ur|jnd|j }td| _t	t
dd|| _t	t
dd|| _tj||dd| _tj||dd| _tj||dd| _d S )Nr   rq   r   Frs   )rt   ru   rv   rx   rV   intermediate_sizer   r   r   rz   rA   r{   r|   r~   r   rR   r   rS   )r   rv   rx   rV   r   r   r   r   ru   6  s   
zRwkvFeedForward.__init__Nc                 C   s
  | ddkr|d ur|d d d d d | jf }n| |}|d ur7|d d d d d | jf |d d df< || j |d| j   }|| j |d| j   }tt| |}| 	|}t
| |}|d ur|d d df |d d d d d | jf< || |fS r   )r<   rx   r   r|   r~   rA   ZsquareZrelurR   rS   r   r   )r   r   rT   r   rR   r   rS   r   r   r   rL   G  s   
(
(zRwkvFeedForward.forwardr   rZ   r[   r\   r]   ru   rL   r   r   r   r   r   r   5  s    r   c                       s&   e Zd Z fddZdddZ  ZS )	RwkvBlockc                    sv   t    || _|| _|dkrtj|j|jd| _tj|j|jd| _	tj|j|jd| _
t||| _t||| _d S )Nr   )Zeps)rt   ru   rv   rx   r   	LayerNormrV   Zlayer_norm_epsilonpre_lnln1ln2rp   	attentionr   feed_forward)r   rv   rx   r   r   r   ru   \  s   
zRwkvBlock.__init__NFc                 C   s|   | j dkr
| |}| j| |||d\}}|| }| j| ||d\}}|| }||f}|r8||f7 }|S |d7 }|S )Nr   )rT   r   r   rZ   )rx   r   r   r   r   r   )r   r   rT   r   output_attentionsr   r   Zoutputsr   r   r   rL   j  s   


zRwkvBlock.forward)NFFr   r   r   r   r   r   [  s    r   c                   @   s2   e Zd ZeZdZdgZddgZdZdZ	dd Z
dS )	RwkvPreTrainedModelr   r   rP   rQ   Tc                    s"  t |tr|j}|jj}|jj|j ||d  d||  }tjfddt	D |j
j|j
jd}|ddddf } fddt	 D }tj||jj|jjd}tjdd t	 D |jj|jjdd	 }t : ||j_t|jtd
 | |j_t|||j
_t||d
  |j_t|d	| |j_W d   dS 1 sw   Y  dS t |tr|j}|jj}|jjd||  }tjfddt	D |j
j|j
jd}|ddddf }t  t|||j
_t|||j_W d   dS 1 sw   Y  dS dS )zInitialize the weights.r   g      ?c                       g | ]}|  qS r   r   r   irV   r   r   r     r   z5RwkvPreTrainedModel._init_weights.<locals>.<listcomp>r1   r2   Nc                    s,   g | ]}d d| d  dd     qS )   r   gffffff?g?r   )r   h)ry   ratio_0_to_1r   r   r     s    c                 S   s   g | ]
}|d  d d  qS )r   r	   r   r   r   r   r   r     s    g      ?g333333?c                    r   r   r   r   r   r   r   r     r   )
isinstancerp   rx   rv   num_hidden_layersrV   ry   rA   Ztensorr_   r|   r1   r2   rP   rQ   no_graddataZ	ones_likemathlogpowr}   r~   r   )r   modulerx   r   Zratio_1_to_almost0Ztime_weightZdecay_speedZzigzagr   )ry   rV   r   r   _init_weights  s`   

	"
$z!RwkvPreTrainedModel._init_weightsN)r[   r\   r]   r   Zconfig_classZbase_model_prefixZ_no_split_modulesZ_keep_in_fp32_modulesZsupports_gradient_checkpointingZ_is_statefulr   r   r   r   r   r   }  s    r   c                   @   sn   e Zd ZU dZdZeej ed< dZ	ee
ej  ed< dZeeejdf  ed< dZeeejdf  ed< dS )
RwkvOutputa  
    Class for the RWKV model outputs.

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        state (list of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`):
            The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
            avoid providing the old `input_ids`.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    Nlast_hidden_staterT   .hidden_states
attentions)r[   r\   r]   __doc__r   r   rA   FloatTensor__annotations__rT   r   r   r   r   r   r   r   r   r     s   
 r   c                   @   s   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eeej  ed< dZeeejdf  ed< dZeeejdf  ed< dS )	RwkvCausalLMOutputa|  
    Base class for causal language model (or autoregressive) outputs.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Language modeling loss (for next-token prediction).
        logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        state (list of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`):
            The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
            avoid providing the old `input_ids`.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    NlosslogitsrT   .r   r   )r[   r\   r]   r   r   r   rA   r   r   r   rT   r   r   r   r   r   r   r   r   r     s   
 r   c                       s   e Zd Z fddZdd Zdd Ze								ddeej	 d	eej	 d
eej
 deeej
  dee dee dee dee deeef fddZdd Zdd Z  ZS )	RwkvModelc                    sd   t    t j j| _t fddt j	D | _
t j| _d| _d| _|   d S )Nc                    s   g | ]}t  |d qS )r   )r   )r   idxrv   r   r   r     s    z&RwkvModel.__init__.<locals>.<listcomp>F)rt   ru   r   Z	Embedding
vocab_sizerV   
embeddingsZ
ModuleListr_   r   blocksr   ln_outlayers_are_rescaledgradient_checkpointing	post_initr   rv   r   r   r   ru     s    zRwkvModel.__init__c                 C      | j S rZ   r   r   r   r   r   get_input_embeddings     zRwkvModel.get_input_embeddingsc                 C   
   || _ d S rZ   r   r   Znew_embeddingsr   r   r   set_input_embeddings     
zRwkvModel.set_input_embeddingsN	input_idsattention_maskinputs_embedsrT   r   r   output_hidden_statesreturn_dictreturnc	                    s,  |dur|n| j j}|dur|n| j j}|dur|n| js!| j jnd}|dur)|n| j j}|dur6td | j| jkr@| 	  |durL durLt
d|du rX du rXt
d du ra| | |r|du r d| j j| j jf fddtd	D }|d
  d8  < | jr| jr|rtd d} }	|rdnd}
|rdnd}t| jD ]H\}}| jr| jr| |j|	|||\}	}}n||	|||d\}	}}| jr| j jdkr|d | j j dkr|	d }	|r||	f }|r|
|f }
q| |	}	|r||	f }|stdd |	|||
fD S t|	|||
dS )a  
        input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
            `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
            `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
            sequence tokens in the vocabulary.

            If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
            `input_ids`.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        state (tuple of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`, *optional*):
            If passed along, the model uses the previous state in all the blocks (which will give the output for the
            `input_ids` provided as if the model add `state_input_ids + input_ids` as context).
        use_cache (`bool`, *optional*):
            If set to `True`, the last state is returned and can be used to quickly generate the next logits.
        NFz<`attention_mask` was passed, but it is unused in this model.zDYou cannot specify both input_ids and inputs_embeds at the same timez5You have to specify either input_ids or inputs_embedsr   c                    s0   g | ]}t j|d kr jnt j jdqS )r   r   )rA   rH   r1   rI   r2   r   r   shaper   r   r   K  s    z%RwkvModel.forward.<locals>.<listcomp>   r   gꌠ9Y>)FzZ`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...r   )rT   r   r   r   r3   c                 s   s    | ]	}|d ur|V  qd S rZ   r   )r   xr   r   r   rj   {  rk   z$RwkvModel.forward.<locals>.<genexpr>)r   rT   r   r   )rv   r   r   trainingr   use_return_dictr&   Zwarning_oncer   _rescale_layersr=   r   r<   rV   r   r_   r   	enumerater   Z_gradient_checkpointing_func__call__rescale_everyr   r   r   )r   r   r   r   rT   r   r   r   r   r   Zall_self_attentionsZall_hidden_statesr   blockr   r   r   r   rL     sv   





zRwkvModel.forwardc                 C   sx  | j | j kr	d S | jjdkrt  t| jD ]\}}| jrA|jj	j
dt|| jj   |jjj
dt|| jj   qt|jj	j
drl|jj	j
jdt|| jj   |jjj
jdt|| jj   qt|jj	j
dr| |jj	| | |jj| q|jj	j
dt|| jj   |jjj
dt|| jj   qW d    n1 sw   Y  | j | _ d S )Nr   r3   SCBquant_state)r   r   rv   r   rA   r   r   r   r   rW   weightZmul_intr   rS   hasattrr   div_ _bnb_4bit_dequantize_and_rescale)r   block_idr   r   r   r   r     s&   
 ""$ "zRwkvModel._rescale_layersc                 C   st   t  stdddl}|j|jj|jj}|dt	|| j
j   |jj|ddd|j}t|d| dS )	z
        Perform the dequantization and rescaling of the weights of a given layer. After that operation the layer will
        be quantized again.
        z/Please install bitsandbytes to use this method.r   Nr3   cpuF)Zrequires_gradr   )r   ImportErrorZbitsandbytesZ
functionalZdequantize_4bitr   r   r   r   r   rv   r   r   Z
Params4bitrN   r2   setattr)r   Ztarget_layerr   ZbnbZdequant_weightsZquant_weightr   r   r   r     s   z*RwkvModel._bnb_4bit_dequantize_and_rescale)NNNNNNNN)r[   r\   r]   ru   r   r   r   r   rA   
LongTensorr   r   boolr   r   r   rL   r   r   r   r   r   r   r   r     sD    	

nr   z
    The RWKV Model transformer with a language modeling head on top (linear layer with weights tied to the input
    embeddings).
    )Zcustom_introc                       s   e Zd ZdgZ fddZdd Zdd Zdd	d
Ze									dde	e
j de	e
j de	e
j de	ee
j  de	e
j de	e de	e de	e de	e deeef fddZ  ZS )RwkvForCausalLMzhead.weightc                    s8   t  | t|| _tj|j|jdd| _| 	  d S )NFrs   )
rt   ru   r   r   r   r   rV   r   headr   r   r   r   r   ru     s   
zRwkvForCausalLM.__init__c                 C   r   rZ   r   r   r   r   r   get_output_embeddings  r   z%RwkvForCausalLM.get_output_embeddingsc                 C   r   rZ   r   r   r   r   r   set_output_embeddings  r   z%RwkvForCausalLM.set_output_embeddingsNc                 K   sT   |d ur|d d df  d}|d ur|d u rd|i}nd|i}||d< ||d< |S )Nrr   r   r   rT   r   r7   )r   r   rT   r   r   kwargsZmodel_inputsr   r   r   prepare_inputs_for_generation  s   
z-RwkvForCausalLM.prepare_inputs_for_generationr   r   r   rT   labelsr   r   r   r   r   c
              	   K   s   |	dur|	n| j j}	| j|||||||	d}|d }| |}d}|dur3| j||fd| j ji|
}|	sI|f|dd  }|durG|f| S |S t|||j|j|j	dS )aI  
        input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
            `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
            `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
            sequence tokens in the vocabulary.

            If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
            `input_ids`.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        state (tuple of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`, *optional*):
            If passed along, the model uses the previous state in all the blocks (which will give the output for the
            `input_ids` provided as if the model add `state_input_ids + input_ids` as context).
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
            `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
            are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
        use_cache (`bool`, *optional*):
            If set to `True`, the last state is returned and can be used to quickly generate the next logits.
        N)r   rT   r   r   r   r   r   r   r   )r   r   rT   r   r   )
rv   r   r   r   Zloss_functionr   r   rT   r   r   )r   r   r   r   rT   r   r   r   r   r   r   Zrwkv_outputsr   r   r   rW   r   r   r   rL     s@   %	
zRwkvForCausalLM.forward)NNN)	NNNNNNNNN)r[   r\   r]   Z_tied_weights_keysru   r   r   r   r   r   rA   r   r   r   r   r   r   r   rL   r   r   r   r   r   r     sJ    
	

r   )r   r   r   rY   )/r   r   dataclassesr   pathlibr   typingr   r   r   r   rA   Ztorch.utils.checkpointr   Z
generationr
   Zmodeling_utilsr   utilsr   r   r   r   r   r   Zconfiguration_rwkvr   Z
get_loggerr[   r&   r$   r+   ZautogradFunctionr,   rh   ro   Modulerp   r   r   r   r   r   r   r   __all__r   r   r   r   <module>   sF    
 
j
,F&"B  3l