o
    Zhl                     @   s  d dl Z d dlmZ d dlmZmZmZmZ d dlZd dl	m
Z
 d dlm
  mZ ddlmZ ddlmZ ddlmZ ddlmZmZ dd	lmZ dd
lmZmZmZ ddlmZ ee Z!eG dd deZ"eG dd deZ#G dd de
j$Z%G dd de
j$Z&edG dd de
j$Z'G dd de
j$Z(	d/de
j$dej)dej)dej)deej) d e*d!e*fd"d#Z+G d$d% d%e
j$Z,G d&d' d'e
j$Z-eG d(d) d)eZ.eG d*d+ d+e.Z/G d,d- d-e.Z0g d.Z1dS )0    N)	dataclass)CallableOptionalSequenceUnion   )use_kernel_forward_from_hub)FlashAttentionKwargs)BaseModelOutput)ALL_ATTENTION_FUNCTIONSPreTrainedModel)Unpack)auto_docstringcan_return_tuplelogging   )TimesFmConfigc                   @   s6   e Zd ZU dZdZeej ed< dZ	eej ed< dS )TimesFmOutputz
    Args:
        loc (`torch.Tensor` of shape `(batch_size, )`):
            The mean of the time series inputs.
        scale (`torch.Tensor` of shape `(batch_size,)`):
            The scale of the time series inputs.
    Nlocscale)
__name__
__module____qualname____doc__r   r   torchTensor__annotations__r    r   r   [/var/www/auris/lib/python3.10/site-packages/transformers/models/timesfm/modeling_timesfm.pyr   *   s   
 r   c                   @   sP   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eeejef  ed< dS )TimesFmOutputForPredictiona  
    Args:
        mean_predictions (`torch.Tensor` of shape `(batch_size, sequence_length)`):
            The mean predictions of the time series.
        full_predictions (`torch.Tensor` of shape `(batch_size, sequence_length)`):
            The full predictions of the time series including the mean and the quantiles.
        loss (`torch.Tensor` of shape `(1,)`, *optional*, returned when `future_values` is provided):
            The loss of the TimesFM model.
    Nmean_predictionsfull_predictionsloss)r   r   r   r   r    r   r   r   r   r!   r"   r   floatr   r   r   r   r   8   s
   
 
r   c                       0   e Zd ZdZdef fddZdddZ  ZS )	
TimesFmMLPzPax MLP in pytorch.configc                    sF   t    |j}|j}t||| _t||| _tj|dd| _	d S )Nư>)Znormalized_shapeeps)
super__init__hidden_sizeintermediate_sizennLinear	gate_proj	down_proj	LayerNorm
layer_norm)selfr&   r+   r,   	__class__r   r   r*   L   s   
zTimesFmMLP.__init__Nc                 C   sV   |  |}| |}t|}| |}|d ur'|d|d d d d d f   }|| S )N      ?)r2   r/   FZrelur0   )r3   xpaddingsZgate_inpZgateoutputsr   r   r   forwardU   s   



zTimesFmMLP.forwardNr   r   r   r   r   r*   r;   __classcell__r   r   r4   r   r%   I   s    	r%   c                       s(   e Zd ZdZ fddZdd Z  ZS )TimesFmResidualBlockzTimesFM residual block.c                    sT   t    || _|| _|| _t||| _t | _	t||| _
t||| _d S r<   )r)   r*   
input_dimshidden_dimsoutput_dimsr-   r.   input_layerZSiLU
activationoutput_layerresidual_layer)r3   r@   rA   rB   r4   r   r   r*   b   s   

zTimesFmResidualBlock.__init__c                 C   s0   |  |}| |}| |}| |}|| S r<   )rC   rD   rE   rF   )r3   r8   Zhiddenoutputresidualr   r   r   r;   m   s
   



