o
    Zh7Q                    @   s  d Z ddlZddlmZ ddlmZmZmZ ddlZddl	m
Z
 ddlmZ ddlmZ ddlmZmZmZ dd	lmZmZ dd
lmZ ddlmZ eeZG dd de
jZG dd de
jZG dd de
jZ G dd de
jZ!G dd de
jZ"G dd de
jZ#G dd de
jZ$G dd de
jZ%G dd de
jZ&G dd  d e
jZ'G d!d" d"e
jZ(G d#d$ d$e
jZ)G d%d& d&e
jZ*eG d'd( d(eZ+G d)d* d*e
jZ,		+	djd,ej-d-e.d.ee/ d/e0d0e1f
d1d2Z2		dkd,ej-d3ee/e1f d.ee/ d0e1fd4d5Z3G d6d7 d7e
jZ4G d8d9 d9e
jZ5G d:d; d;e
jZ6G d<d= d=e
jZ7G d>d? d?e
jZ8eG d@dA dAeZ9G dBdC dCe+Z:eG dDdE dEeZ;edFdGG dHdI dIe+Z<eG dJdK dKeZ=G dLdM dMe+Z>eG dNdO dOeZ?eG dPdQ dQeZ@eG dRdS dSeZAdTejBjCdUej-dVej-fdWdXZDdldYej-dZeej- dVej-fd[d\ZEG d]d^ d^e+ZFeG d_d` d`eZGG dadb dbe+ZHeG dcdd ddeZIG dedf dfe
jZJG dgdh dhe+ZKg diZLdS )mzPyTorch PatchTSMixer model.    N)	dataclass)OptionalTupleUnion)PreTrainedModel)ModelOutput   )NegativeBinomialOutputNormalOutputStudentTOutput)auto_docstringlogging)deprecate_kwarg   )PatchTSMixerConfigc                       s2   e Zd ZdZdedef fddZdd Z  ZS )PatchTSMixerGatedAttentionz
    Module that applies gated attention to input data.

    Args:
        in_size (`int`): The input size.
        out_size (`int`): The output size.
    in_sizeout_sizec                    s*   t    t||| _tjdd| _d S )Ndim)super__init__nnLinear
attn_layerZSoftmaxattn_softmax)selfr   r   	__class__ e/var/www/auris/lib/python3.10/site-packages/transformers/models/patchtsmixer/modeling_patchtsmixer.pyr   -   s   
z#PatchTSMixerGatedAttention.__init__c                 C   s   |  | |}|| }|S N)r   r   )r   inputsZattn_weightr    r    r!   forward2   s   z"PatchTSMixerGatedAttention.forward)__name__
__module____qualname____doc__intr   r$   __classcell__r    r    r   r!   r   $   s    r   c                       6   e Zd ZdZdef fddZdejfddZ  Z	S )PatchTSMixerBatchNormzP
    Compute batch normalization over the sequence length (time) dimension.
    configc                    s"   t    tj|j|jd| _d S )NZeps)r   r   r   BatchNorm1dd_modelnorm_eps	batchnormr   r-   r   r    r!   r   >   s   
zPatchTSMixerBatchNorm.__init__r#   c                 C   s"   | dd}| |}| ddS )a  
        Parameters:
            inputs (`torch.Tensor` of shape `(batch_size, sequence_length, d_model)`):
                input for Batch norm calculation
        Returns:
            `torch.Tensor` of shape `(batch_size, sequence_length, d_model)`
        r      )	transposer2   )r   r#   outputr    r    r!   r$   B   s   
zPatchTSMixerBatchNorm.forward
r%   r&   r'   r(   r   r   torchTensorr$   r*   r    r    r   r!   r,   9   s    r,   c                       sN   e Zd ZdZdef fddZededejfddZ	de
jfd	d
Z  ZS )PatchTSMixerPositionalEncodingz'
    Class for positional encoding
    r-   c                    s<   t    |jr| || _d S tt|j	|j
| _d S r"   )r   r   use_positional_encoding_init_peposition_encr   	Parameterr8   zerosnum_patchesr0   r3   r   r    r!   r   T   s   
z'PatchTSMixerPositionalEncoding.__init__returnc                 C   s   | j dkrtjt| j| jdd}|S | j dkrvt| j| j}td| j	d}t
td| jdtd| j   }t|| |d d dd df< t|| |d d dd df< ||  }|| d	  }tj|d
d}|S t| j  d)NrandomTZrequires_gradZsincosr   r   r4   g     @
   FzN is not a valid positional encoder. Available types are 'random' and 'sincos'.)positional_encoding_typer   r>   r8   Zrandnr@   r0   r?   Zarange	unsqueezeexpmathlogsincosmeanstd
ValueError)r-   r=   positionZdiv_termr    r    r!   r<   \   s    

(  
z'PatchTSMixerPositionalEncoding._init_pepatch_inputc                 C   s   || j  }|S r"   )r=   )r   rP   hidden_stater    r    r!   r$   p   s   
z&PatchTSMixerPositionalEncoding.forward)r%   r&   r'   r(   r   r   staticmethodr   r>   r<   r8   r9   r$   r*   r    r    r   r!   r:   O   s    r:   c                       r+   )PatchTSMixerNormLayerzeNormalization block

    Args:
        config (`PatchTSMixerConfig`):
            Configuration.
    r-   c                    sF   t    |j| _d|j v rt|| _d S tj|j|j	d| _d S )Nbatchr.   )
r   r   norm_mlplowerr,   normr   	LayerNormr0   r1   r3   r   r    r!   r   ~   s
   
zPatchTSMixerNormLayer.__init__r#   c                 C   sf   d| j  v r,t||jd |jd  |jd |jd f}| |}t||j}|S | |}|S )a  
        Args:
            inputs (`torch.Tensor` of shape `((batch_size, num_channels, num_patches, d_model))`):
                Input to the normalization layer.
        Returns:
            `torch.Tensor` of shape `((batch_size, num_channels, num_patches, d_model))`
        rT   r   r   r4   r   )rU   rV   r8   reshapeshaperW   )r   r#   Zinputs_reshapedr    r    r!   r$      s   


zPatchTSMixerNormLayer.forwardr7   r    r    r   r!   rS   v   s    
rS   c                       s,   e Zd Z fddZdejfddZ  ZS )PatchTSMixerMLPc                    sP   t    ||j }t||| _t|j| _t||| _	t|j| _
d S r"   )r   r   Zexpansion_factorr   r   fc1Dropoutdropoutdropout1fc2dropout2)r   in_featuresout_featuresr-   Z
num_hiddenr   r    r!   r      s   

zPatchTSMixerMLP.__init__r#   c                 C   s0   |  tj| |}| |}| |}|S )z
        Args:
            inputs (`torch.Tensor` of shape `((batch_size, num_channels, num_patches, d_model))`):
                Input to the MLP layer.
        Returns:
            `torch.Tensor` of the same shape as `inputs`
        )r_   r   
functionalZgelur\   r`   ra   )r   r#   r    r    r!   r$      s   

zPatchTSMixerMLP.forward)r%   r&   r'   r   r8   r9   r$   r*   r    r    r   r!   r[      s    r[   c                       r+   )$PatchTSMixerChannelFeatureMixerBlockzThis module mixes the features in the channel dimension.

    Args:
        config (`PatchTSMixerConfig`):
            Configuration.
    r-   c                    P   t    t|| _|j| _t|j|j|d| _|jr&t|j|jd| _	d S d S Nrb   rc   r-   r   r   )
r   r   rS   rW   
gated_attnr[   num_input_channelsmlpr   gating_blockr3   r   r    r!   r      s   

z-PatchTSMixerChannelFeatureMixerBlock.__init__r#   c                 C   sT   |}|  |}|dddd}| jr| |}| |}|dddd}|| }|S )z
        Args:
            inputs (`torch.Tensor` of shape `((batch_size, num_channels, num_patches, d_model))`):
                input to the MLP layer
        Returns:
            `torch.Tensor` of the same shape as `inputs`
        r   r   r4   r   )rW   Zpermuterj   rm   rl   )r   r#   residualoutr    r    r!   r$      s   


z,PatchTSMixerChannelFeatureMixerBlock.forwardr7   r    r    r   r!   re      s    re   c                       s   e Zd ZdZ						ddededed	ed
ededee dee f fddZ	e
ddde
ddde
ddd						ddejdeej deeej  deej deej dedeej deejeej eeej  f fddZ  ZS )PatchTSMixerAttentionz=Multi-headed attention from 'Attention Is All You Need' paper        FTN	embed_dim	num_headsr^   
is_decoderbias	is_causalr-   	layer_idxc	           	         s   t    || _|| _|| _|| | _|| _| j| | jkr*td| j d| d| jd | _|| _	|| _
|| _|d u rK| j	rKtd| jj d tj|||d| _tj|||d| _tj|||d| _tj|||d| _d S )Nz;embed_dim must be divisible by num_heads (got `embed_dim`: z and `num_heads`: ).g      zInstantiating a decoder z without passing `layer_idx` is not recommended and will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` when creating this class.)ru   )r   r   rr   rs   r^   head_dimr-   rN   scalingrt   rv   rw   loggerZwarning_oncer   r%   r   r   k_projv_projq_projout_proj)	r   rr   rs   r^   rt   ru   rv   r-   rw   r   r    r!   r      s0   