zTimesFmResidualBlock.forward)r   r   r   r   r*   r;   r>   r   r   r4   r   r?   _   s    r?   ZRMSNormc                       s.   e Zd Zd fdd	Zdd Zdd Z  ZS )	TimesFmRMSNormr'   c                    s&   t    tt|| _|| _dS )z=
        TimesFmRMSNorm is equivalent to T5LayerNorm
        N)r)   r*   r-   	Parameterr   onesweightvariance_epsilon)r3   r+   r(   r4   r   r   r*   w   s   

zTimesFmRMSNorm.__init__c                 C   sJ   |j }|tj}|djddd}|t|| j  }| j|| S )N   T)Zkeepdim)	dtypetor   float32powmeanZrsqrtrM   rL   )r3   hidden_statesZinput_dtypeZvariancer   r   r   r;      s
   zTimesFmRMSNorm.forwardc                 C   s   t | jj d| j S )Nz, eps=)tuplerL   shaperM   )r3   r   r   r   
extra_repr   s   zTimesFmRMSNorm.extra_repr)r'   )r   r   r   r*   r;   rX   r>   r   r   r4   r   rI   u   s    rI   c                       r$   )	TimesFmPositionalEmbeddingz6Generates position embedding for a given 1-d sequence.r&   c              
      sx   t    |j}|j}|j| _| jd }tt|t| t	|d d }| 
d|ttj|tjd|    d S )NrN   r   inv_timescalesrP   )r)   r*   min_timescalemax_timescaler+   embedding_dimsmathlogr#   maxZregister_bufferr   exparangerR   )r3   r&   r\   r]   Znum_timescalesZlog_timescale_incrementr4   r   r   r*      s   

$z#TimesFmPositionalEmbedding.__init__Nc                 C   s   |du r|du rt d|du rtj|tj| jjdd}n|jdkr,t d|j |j	g |jdR  | j	ddd }tj
t|t|gdd	}t|ddd| jd f}|S )
a  Generates a Tensor of sinusoids with different frequencies.

        Args:
            seq_length: an optional Python int defining the output sequence length.
              if the `position` argument is specified.
            position: [B, seq_length], optional position for each token in the
              sequence, only required when the sequence is packed.

        Returns:
            [B, seqlen, D] if `position` is specified, else [1, seqlen, D]
        Nz.Either position or seq_length must be providedrP   devicer   rN   z*position must be 2-dimensional, got shape r   rO   dim)
ValueErrorr   rc   rR   rZ   re   Z	unsqueezendimrW   viewcatsincosr7   padr^   )r3   Z
seq_lengthpositionZscaled_timesignalr   r   r   r;      s   
&z"TimesFmPositionalEmbedding.forward)NNr=   r   r   r4   r   rY      s    rY           modulequery_states
key_statesvalue_statesattention_maskscalingdropoutc                 K   s   t ||dd| }|d ur'|d d d d d d d |jd f }	||	 }tjj|dt jd|j	}tjj
||| jd}t ||}
|
dd }
|
|fS )NrN   r   rO   )rg   rP   )ptrainingr   )r   matmul	transposerW   r-   
functionalZsoftmaxrR   rQ   rP   rx   r{   
contiguous)rr   rs   rt   ru   rv   rw   rx   kwargsattn_weightscausal_maskattn_outputr   r   r   simple_eager_attention_forward   s   
&r   c                       sz   e Zd ZdZdedef fddZdejdejfdd	Z		
ddejde
ej dee deeje
ej f fddZ  ZS )TimesFmAttentionzlImplements the attention used in TimesFM. One key difference is that there is _per_dim_scaling of the query.r&   	layer_idxc                    s   t    || _d| _|j| _|| _|j| _|j| _|j	| _	| j| j	 | _
| j| j	 | _tt| j	f| _t| j| j| j	 | _t| j| j| j	 | _t| j| j| j	 | _t| j| j	 | j| _d S )NT)r)   r*   r&   	is_causalattention_dropoutr   Znum_attention_headsZ	num_headsr+   head_dimZq_sizeZkv_sizer-   rJ   r   emptyrw   r.   q_projk_projv_projo_projr3   r&   r   r4   r   r   r*      s   
zTimesFmAttention.__init__queryreturnc                 C   s6   t | jdt| j }||d d d d d f  S )Ng^$3eG?)r7   Zsoftplusrw   mulr_   sqrtr   )r3   r   r   r   r   r   _scale_query   s   zTimesFmAttention._scale_queryNrU   rv   r   c                 K   s  |j d d }g |d| jR }| ||dd}| |}| ||dd}| ||dd}t}	| j	j
dkrZ| j	j
dkrT|ddrTtd nt| j	j
 }	|	| ||||f| jsfd	n| jd