zPatchTSMixerAttention.__init__key_value_statesz4.55)versionpast_key_valuecache_positionhidden_statesattention_masklayer_head_maskoutput_attentionsrA   c                 C   s  |  \}}	}
| ||d| j| jdd}|| j }| ||d| j| jdd}| ||d| j| jdd}|| j d| jf}|j	| }|j	| }|j	| }| d}t
||dd}|  || j |	|fkrtd|| j |	|f d|   |dur|ddddddd|jd f }||| j|	|| }||| j |	|}tjj|dd}|dur|  | jfkrtd	| jf d|   |dddd||| j|	| }||| j |	|}|r||| j|	|}||| j |	|}nd}tjj|| j| jd
}t
||}|  || j |	| jfkr8td|| j |	| jf d|   ||| j|	| j}|dd}|	||	| j}| |}||dfS )z#Input shape: Batch x Time x Channelr   r   r4   z$Attention weights should be of size z	, but is Nr   z/Head mask for a single layer should be of size )ptrainingz `attn_output` should be of size )sizer~   viewrs   ry   r5   rz   r|   r}   rY   r8   ZbmmrN   rZ   r   rd   Zsoftmaxr^   r   rr   r   )r   r   r   r   r   r   r   r   ZbszZtgt_len_Zquery_statesZ
key_statesZvalue_statesZ
proj_shapeZsrc_lenZattn_weightsZattn_weights_reshapedZ
attn_probsZattn_outputr    r    r!   r$     s`   "
""



&"

zPatchTSMixerAttention.forward)rq   FTFNN)NNNNFN)r%   r&   r'   r(   r)   floatboolr   r   r   r   r8   r9   r   r$   r*   r    r    r   r!   rp      sf    	
(

	rp   c                       .   e Zd ZdZdef fddZdd Z  ZS )PatchMixerBlockzxThis module mixes the patch dimension.

    Args:
        config (`PatchTSMixerConfig`):
            Configuration.
    r-   c                    s~   t    t|| _|j| _|j| _t|j|j|d| _|jr(t	|j|jd| _
|jr=t|j|j|jd| _t|| _d S d S )Nrh   ri   )rr   rs   r^   )r   r   rS   rW   	self_attnrj   r[   r@   rl   r   rm   rp   r0   Zself_attn_headsr^   self_attn_layer	norm_attnr3   r   r    r!   r   x  s&   

zPatchMixerBlock.__init__c                 C   s   |}|  |}| jr,|j\}}}}||| ||}| j|dd\}}	}	|||||}|dd}| |}| jr?| |}|dd}| jrO| 	|| }|| }
|
S )z
        Args:
            hidden_state (`torch.Tensor`): Input tensor.

        Returns:
            `torch.Tensor`: Transformed tensor.
        F)r   r4   r   )
rW   r   rZ   rY   r   r5   rl   rj   rm   r   )r   rQ   rn   
batch_sizeZn_varsr@   r0   Zhidden_state_reshapedZx_attnr   ro   r    r    r!   r$     s    


zPatchMixerBlock.forwardr%   r&   r'   r(   r   r   r$   r*   r    r    r   r!   r   p  s    r   c                       r+   )FeatureMixerBlockzThis module mixes the hidden feature dimension.

    Args:
        config (`PatchTSMixerConfig`):
            Configuration.

    r-   c                    rf   rg   )
r   r   rS   rW   rj   r[   r0   rl   r   rm   r3   r   r    r!   r     s   

zFeatureMixerBlock.__init__hiddenc                 C   s4   |}|  |}| |}| jr| |}|| }|S )
        Args:
            hidden (`torch.Tensor` of shape `(batch_size, num_patches, d_model)`):
                Input tensor to the layer.

        Returns:
            `torch.Tensor`: Transformed tensor.
        )rW   rl   rj   rm   )r   r   rn   ro   r    r    r!   r$     s   	


zFeatureMixerBlock.forwardr7   r    r    r   r!   r     s    r   c                       r+   )PatchTSMixerLayerz
    The `PatchTSMixer` layer that does all three kinds of mixing.

    Args:
        config (`PatchTSMixerConfig`):
            Configuration.

    r-   c                    sH   t    t|d| _t|d| _|j| _|jdkr"t|d| _d S d S )Nr-   mix_channel)	r   r   r   patch_mixerr   feature_mixermodere   channel_feature_mixerr3   r   r    r!   r     s   

zPatchTSMixerLayer.__init__r   c                 C   s,   | j dkr
| |}| |}| |}|S )r   r   )r   r   r   r   )r   r   r    r    r!   r$     s
   
	


zPatchTSMixerLayer.forwardr7   r    r    r   r!   r     s    	r   c                       s6   e Zd ZdZdef fddZd	defddZ  ZS )
PatchTSMixerBlockzThe main computing framework of the `PatchTSMixer` model.

    Args:
        config (`PatchTSMixerConfig`):
            Configuration.
    r-   c                    s2   t     j}t fddt|D | _d S )Nc                    s   g | ]}t  d qS )r   )r   .0r   r   r    r!   
<listcomp>  s    z.PatchTSMixerBlock.__init__.<locals>.<listcomp>)r   r   
num_layersr   Z
ModuleListrangemixers)r   r-   r   r   r   r!   r     s   
"zPatchTSMixerBlock.__init__Foutput_hidden_statesc                 C   s>   g }|}| j D ]}||}|r|| q|r||fS |dfS )as  
        Args:
            hidden_state (`torch.Tensor`): The input tensor.
            output_hidden_states (`bool`, *optional*, defaults to False.):
                Whether to output the hidden states as well.

        Returns:
            `torch.Tensor`: The embedding. `list`: List of all hidden states if `output_hidden_states` is set to
            `True`.
        N)r   append)r   rQ   r   Zall_hidden_statesZ	embeddingmodr    r    r!   r$     s   

zPatchTSMixerBlock.forwardF)	r%   r&   r'   r(   r   r   r   r$   r*   r    r    r   r!   r     s    r   c                       0   e Zd ZdZddef fddZdd Z  ZS )	PatchTSMixerForPredictionHeadzqPrediction Head for Forecasting

    Args:
        config (`PatchTSMixerConfig`):
            Configuration.
    Nr-   c                    s|   t    |j| _| jd ur| j  t|j| _|d u r+t|j	|j
 |j| _n
||j	|j
 | _tjdd| _d S )Nr   Z	start_dim)r   r   prediction_channel_indicessortr   r]   head_dropoutdropout_layerr   r@   r0   prediction_lengthbase_forecast_blockget_parameter_projectionFlattenflatten)r   r-   distribution_outputr   r    r!   r   9  s   



z&PatchTSMixerForPredictionHead.__init__c                    s     |} |} |}t|trtdd |D }n|dd} jdurBt|tr;t fdd|D }|S |d jf }|S )ar  

        Args:
            hidden_features (`torch.Tensor` of shape `(batch_size, num_patch, d_model)` in `flatten` mode
                or `(batch_size, n_vars, num_patch, d_model)` in `common_channel`/`mix_channel` mode.): Input hidden
                features.

        Returns:
            `torch.Tensor` of shape `(batch_size, prediction_length, nvars)`.

        c                 s   s    | ]	}| d dV  qdS )r   r   N)r5   r   zr    r    r!   	<genexpr>\  s    z8PatchTSMixerForPredictionHead.forward.<locals>.<genexpr>r   r   Nc                 3   s    | ]
}|d  j f V  qdS ).N)r   r   r   r    r!   r   b  s    .)r   r   r   
isinstancetupler5   r   r   hidden_featuresforecastr    r   r!   r$   K  s   





z%PatchTSMixerForPredictionHead.forwardr"   r   r    r    r   r!   r   1  s    r   c                       r   )	PatchTSMixerLinearHeadzLinear head for Classification and Regression.

    Args:
        config (`PatchTSMixerConfig`):
            Configuration.
    Nr-   c                    s   t    |j| _|j| _|jd u r|j}nd}|| _|d u r.t|j|j	 | |j
| _n||j|j	 | | _|jd u rGtjdd| _ntjdd| _t|j| _d S )Nr   r   r   )r   r   head_aggregationoutput_ranger@   r   r   r   r0   rk   num_targets
projectionr   r   r   r]   r   r^   )r   r-   r   Z
mul_factorr   r    r!   r   q  s&   


zPatchTSMixerLinearHead.__init__c                 C   s   | dd}| jdkr|d }n| jdkr|jddj}n| jdkr(|jdd}| jr0| |}| |}| |}| jdu rX| j	durXt
|| j	d	 | j	d
   | j	d
  }|S )ai  
        Args:
            hidden_features (`torch.Tensor` of shape `(batch_size x num_patch x d_model)` in `flatten` mode
                or `(batch_size x n_vars x num_patch x d_model)` in `common_channel`/`mix_channel` mode.): Input hidden
                features.

        Returns:
            `torch.Tensor` of shape `(batch_size x num_targets)`.
        r   r   Zuse_last).r   Zmax_poolr   Zavg_poolNr   r   )r5   r   maxvaluesrL   r   r^   r   r   r   r8   Zsigmoid)r   r   r    r    r!   r$     s   






&zPatchTSMixerLinearHead.forwardr"   r   r    r    r   r!   r   i  s    r   c                   @   s$   e Zd ZeZdZdZdZdd ZdS )PatchTSMixerPreTrainedModelmodelpast_valuesFc                 C   s   t |tr| jjdkrtjj|jddd dS dS t |tjtj	fr1|j
j  |jjd dS t |trG|jj
j  |jjjd dS t |tjre|jjjd| jjd |j
durg|j
j  dS dS dS )zInitialize weightsrB   rq   g?)rL   rM         ?N)r   r:   r-   rE   r   initZnormal_r=   rX   r/   ru   dataZzero_weightZfill_r,   r2   r   Zinit_std)r   moduler    r    r!   _init_weights  s    


z)PatchTSMixerPreTrainedModel._init_weightsN)	r%   r&   r'   r   Zconfig_classZbase_model_prefixZmain_input_nameZsupports_gradient_checkpointingr   r    r    r    r!   r     s    r   c                       r   )PatchTSMixerPretrainHeadzcPretraining head.

    Args:
        config (`PatchTSMixerConfig`):
            Configuration.
    r-   c                    s.   t    t|j| _t|j|j| _	d S r"   )
r   r   r   r]   r   r   r   r0   patch_lengthbase_pt_blockr3   r   r    r!   r     s   
z!PatchTSMixerPretrainHead.__init__c                 C   s   |  |}| |}|S )a  
        Args:
            hidden_features (`torch.Tensor` of shape `(batch_size x num_patch x d_model)` in `flatten` mode
                or `(batch_size x n_vars x num_patch x d_model)` in `common_channel`/`mix_channel` mode.): Input hidden
                features.

        Returns:
            `torch.Tensor` of shape `(batch_size x n_vars x num_patch x patch_length)`.
        )r   r   r   r    r    r!   r$     s   

z PatchTSMixerPretrainHead.forwardr   r    r    r   r!   r     s    r   Fr#   
mask_ratiounmasked_channel_indiceschannel_consistent_masking
mask_valuec                 C   s*  |dk s|dkrt d| d| j\}}}}| j}	t|d|  }
|r5tj|d||	d}|d|d}n	tj||||	d}tj||||	d}d|ddddd|
f< tj|dd}tj|dd}tj	|d|d	}|
dddd|}|durd|dd|ddddf< | | |}||d
 fS )a  random_masking: Mask the input considering the control variables.

    Args:
        inputs (`torch.Tensor` of shape `(batch_size, num_channels, sequence_length, num_features)`):
            The input tensor to mask.
        mask_ratio (`float`):
            Masking ratio applied to mask the input data during random pretraining. It is the number between 0 and 1.
        unmasked_channel_indices (list, *optional*):
            Indices of channels that will not be masked.
        channel_consistent_masking (bool, *optional*, defaults to `False`):
            When true, masking will be same across all channels of a timeseries. Otherwise, masking positions will vary
            across channels.
        mask_value (int, *optional*, defaults to 0):
            Define the value of masked patches for pretraining.

    Returns:
        `tuple(torch.Tensor)`: inputs_mask, masked input, same shape as input Tensor and mask tensor of shape [bs x c x
        n]
    r   r   zMask ratio z has to be between 0 and 1.deviceNr   r   )r   index.r   )rN   rZ   r   r)   r8   ZrandrepeatZonesZargsortgatherrF   masked_fillr   )r#   r   r   r   r   r   num_channelssequence_lengthnum_featuresr   Zlen_keepnoisemaskZids_shuffleZids_restoreinputs_maskr    r    r!   random_masking  s&   r   num_forecast_mask_patchesc                 C   s  t |tr|g}dd |D }| j\}}}}tj|||| jd}	g }
d}t|}t||D ](\}}|dks9||krAtd| dt|| | }|
	|||g ||7 }q-t
|
dd d	}
||k rq|
d d
 ||  |
d d
< n||kr|
d d
 ||  |
d d
< d}|
D ]\}}}|| }d|	||dd| df< |}qt|	jd }|	| }	|	dddd|}	|durd|	dd|ddddf< | |	 |}||	d fS )a  Forecast masking that masks the last K patches where K is from the num_forecast_mask_patches.
    If num_forecast_mask_patches is a list, samples in the batch will be randomly masked by numbers defined in the list.

    Parameters:
        inputs (`torch.Tensor`):
            Input of shape `(bs, num_channels, num_patch, patch_length)`
        num_forecast_mask_patches (`list`):
            Number of patches to be masked at the end of each batch sample. e.g. 4 or [3, 5].
        unmasked_channel_indices (`list`, *optional*):
            Indices of channels that are not masked.
        mask_value (`int`, *optional*, defaults to 0):
            Values in the masked patches will be filled by `mask_value`.

    Returns:
        `tuple(torch.Tensor)`: inputs_mask, masked input, same shape as inputs Tensor and Mask tensor of shape `(bs,
        num_channels , num_patch)` or `(bs, tsg1, tsg2, num_channels, num_patch)`
    c                 S   s   g | ]}d qS )r   r    r   r    r    r!   r   >  s    z$forecast_masking.<locals>.<listcomp>r   r   znum_forecast_mask_patches z6 should be greater than 0 and less than total patches.c                 S   s   | d S Nr4   r    )xr    r    r!   <lambda>P  s    z"forecast_masking.<locals>.<lambda>)keyr4   r   r   Nr   )r   r)   rZ   r8   r?   r   sumziprN   r   sortedZrandpermrF   r   r   r   )r#   r   r   r   Zforecast_mask_ratiosr   r   r   r   r   Zt_listtotal_lengthtotal_ratior   ratioZtemp_lenZbatch1Z	patch_lenr   Zbatch2permr   r    r    r!   forecast_masking$  sB   


r   c                       r+   )PatchTSMixerPatchifyz
    A class to patchify the time series sequence into different patches

    Returns:
        `torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`
    r-   c                    s   t    |j| _|j| _|j| _| j| jkr$td| j d| j dt| j| j| j | j d | _| j| j| jd   }| j| | _	d S )NzSequence length (z+) has to be greater than the patch length ()r   )
r   r   Zcontext_lengthr   r   patch_striderN   r   r@   sequence_start)r   r-   Znew_sequence_lengthr   r    r!   r   q  s   
 zPatchTSMixerPatchify.__init__r   c                 C   sp   |j d }|| jkrtd| d| j d|dd| jdddf }|jd| j| jd}|dd }|S )a!  
        Parameters:
            past_values (`torch.Tensor` of shape `(batch_size, sequence_length, num_channels)`, *required*):
                Input for patchification

        Returns:
            `torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`
        r   zInput sequence length (z%) doesn't match model configuration (rx   N)	dimensionr   stepr   )	rZ   r   rN   r   Zunfoldr   r   r5   
contiguous)r   r   r   r6   r    r    r!   r$     s   
	
zPatchTSMixerPatchify.forwardr7   r    r    r   r!   r   i  s    r   c                       r+   )PatchTSMixerMaskinga  
    Class to perform random or forecast masking.

    Parameters:
        config (`PatchTSMixerConfig`): model config
    Returns:
        x_mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`)
            Masked patched input
        mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches)`)
            Bool tensor indicating True on masked points
    r-   c                    sX   t    |j| _|j| _|j| _|j| _|j| _|j| _| jd ur*t| j| _d S d S r"   )	r   r   random_mask_ratior   	mask_typer   r   r   r   r3   r   r    r!   r     s   

zPatchTSMixerMasking.__init__rP   c                 C   sr   | j dkrt|| j| j| j| jd\}}n| j dkr(t|| j| j| jd\}}n	td| j  d|	 }||fS )a  
        Parameters:
            patch_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`, *required*):
                Patch input

        Return:
            masked_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`)
                Masked patched input
            mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches)`)
                Bool tensor indicating True on masked points

        rB   )r#   r   r   r   r   r   )r#   r   r   r   zInvalid mask type .)
r   r   r   r   r   r   r   r   rN   r   )r   rP   Zmasked_inputr   r    r    r!   r$     s$   

zPatchTSMixerMasking.forwardr7   r    r    r   r!   r     s    r   c                	       P   e Zd ZdZdef fddZdejdejdeejejejf fdd	Z	  Z
S )
PatchTSMixerStdScalerz
    Standardize features by calculating the mean and scaling along the first dimension, and then normalizes it by
    subtracting from the mean and dividing by the standard deviation.
    r-   c                    sV   t    t|dr|jnd| _t|dr|jnd| _t|dr&|j| _d S d| _d S )Nscaling_dimr   keepdimTminimum_scalegh㈵>)r   r   hasattrr  r   r  r  r3   r   r    r!   r     s   
 zPatchTSMixerStdScaler.__init__r   observed_indicatorrA   c                 C   sz   |j | j| jd}|d}|| j | j| jd| }|| | d j | j| jd| }t|| j }|| | ||fS )C  
        Parameters:
            data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`):
                input for Batch norm calculation
            observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
                Calculating the scale on the observed indicator.
        Returns:
            tuple of `torch.Tensor` of shapes
                (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`,
                `(batch_size, 1, num_input_channels)`)
        r  r   r4   )r   r   r  Z	clamp_minr8   sqrtr  )r   r   r  denominatorlocZvariancescaler    r    r!   r$     s   
"zPatchTSMixerStdScaler.forwardr%   r&   r'   r(   r   r   r8   r9   r   r$   r*   r    r    r   r!   r    s    r  c                	       r   )
PatchTSMixerMeanScalerz
    Computes a scaling factor as the weighted average absolute value along the first dimension, and scales the data
    accordingly.
    r-   c                    sl   t    t|dr|jnd| _t|dr|jnd| _t|dr#|jnd| _t|dr1|j| _d S d | _d S )Nr  r   r  Tr  绽|=default_scale)r   r   r  r  r   r  r  r  r3   r   r    r!   r     s
   
 zPatchTSMixerMeanScaler.__init__r   r  rA   c           
      C   s   ||   j| jdd}|j| jdd}|tj|dd }| jdu r:|jdd}tj|ddd}t|| }n| jt| }t|dk||}tj|| j	d}|| }	| j