d|\}
}|
jg |dR   }
| |
}
|
|fS )NrO   r   rN   eagerZsdpaoutput_attentionsFz`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.rq   r6   )rx   rw   )rW   r   r   rj   r}   r   r   r   r   r&   Z_attn_implementationgetloggerZwarning_oncer   r{   r   reshaper   r   )r3   rU   rv   r   Zinput_shapeZhidden_shapers   rt   ru   Zattention_interfacer   r   r   r   r   r;      s8   



zTimesFmAttention.forwardr<   )r   r   r   r   r   intr*   r   r   r   r   r   r	   rV   r;   r>   r   r   r4   r   r      s    r   c                       sb   e Zd ZdZdedef fddZ	ddejdejd	ejd
e	de
eej ejf f
ddZ  ZS )TimesFmDecoderLayerzTransformer layer.r&   r   c                    s8   t    t||d| _t|| _t|j|jd| _	d S )N)r   )r(   )
r)   r*   r   	self_attnr%   mlprI   r+   Zrms_norm_epsinput_layernormr   r4   r   r   r*     s   

zTimesFmDecoderLayer.__init__FrU   rv   r9   r   r   c                 C   s@   |}|  |}| j|||d\}}|| }| j||d}||fS )N)rU   rv   r   )r9   )r   r   r   )r3   rU   rv   r9   r   rH   scoresr   r   r   r;     s   

zTimesFmDecoderLayer.forward)F)r   r   r   r   r   r   r*   r   r   boolrV   r   r;   r>   r   r   r4   r   r     s    r   c                   @   s*   e Zd ZeZdZdgZdZdZdd Z	dS )TimesFmPreTrainedModelZtimesfmr   past_valuesTc                 C   s   t |tjr|jjjd| jjd d S t |tjr4|jjjd| jjd |j	d ur2tj
|j	 d S d S t |tjrJtj
|j tj
|j	 d S t |trXtj
|j d S t |trftj
|j d S d S )Nr   )rT   Zstd)
isinstancer-   	EmbeddingrL   dataZnormal_r&   Zinitializer_ranger.   ZbiasinitZzeros_r1   Zones_rI   r   rw   )r3   rr   r   r   r   _init_weights9  s   


z$TimesFmPreTrainedModel._init_weightsN)
r   r   r   r   Zconfig_classZbase_model_prefixZ_no_split_modulesZmain_input_nameZ_supports_sdpar   r   r   r   r   r   1  s    r   c                       s  e Zd Zdef fddZdejdejdeejeejejf f fddZe	e
				d d
ejdejdejdededefddZe	d!deej dedejdejdedeej fddZedejdejdeejejf fddZedejdejdejfddZ  ZS )"TimesFmModelr&   c                    s   t     | _td j  j jd| _tj	 j
 jd| _t fddt jD | _| jjr:t d| _|   d S )NrN   r@   rB   rA   )Znum_embeddingsZembedding_dimc                    s   g | ]}t  |qS r   )r   ).0r   r&   r   r   
<listcomp>[      z)TimesFmModel.__init__.<locals>.<listcomp>r   )r)   r*   r&   r?   patch_lengthr+   r,   input_ff_layerr-   r   Z	freq_sizefreq_embZ
ModuleListrangenum_hidden_layerslayersuse_positional_embeddingrY   position_emb	post_initr3   r&   r4   r   r   r*   P  s   zTimesFmModel.__init__inputspatched_padsr   c                 C   s   |  ||\}}t|| jjk tjd|j|jd|}||ddddf  |ddddf  }tt|| jj	 | jjk tj| jj	|j|jd|}|||ffS )zInput is of shape [B, N, P].r6   rd   N)
_timesfm_masked_mean_stdr   wherer&   	tolerancetensorrP   re   abspad_val)r3   r   r   musigmar:   r   r   r   _forward_transformc  s   
(zTimesFmModel._forward_transformFr   past_values_paddingfreqr   output_hidden_statesc                 C   s  |j d }||d| jj}||d| jj}tt|d | jjk tjd|j	|j
d|}tt|| jj | jjk tjd|j	|j
d|}| ||\}}	|d|  }tj||gdd}
| |
}tj|ddd }| jjr| |j d }tj|g|j d  dd}| ||}||7 }| |}||7 }|}| j||j d |j	|j
dd	}g }g }| jd
| jj D ]}|||||d\}}|r|| |r|| q|r|g| }nd
}t|||r|nd
|	d |	d dS )a  
        past_values_padding (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            The padding indicator of the time series.
        past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
            Past values of the time series that serves as input to the model.
        freq (`torch.LongTensor` of shape `(batch_size,)`):
            Frequency indices for the time series data.
        r   rO   r6   rq   rd   rf   r   T)rv   sequence_lengthrP   re   r   N)rU   rv   r9   r   )last_hidden_staterU   
attentionsr   r   )rW   rj   r&   r   r   r   r   r   r   rP   re   r   r   rk   r   minr   r   concat_timesfm_shift_padded_seqr   _prepare_4d_attention_maskr   r   appendr   )r3   r   r   r   r   r   ZbsizeZpatched_inputsr   statsZconcat_inputsZmodel_inputZpatched_paddingZpos_embZf_embrU   rv   Zall_attentionsZall_hidden_stateslayerr   r   r   r   r;   w  sr   






zTimesFmModel.forwardTrv   r   rP   re   r   c                 C   s   |j r	t|jnt|j}| dur"| | jd ddd} | | } |rKtjtj||f||d| dd}|dd||}| durIt	| |} | S |} | S )a  
        Creates 4D attention mask and combines causal and padding masks if needed.

        Args:
            attention_mask: Optional tensor of shape (batch_size, seq_length) containing padding mask
            sequence_length: Length of the sequence
            dtype: Data type of the mask
            device: Device of the mask
            is_causal: Whether to apply causal masking

        Returns:
            4D attention mask of shape (batch_size, 1, seq_length, seq_length)
        Nr   r   rO   rd   )Zdiagonal)
Zis_floating_pointr   Zfinfor   Ziinforj   rW   ZtriurK   minimum)rv   r   rP   re   r   Z	min_valuer   r   r   r   r     s   z'TimesFmModel._prepare_4d_attention_maskpaddingc                 C   s
  dt jfdd}t jd| dd}||}t | jd }| ||ddf }|||ddf }d| }t j|dd}	t |	dkt jd|	j|	jd	|	}	t j|| dd}
t j|| d dd}|
|	 }||	 |d  }t |d
k t jd
|j|jd	|}t 	|}||fS )a  Calculates mean and standard deviation of `inputs` across axis 1.

        It excludes values where `padding` is 1.

        Args:
            inputs: A PyTorch tensor of shape [b, n, p].
            padding: A PyTorch tensor of shape [b, n, p] with values 0 or 1.

        Returns:
            A tuple containing the mean and standard deviation.
            We return the statistics of the first patch with more than three non-padded values.
        arrc                 S   sN   t j| dkt jdd}| dkt jjdd}t |dk| jd d |S )Nr   r   rf   r   )r   argmaxrQ   int32sumr   rW   )r   indicesZrow_sumr   r   r   _get_patch_index  s   z?TimesFmModel._timesfm_masked_mean_std.<locals>._get_patch_indexr   rN   rf   r   Nrd   rq   )
r   r   r   rc   rW   r   r   rP   re   r   )r   r   r   Zpad_sumZpatch_indicesZbidxsr   rn   maskZnum_valid_elementsZ
masked_sumZmasked_squared_sumZmasked_meanZ
masked_varZ
masked_stdr   r   r   r     s0   
z%TimesFmModel._timesfm_masked_mean_stdr   seqc           
      C   s   |j \}}}| dk}|tjjdd}d||jdd < tj||jdddd	|d|}||ddddf  | }|
d|}	|	S )zShifts rows of seq based on the first 0 in each row of the mask.

        Args:
            mask: mask tensor of shape [B, N]
            seq: seq tensor of shape [B, N, P]

        Returns:
            The shifted sequence.
        r   r   rf   rO   )re   N)rW   rQ   r   r   r   anyrc   re   rj   expandgather)
r   r   Z
batch_sizeZnum_seqZfeature_dimZnew_maskr   Z	idx_rangeZshifted_idxZshifted_seqr   r   r   r   5  s   $z&TimesFmModel._timesfm_shift_padded_seq)FF)T)r   r   r   r   r*   r   r   rV   r   r   r   Z
LongTensorr   r   r;   staticmethodr   r   rP   re   r   r   r   r>   r   r   r4   r   r   N  sZ    
W-(6&r   c                       s4  e Zd ZdZdef fddZdeej dee	 de
ejejejf fdd	Zd
ejde
ejejf dejfddZdejdejdejfddZee								d!deej deeeeje	f   dee	 deej dee	 dededee dee defddZedejde	deej fdd Z  ZS )"TimesFmModelForPredictionz/TimesFM model for quantile and mean prediction.r&   c                    s\   t  | || _|j| _|j| _t|| _t	|j
|jdt|j  |jd| _|   d S )Nr   r   )r)   r*   r&   Zcontext_lengthcontext_lenhorizon_lengthhorizon_lenr   decoderr?   r+   len	quantilesr,   horizon_ff_layerr   r   r4   r   r   r*   Y  s   
z"TimesFmModelForPrediction.__init__r   r   r   c                 C   s$  g g g }}}t |D ]l\}}|jd }tj|| j |j|jd}	|| jk rO| j| }
tjtj|
|j|jd|gdd}tjtj	|
|j|	jd|	gdd}	n|| jkrg|| j d }|	| j| j  d }	|
| |
|	 |
||  qtj|ddtj|ddtj|tjdddfS )a  Formats and pads raw inputs to feed into the model.

        This function both pads each time series to match the context length, and
        pads the inputs to meet the SPMD shape requirement.

        Args:
          inputs: A list of 1d Tensors. Each Tensor is the context time series of
            a single forecast task.
          freq: list of frequencies

        Returns:
        A tuple of:
        - the padded input time series to meet the model required context.
        - the padding indicator.
        - the number of padded examples for SPMD so that each core has the same
            number (a multiple of `batch_size`) of examples.
        r   rd   rf   Nr[   rO   r   )	enumeraterW   r   Zzerosr   rP   re   r   rk   rK   r   stackr   r   r   )r3   r   r   input_tsinput_paddinginp_freqitsZ	input_lenr   Znum_front_padr   r   r   _preprocessl  s$   


"$


z%TimesFmModelForPrediction._preprocessmodel_outputr   c           	      C   sj   |  |}|j\}}}|||| jjt| jjd }|\}}||dddddf  |dddddf  S )z*Postprocess output of stacked transformer.r   N)r   rW   rj   r&   r   r   r   )	r3   r   r   Z	output_tsbn_r   r   r   r   r   _postprocess_output  s
   
 ,z-TimesFmModelForPrediction._postprocess_outputpredictionstargetsc                 C   s^   g }t | jjD ]\}}||d|f  }t|d | || }||  qt| S )N.r   )r   r&   r   r   ra   r   rT   r   )r3   r   r   Zlossesr   qerrorsr"   r   r   r   _quantile_loss  s   z(TimesFmModelForPrediction._quantile_lossNFr   window_sizefuture_valuesforecast_context_lenreturn_forecast_on_contexttruncate_negativer   r   c
           #         s  |du r| j  n| |d j}
 fdd|D }ttdd |D }|durUg }g }t|D ]\}}|| || |durL||| gd  q1|}|durU|}|du ret	d dgt
| }|du rm| jj}|	du ru| jj}	| ||\}}}||
}||
}||
}|}|jd }g }|jd |jd | j krtd	|jd  d
|jd  d| j | jj}| j| d | }t|D ]}|ddd|jd f }|dd  df }|dd  df }| j|||||	d}| |j|j|jf}|r.|dkr.|ddddd| jjddf }||dd|d}|| |dddd|df }|dddd|ddf }|| tj||gdd}q|rttj|ddddd|| jj | j ddf }ntj|ddddd| jddf }|dddddf }|dur|ddddf |ddddf  }|ddddf |ddddf  }|dkr|rt|d}t|d}d} |durt !||}!| "|ddddddf |}"|!|" } t#|j|r|j$nd|	r|j%nd||| dS )aa  
        window_size (`int`, *optional*):
            Window size of trend + residual decomposition. If None then we do not do decomposition.
        future_values (`torch.Tensor`, *optional*):
            Optional future time series values to be used for loss computation.
        forecast_context_len (`int`, *optional*):
            Optional max context length.
        return_forecast_on_context (`bool`, *optional*):
            True to return the forecast on the context when available, i.e. after the first input patch.
        truncate_negative (`bool`, *optional*):
            Truncate to only non-negative values if any of the contexts have non-negative values,
            otherwise do nothing.
        output_attentions (`bool`, *optional*):
            Whether to output the attentions.
        output_hidden_states (`bool`, *optional*):
            Whether to output the hidden states.
        past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
            Past values of the time series that serves as input to the model.
        freq (`torch.LongTensor` of shape `(batch_size,)`):
            Frequency indices for the time series data.

        Example:

        ```python
        >>> from transformers import TimesFmModelForPrediction

        >>> model = TimesFmModelForPrediction.from_pretrained("google/timesfm-2.0-500m-pytorch")

        >>> forecast_input = [torch.linspace(0, 20, 100).sin(), torch.linspace(0, 20, 200).sin(), torch.linspace(0, 20, 400).sin()]
        >>> frequency_input = torch.tensor([0, 1, 2], dtype=torch.long)

        >>> # Generate
        >>> with torch.no_grad():
        >>>     outputs = model(past_values=forecast_input, freq=frequency_input, return_dict=True)
        >>>     point_forecast_conv = outputs.mean_predictions
        >>>     quantile_forecast_conv = outputs.full_predictions
        ```
        Nr   c                    s   g | ]	}|  d  qS r<   r   r   r   Zfcontext_lenr   r   r     s    z5TimesFmModelForPrediction.forward.<locals>.<listcomp>c                 S   s   g | ]}t |qS r   )r   r   r   r   r   r   r     r   rN   z6No frequency provided via `freq`. Default to high (0).r   z=Length of paddings must match length of input + horizon_len: z != z + )r   r   r   r   r   rO   r   )Zaxis.rq   )r   r   rU   r    r!   r"   )&r   re   r   r   r   r   extend_timesfm_moving_averager   infor   r&   r   r   r   rQ   rW   r   rh   r   r   r   r   r   r   r   r   r   sizer   Zconcatenatemaximumr7   mse_lossr   r   r   rU   )#r3   r   r   r   r   r   r   r   r   r   re   r   Zinp_minZ
new_inputsZ	new_freqsr   r   r   r   r   Z	final_outr   Zfull_outputsZoutput_patch_lenZnum_decode_patchesZ
step_indexZcurrent_paddingZdecoder_outputZfprop_outputsZnew_full_tsZnew_tsZmean_outputsr"   r  Zquantile_lossr   r   r   r;     s   4






$

"&
$$
"z!TimesFmModelForPrediction.forwardr   c                 C   s`   t | |d dfdd}tj|| j| jd| }t |ddd|ddd }|| | gS )zCCalculates the moving average using PyTorch's convolution function.r   r   Zconstantrd   rO   )	r7   rn   r   rK   rP   re   Zconv1drj   Zsqueeze)r   r   Z
arr_paddedZkernelZsmoothed_arrr   r   r   r   T  s   $z1TimesFmModelForPrediction._timesfm_moving_average)NNNNFFNN)r   r   r   r   r   r*   r   r   r   r   rV   r   r   r   r   r   r   r   r   r   r;   r   listr   r>   r   r   r4   r   r   V  sh    
+
	
 %(r   )r   r   r   )rq   )2r_   dataclassesr   typingr   r   r   r   r   Ztorch.nnr-   Ztorch.nn.functionalr~   r7   Zintegrationsr   Zmodeling_flash_attention_utilsr	   Zmodeling_outputsr
   Zmodeling_utilsr   r   Zprocessing_utilsr   utilsr   r   r   Zconfiguration_timesfmr   Z
get_loggerr   r   r   r   Moduler%   r?   rI   rY   r   r#   r   r   r   r   r   r   __all__r   r   r   r   <module>   sf   
4
B!  	  