sa|j| jd}|	t||fS )r  Tr  r   minNr   r   )absr   r   r8   clampr  Zsqueeze	ones_likewherer  r  
zeros_like)
r   r   r  Zts_sumZnum_observedr  Z	batch_sumZbatch_observationsr  Zscaled_datar    r    r!   r$     s   
zPatchTSMixerMeanScaler.forwardr  r    r    r   r!   r    s    r  c                
       sX   e Zd ZdZdef fddZ	ddejdeej de	ejejejf fd	d
Z
  ZS )PatchTSMixerNOPScalerz|
    Assigns a scaling factor equal to 1 along the first dimension, and therefore applies no scaling to the input data.
    r-   c                    s@   t    t|dr|jnd| _t|dr|j| _d S d| _d S )Nr  r   r  T)r   r   r  r  r   r  r3   r   r    r!   r   7  s   
 zPatchTSMixerNOPScaler.__init__Nr   r  rA   c                 C   sB   t j|ddj| j| jd}t j|ddj| j| jd}|||fS )a  
        Parameters:
            data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`):
                input for Batch norm calculation
        Returns:
            tuple of `torch.Tensor` of shapes
                (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`,
                `(batch_size, 1, num_input_channels)`)
        FrC   )r   r  )r8   r  rL   r   r  r  )r   r   r  r  r  r    r    r!   r$   <  s   
zPatchTSMixerNOPScaler.forwardr"   )r%   r&   r'   r(   r   r   r8   r9   r   r   r$   r*   r    r    r   r!   r  2  s    r  c                   @   s:   e Zd ZU dZdZeej ed< dZ	ee
ej  ed< dS )PatchTSMixerEncoderOutputa  
    Base class for `PatchTSMixerEncoderOutput`, with potential hidden states.

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches, d_model)`):
            Hidden-state at the output of the last layer of the model.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*):
            Hidden-states of the model at the output of each layer.
    Nlast_hidden_stater   )r%   r&   r'   r(   r  r   r8   FloatTensor__annotations__r   r   r    r    r    r!   r  M  s   
 
r  c                       s\   e Zd ZdZdef fddZe		ddejde	e
 d	e	e
 d
eeef fddZ  ZS )PatchTSMixerEncoderz
    Encoder for PatchTSMixer which inputs patched time-series and outputs patched embeddings.

    Args:
        config (`PatchTSMixerConfig`):
            Configuration.
    r-   c                    sb   t  | |j| _t|j|j| _|jrt	|d| _
nd | _
t|d| _|jr/|   d S d S )Nr   )r   r   use_return_dictr   r   r   r0   patcherr;   r:   positional_encoderr   mlp_mixer_encoder	post_initr3   r   r    r!   r   f  s   zPatchTSMixerEncoder.__init__FNr   r   return_dictrA   c                 C   sh   |dur|n| j }| |}| jdur| |}| j||d\}}|s.tdd ||fD S t||dS )a  
        past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`):
            Context values of the time series. For a pretraining task, this denotes the input time series to
            predict the masked portion. For a forecasting task, this denotes the history/past time series values.
            Similarly, for classification or regression tasks, it denotes the appropriate context values of the
            time series.

            For univariate time series, `num_input_channels` dimension should be 1. For multivariate time series,
            it is greater than 1.

        Returns:
            `torch.FloatTensor` of shape `(batch_size, n_vars, num_patches, d_model)`
        N)r   c                 s       | ]}|V  qd S r"   r    r   vr    r    r!   r     
    
z.PatchTSMixerEncoder.forward.<locals>.<genexpr>)r  r   )r  r  r   r!  r   r  )r   r   r   r#  Zpatchesr  r   r    r    r!   r$   v  s   


zPatchTSMixerEncoder.forward)FN)r%   r&   r'   r(   r   r   r   r8   r9   r   r   r   r   r  r$   r*   r    r    r   r!   r  ]  s    
r  c                   @   s   e Zd ZU dZdZeej ed< dZ	ee
ej  ed< dZeej ed< dZeej ed< dZeej ed< dZeej ed< dS )	PatchTSMixerModelOutputa  
    Base class for model's outputs, with potential hidden states.

    Args:
        last_hidden_state (`torch.FloatTensor`  of shape `(batch_size, num_channels, num_patches, d_model)`):
            Hidden-state at the output of the last layer of the model.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*):
            Hidden-states of the model at the output of each layer.
        patch_input (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches, patch_length)`):
            Patched input data to the model.
        mask: (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches)`,*optional*):
            Bool Tensor indicating True in masked patches and False otherwise.
        loc: (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`,*optional*):
            Gives the mean of the context window per channel. Used for revin denorm outside the model, if revin
            enabled.
        scale: (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`,*optional*):
            Gives the std dev of the context window per channel. Used for revin denorm outside the model, if revin
            enabled.
    Nr  r   rP   r   r  r  )r%   r&   r'   r(   r  r   r8   r  r  r   r   rP   r   r  r  r    r    r    r!   r(    s   
 r(  z=
    The PatchTSMixer Model for time-series forecasting.
    )Zcustom_introc                       sb   e Zd Zddedef fddZe			ddejde	ej d	e	e d
e	e de
f
ddZ  ZS )PatchTSMixerModelFr-   
mask_inputc                    s   t  | |j| _t|| _t|| _|du rt|| _nd| _|j	dkr,t
|| _n|j	dks6|j	du r<t|| _nt|| _|jrJ|   dS dS )z
        mask_input (bool, *optional*, defaults to `False`):
            Whether to mask the input using the [`PatchTSMixerMasking`] module.
        TNrL   rM   )r   r   r  r  encoderr   patchingr   maskingrz   r  scalerr  r  r"  )r   r-   r*  r   r    r!   r     s   



zPatchTSMixerModel.__init__Nr   observed_maskr   r#  rA   c                 C   s   |dur|n| j }d}|du rt|}| ||\}}}| |}	|	}
| jdur0| |	\}
}| j|
||d}t|trAt	| }|sTtdd |j
|j|	|||fD S t|j
|j|	|||dS )a  
        past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`):
            Context values of the time series. For a pretraining task, this denotes the input time series to predict
            the masked portion. For a forecasting task, this denotes the history/past time series values. Similarly,
            for classification or regression tasks, it denotes the appropriate context values of the time series.

            For univariate time series, `num_input_channels` dimension should be 1. For multivariate time series, it is
            greater than 1.
        observed_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
            Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
            in `[0, 1]`:
            - 1 for values that are **observed**,
            - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
        Nr   r#  c                 s   r$  r"   r    r%  r    r    r!   r     r'  z,PatchTSMixerModel.forward.<locals>.<genexpr>)r  r   rP   r   r  r  )r  r8   r  r.  r,  r-  r+  r   r   r  r  r   r(  )r   r   r/  r   r#  r   Zscaled_past_valuesr  r  Z	patched_xZ	enc_inputZencoder_outputr    r    r!   r$     sD   



zPatchTSMixerModel.forwardr   )NFN)r%   r&   r'   r   r   r   r   r8   r9   r   r(  r$   r*   r    r    r   r!   r)    s"    r)  c                   @   ^   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eej ed< dZeeej  ed< dS ) PatchTSMixerForPreTrainingOutputa  
    Output type of [`PatchTSMixerForPreTrainingOutput`].

    Args:
        prediction_outputs (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, patch_length)`):
            Prediction output from the pretrain head.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*):
            Hidden-states of the model at the output of each layer.
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, d_model)`):
            Backbone embeddings before passing through the head.
        loss (*optional*, returned when `y` is provided, `torch.FloatTensor` of shape `()`):
            Total loss
    Nlossprediction_outputsr  r   r%   r&   r'   r(   r3  r   r8   r  r  r4  r  r   r   r    r    r    r!   r2  &     
 r2  c                       f   e Zd ZdZdef fddZe				ddejd	e	ej d
e	e
 de
de	e
 defddZ  ZS )PatchTSMixerForPretrainingz
    `PatchTSMixer` for mask pretraining.

    Args:
        config (`PatchTSMixerConfig`):
            Configuration.

    Returns:
        `None`.
    r-   c                    sL   t  | t|dd| _t|d| _|j| _|j| _|jr$|   d S d S )NT)r*  r   )	r   r   r)  r   r   headmasked_lossr  r"  r3   r   r    r!   r   H  s   z#PatchTSMixerForPretraining.__init__NFTr   r/  r   return_lossr#  rA   c           
      C   s   |dur|n| j }| jdu rtjjdd}ntjjdd}| j||||d}t|tr/t| }| 	|j
}|du r@|||j}	nd}	| jdu r]|	dur]|	jdd|j  |j d	  }	|sntd
d |	||j
|jfD S t|	||j
|jdS )aT  
        past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`):
            Context values of the time series. For a pretraining task, this denotes the input time series to predict
            the masked portion. For a forecasting task, this denotes the history/past time series values. Similarly,
            for classification or regression tasks, it denotes the appropriate context values of the time series.

            For univariate time series, `num_input_channels` dimension should be 1. For multivariate time series, it is
            greater than 1.
        observed_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
            Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
            in `[0, 1]`:
            - 1 for values that are **observed**,
            - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
        return_loss (`bool`,  *optional*):
            Whether to return the loss in the `forward` call.
        NTnoneZ	reductionrL   r/  r   r#  r   r   r  c                 s   r$  r"   r    r%  r    r    r!   r     r'  z5PatchTSMixerForPretraining.forward.<locals>.<genexpr>r3  r4  r  r   )r  r:  r8   r   MSELossr   r   r   r(  r9  r  rP   rL   r   r   r   r2  )
r   r   r/  r   r;  r#  r3  model_outputZx_hatloss_valr    r    r!   r$   S  s@   

$
z"PatchTSMixerForPretraining.forwardNFTN)r%   r&   r'   r(   r   r   r   r8   r9   r   r   r2  r$   r*   r    r    r   r!   r8  <  s*    r8  c                   @   s   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eej ed< dZeeej  ed< dZeej ed< dZeej ed< dS )	PatchTSMixerForPredictionOutputa  
    Output type of [`PatchTSMixerForPredictionOutput`].

    Args:
        prediction_outputs (`torch.FloatTensor` of shape `(batch_size, prediction_length, num_input_channels)`):
            Prediction output from the forecast head.
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, d_model)`):
            Backbone embeddings before passing through the head.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*):
            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
        loss (*optional*, returned when `y` is provided, `torch.FloatTensor` of shape `()`):
            Total loss.
        loc (`torch.FloatTensor`, *optional* of shape `(batch_size, 1, num_input_channels)`):
            Input mean
        scale (`torch.FloatTensor`, *optional* of shape `(batch_size, 1, num_input_channels)`):
            Input std dev

    Nr3  r4  r  r   r  r  )r%   r&   r'   r(   r3  r   r8   r  r  r4  r  r   r   r  r  r    r    r    r!   rD    s   
 rD  c                   @   $   e Zd ZU dZdZeej ed< dS )"SamplePatchTSMixerPredictionOutputa9  
    Base class for time series model's predictions outputs that contains the sampled values from the chosen
    distribution.

    Args:
        sequences (`torch.FloatTensor` of shape `(batch_size, num_samples, prediction_length, number_channels)`):
            Sampled values from the chosen distribution.
    N	sequences	r%   r&   r'   r(   rG  r   r8   r  r  r    r    r    r!   rF       
 	rF  c                   @   rE  )"SamplePatchTSMixerRegressionOutputa$  
    Base class for time series model's predictions outputs that contains the sampled values from the chosen
    distribution.

    Args:
        sequences (`torch.FloatTensor` of shape `(batch_size, num_samples, num_targets)`
                Sampled values from the chosen distribution.
    NrG  rH  r    r    r    r!   rJ    rI  rJ  inputtargetrA   c                 C   s   |  | S )zc
    Computes the negative log likelihood loss from input distribution with respect to target.
    )Zlog_prob)rK  rL  r    r    r!   nll  s   rM  input_tensorweightsc                 C   sr   |dur3t |dk| | t | }t j|r|j|dn| dd}|r-|j|d| S | | S | j|dS )aj  
    Computes the weighted average of a given tensor across a given `dim`, masking values associated with weight zero,
    meaning instead of `nan * 0 = nan` you will get `0 * 0 = 0`.

    Args:
        input_tensor (`torch.FloatTensor`):
            Input tensor, of which the average must be computed.
        weights (`torch.FloatTensor`, *optional*):
            Weights tensor, of the same shape as `input_tensor`.
        dim (`int`, *optional*):
            The dim along which to average `input_tensor`.

    Returns:
        `torch.FloatTensor`: The tensor with values averaged along the specified `dim`.
    Nr   r   r   r  )r8   r  r  r  r   rL   )rN  rO  r   Zweighted_tensorZsum_weightsr    r    r!   weighted_average  s
   " rP  c                       s   e Zd ZdZdef fddZe					ddejd	e	ej d
e	ej de	e
 de
de	e
 defddZ	ddejd	e	ej defddZ  ZS )PatchTSMixerForPredictionz
    `PatchTSMixer` for forecasting application.

    Args:
        config (`PatchTSMixerConfig`):
            Configuration.

    Returns:
        `None`.
    r-   c                    s   t  | |j| _|j| _|j| _|j| _|jdkrd | _n#|j}tt	t
d}||jd }|d ur:||d| _ntd|j t|| _t|| jd| _|jrX|   d S d S )NmseZ	student_tnormalnegative_binomialr   Unknown distribution output r-   r   )r   r   r3  r  r   num_parallel_samplesr   r   r   r
   r	   getrN   r)  r   r   r9  r"  )r   r-   r   distribution_output_mapoutput_classr   r    r!   r     s0   

z"PatchTSMixerForPrediction.__init__NFTr   r/  future_valuesr   r;  r#  rA   c                 C   s  | j dkrtjdd}n| j dkrt}ntd|dur|n| j}| j||||d}t|tr3t	| }| 
|j}	d}
| jdur| jro| jj|	|jd| jf |jd| jf d	}|durn|d
u rn|||d| jf }
t|
}
nZ|	|jd| jf  |jd| jf  }	|dur|d
u r||	|d| jf }
n5| jr| jj|	|j|jd	}|dur|d
u r|||}
t|
}
n|	|j |j }	|dur|d
u r||	|}
| jdur|jd| jf }|jd| jf }n|j}|j}|stdd |
|	|j|j||fD S t|
|	|j|j||dS )a  
        past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`):
            Context values of the time series. For a pretraining task, this denotes the input time series to predict
            the masked portion. For a forecasting task, this denotes the history/past time series values. Similarly,
            for classification or regression tasks, it denotes the appropriate context values of the time series.

            For univariate time series, `num_input_channels` dimension should be 1. For multivariate time series, it is
            greater than 1.
        observed_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
            Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
            in `[0, 1]`:
            - 1 for values that are **observed**,
            - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
        future_values (`torch.FloatTensor` of shape `(batch_size, target_len, num_input_channels)` for forecasting,:
            `(batch_size, num_targets)` for regression, or `(batch_size,)` for classification, *optional*):
            Target values of the time series, that serve as labels for the model. The `future_values` is what the
            Transformer needs during training to learn to output, given the `past_values`. Note that, this is NOT
            required for a pretraining task.

            For a forecasting task, the shape is be `(batch_size, target_len, num_input_channels)`. Even if we want
            to forecast only specific channels by setting the indices in `prediction_channel_indices` parameter,
            pass the target data with all channels, as channel Filtering for both prediction and target will be
            manually applied before the loss computation.
        return_loss (`bool`,  *optional*):
            Whether to return the loss in the `forward` call.
        rR  rL   r=  rM  2Invalid loss function: Allowed values: mse and nllNr>  .r  r  Tc                 s   r$  r"   r    r%  r    r    r!   r     r'  z4PatchTSMixerForPrediction.forward.<locals>.<genexpr>)r3  r4  r  r   r  r  )r3  r   r@  rM  rN   r  r   r   r   r(  r9  r  r   r   distributionr  r  rP  r   rD  )r   r   r/  r\  r   r;  r#  r3  rA  y_hatrB  r_  r  r  r    r    r!   r$   !  s   
$






z!PatchTSMixerForPrediction.forwardc                    s\   | j }| |d|dd}| jj|j|j|jd  fddt|D }tj|dd}t	|d	S )
a  
        Generate sequences of sample predictions from a model with a probability distribution head.

        Args:
            past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
                Past values of the time series that serves as context in order to predict the future.

            observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
                Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
                in `[0, 1]`:

                - 1 for values that are **observed**,
                - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).

        Return:
            [`SamplePatchTSMixerPredictionOutput`] where the outputs `sequences` tensor will have shape `(batch_size,
            number of samples, prediction_length, num_input_channels)`.
        NF)r   r\  r/  r   r^  c                       g | ]}   qS r    sampler   r_  r    r!   r     s    z6PatchTSMixerForPrediction.generate.<locals>.<listcomp>r   r   rG  )
rX  r   r_  r4  r  r  r   r8   stackrF  )r   r   r/  rX  outputssamplesr    rd  r!   generate  s   	
z"PatchTSMixerForPrediction.generate)NNFTNr"   )r%   r&   r'   r(   r   r   r   r8   r9   r   r   rD  r$   rF  ri  r*   r    r    r   r!   rQ    s@     |rQ  c                   @   r1  )-PatchTSMixerForTimeSeriesClassificationOutputa  
    Output type of [`PatchTSMixerForTimeSeriesClassificationOutput`].

    Args:
        prediction_outputs (`torch.FloatTensor` of shape `(batch_size, num_labels)`):
            Prediction output from the classification head.
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, d_model)`):
            Backbone embeddings before passing through the head.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*):
            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
        loss (*optional*, returned when `y` is provided, `torch.FloatTensor` of shape `()`):
            Total loss.
    Nr3  r4  r  r   r5  r    r    r    r!   rj    r6  rj  c                       r7  )'PatchTSMixerForTimeSeriesClassificationz
    `PatchTSMixer` for classification application.

    Args:
        config (`PatchTSMixerConfig`):
            Configuration.

    Returns:
        `None`.
    r-   c                    sd   t  | t|| _t|d| _|j| _|jdv r$t|j	|j
d| _nd | _|jr0|   d S d S )Nr   rM   rL   Tr0   r@   )r   r   r)  r   r   r9  r  rz   InjectScalerStatistics4Dr0   r@   inject_scaler"  r3   r   r    r!   r     s   

z0PatchTSMixerForTimeSeriesClassification.__init__NFTr   target_valuesr   r;  r#  rA   c           
      C   s   t j }|dur|n| j}| j|||d}t|trt| }| jdur0| j|j	|j
|jd|_	| |j	}|durD|du rD|||}	nd}	|sWtdd |	||j	|jfD S t|	||j	|jdS )a  
        past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`):
            Context values of the time series. For a pretraining task, this denotes the input time series to predict
            the masked portion. For a forecasting task, this denotes the history/past time series values. Similarly,
            for classification or regression tasks, it denotes the appropriate context values of the time series.

            For univariate time series, `num_input_channels` dimension should be 1. For multivariate time series, it is
            greater than 1.
        target_values (`torch.FloatTensor` of shape `(batch_size, target_len, num_input_channels)` for forecasting,
            `(batch_size, num_targets)` for regression, or `(batch_size,)` for classification, *optional*):
            Target
            values of the time series, that serve as labels for the model. The `target_values` is what the
            Transformer needs during training to learn to output, given the `past_values`. Note that, this is NOT
            required for a pretraining task.

            For a forecasting task, the shape is be `(batch_size, target_len, num_input_channels)`. Even if we want
            to forecast only specific channels by setting the indices in `prediction_channel_indices` parameter,
            pass the target data with all channels, as channel Filtering for both prediction and target will be
            manually applied before the loss computation.

            For a classification task, it has a shape of `(batch_size,)`.

            For a regression task, it has a shape of `(batch_size, num_targets)`.
        return_loss (`bool`, *optional*):
            Whether to return the loss in the `forward` call.
        Nr0  r^  Tc                 s   r$  r"   r    r%  r    r    r!   r   =  r'  zBPatchTSMixerForTimeSeriesClassification.forward.<locals>.<genexpr>r?  )r8   r   ZCrossEntropyLossr  r   r   r   r(  ro  r  r  r  r9  r   rj  )
r   r   rp  r   r;  r#  r3  rA  r`  rB  r    r    r!   r$     sB   
$


z/PatchTSMixerForTimeSeriesClassification.forwardrC  )r%   r&   r'   r(   r   r   r   r8   r9   r   r   rj  r$   r*   r    r    r   r!   rk    s*    rk  c                   @   r1  )PatchTSMixerForRegressionOutputa  
    Output type of [`PatchTSMixerForRegressionOutput`].

    Args:
        regression_outputs (`torch.FloatTensor` of shape `(batch_size, num_targets)`):
            Prediction output from the regression head.
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, d_model)`):
            Backbone embeddings before passing through the head.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*):
            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
        loss (*optional*, returned when `y` is provided, `torch.FloatTensor` of shape `()`):
            Total loss.
    Nr3  regression_outputsr  r   )r%   r&   r'   r(   r3  r   r8   r  r  rr  r  r   r   r    r    r    r!   rq  O  r6  rq  c                       sH   e Zd Zddededef fddZdejdejd	ejfd
dZ  ZS )rn  r4   r0   r@   	expansionc                    s`   t    t|d || | _t|| || _tdd| | _td| d| _|| _d S r   )	r   r   r   r   inverse_trans_expansioninverse_trans_compressionmap_scale_expansionmap_scale_compressionr@   )r   r0   r@   rs  r   r    r!   r   f  s   

z!InjectScalerStatistics4D.__init__r#   r  r  c                 C   s   | dd}|d}|dd| jd}| dd}|d}|dd| jd}tj||gdd}| |}| |}tj||gdd}| |}| 	|}|S )a  
        Args:
            inputs (`torch.Tensor` of shape `(batch_size, num_input_channels, num_patch, d_model)`)
            loc (`torch.Tensor` of shape `(batch_size, 1, num_input_channels)`)
            scale (`torch.Tensor` of shape `(batch_size, 1, num_input_channels)`)
        Returns:
            `torch.Tensor` of shape `(batch_size, num_input_channels, num_patch, d_model)`
        r   r   r   r   )
r5   rF   r   r@   r8   catrv  rw  rt  ru  )r   r#   r  r  rL   stdevZconcat_statsr    r    r!   r$   o  s   






z InjectScalerStatistics4D.forward)r4   )	r%   r&   r'   r)   r   r8   r9   r$   r*   r    r    r   r!   rn  e  s    $	rn  c                       sz   e Zd ZdZdef fddZe				ddejd	e	ej d
e	e
 de
de	e
 defddZdejdefddZ  ZS )PatchTSMixerForRegressionz
    `PatchTSMixer` for regression application.

    Args:
        config (`PatchTSMixerConfig`):
            Configuration.

    Returns:
        `None`.
    r-   c                    s   t  | t|| _|j| _|j| _|j| _|j| _|jdkr$d | _n tt	t
d}||j}|d ur<||jd| _ntd|j |jdv rSt|j|jd| _nd | _t|| jd| _|jrg|   d S d S )NrR  rS  r   rV  rl  rm  rW  )r   r   r)  r   r3  r   r  rX  r   r
   r	   rY  r   rN   rz   rn  r0   r@   ro  r   r9  r"  )r   r-   rZ  r[  r   r    r!   r     s4   


z"PatchTSMixerForRegression.__init__NFTr   rp  r   r;  r#  rA   c                    sD   j dkrtjdd}n j dkrt}ntd|dur|n j} j|||d}t|tr2t	| } j
durC j
|j|j|jd|_ |j}|dur|d	u r jr jd
krdt|dk rdtd j|}	t fdd|D }||	|}
t|
}
n|||}
nd}
|stdd |
||j|jfD S t|
||j|jdS )a  
        past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`):
            Context values of the time series. For a pretraining task, this denotes the input time series to predict
            the masked portion. For a forecasting task, this denotes the history/past time series values. Similarly,
            for classification or regression tasks, it denotes the appropriate context values of the time series.

            For univariate time series, `num_input_channels` dimension should be 1. For multivariate time series, it is
            greater than 1.
        target_values (`torch.FloatTensor` of shape `(batch_size, target_len, num_input_channels)` for forecasting,
            `(batch_size, num_targets)` for regression, or `(batch_size,)` for classification, *optional*):
            Target values of the time series, that serve as labels for the model. The `target_values` is what the
            Transformer needs during training to learn to output, given the `past_values`. Note that, this is NOT
            required for a pretraining task.

            For a forecasting task, the shape is be `(batch_size, target_len, num_input_channels)`. Even if we want
            to forecast only specific channels by setting the indices in `prediction_channel_indices` parameter,
            pass the target data with all channels, as channel Filtering for both prediction and target will be
            manually applied before the loss computation.

            For a classification task, it has a shape of `(batch_size,)`.

            For a regression task, it has a shape of `(batch_size, num_targets)`.
        return_loss (`bool`, *optional*):
            Whether to return the loss in the `forward` call.
        rR  rL   r=  rM  r]  Nr0  r^  TrU  r   zDtarget_values cannot be negative for negative_binomial distribution.c                    s   g | ]
}| d  jjqS )r   )r   r-   r   )r   itemr   r    r!   r     s    z5PatchTSMixerForRegression.forward.<locals>.<listcomp>c                 s   r$  r"   r    r%  r    r    r!   r     r'  z4PatchTSMixerForRegression.forward.<locals>.<genexpr>)r3  rr  r  r   )r3  r   r@  rM  rN   r  r   r   r   r(  ro  r  r  r  r9  r   r8   any	Exceptionr_  rP  r   rq  )r   r   rp  r   r;  r#  r3  rA  r`  r_  rB  r    r   r!   r$     sX   
#





z!PatchTSMixerForRegression.forwardc                    s^   | j }| |ddd}| j|j  fddt|D }tj|ddd|| jj	}t
|d	S )
a
  
        Generate sequences of sample predictions from a model with a probability distribution head.

        Args:
            past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
                Past values of the time series that serves as context in order to predict the target values.

        Return:
            [`SamplePatchTSMixerRegressionOutput`] where the outputs `sequences` tensor will have shape `(batch_size,
            number of samples, num_targets)`.
        NF)r   rp  r   c                    ra  r    rb  r   rd  r    r!   r   :  s    z6PatchTSMixerForRegression.generate.<locals>.<listcomp>r   r   r   re  )rX  r   r_  rr  r   r8   rf  r   r-   r   rJ  )r   r   rX  rg  rh  r    rd  r!   ri    s   

z"PatchTSMixerForRegression.generaterC  )r%   r&   r'   r(   r   r   r   r8   r9   r   r   rq  r$   rJ  ri  r*   r    r    r   r!   rz    s4    '\rz  )r   r)  r8  rQ  rk  rz  )NFr   )Nr   )NN)Mr(   rH   dataclassesr   typingr   r   r   r8   Ztorch.nnr   Ztransformers.modeling_utilsr   Ztransformers.utilsr   Ztime_series_utilsr	   r
   r   utilsr   r   Zutils.deprecationr   Zconfiguration_patchtsmixerr   Z
get_loggerr%   r{   Moduler   r,   r:   rS   r[   re   rp   r   r   r   r   r   r   r   r   r9   r   listr   r)   r   r   r   r   r  r  r  r  r  r(  r)  r2  r8  rD  rF  rJ  distributionsDistributionrM  rP  rQ  rj  rk  rq  rn  rz  __all__r    r    r    r!   <module>   s   
'11 E-&)8G"
>

E1=$7Ea_" Wn( 7