a
    hTm                     @   s  d dl mZmZmZmZ d dlZd dlmZ ddlmZm	Z	m
Z
mZmZmZmZmZ ddgZG dd deZd	d
e de
  d e_ee ee eee  eee  eee  ee ee ee eeeef eeee eeedddZeeeeej eej ef eeee   f dddZee ee eee  eee  eee  ee ee ee eeeef eeee eeedddZeeddee ee eee  eee  eee  ee ee ee ee eeeeef eeeeedddZdS )    )castOptionalTYPE_CHECKINGUnionN)Tensor   )_disable_dynamo_if_unsupported_get_scalar_dtype_maximize_doc_params_doc
_to_scalar	OptimizerParamsTTensorListList	Adafactor	adafactorc                       sx   e Zd Zddddeeeef eeee ef eeee	 e	d	 fd
dZ
 fddZdd Ze dddZ  ZS )r   {Gz?皙鿩NgMbP?      ?        NF)foreachmaximize)paramslrbeta2_decayepsdweight_decayr   r   c          
   	      s   t |tr| dkrtdd|ks4td| d|ksJtd| |d d urtd|d ksttd|d  d|d kstd|d  d	|kstd
| d|kstd| t|||||||d}	t ||	 d S )Nr   zTensor lr must be 1-elementr   z%Learning rate should be >= 0 but is: z#beta2_decay should be <= 0 but is: r   z epsilon1 should be >= 0 but is: z epsilon2 should be >= 0 but is: r   z,Clipping threshold d should be >= 1 but is: z$weight_decay should be >= 0 but is: )r   r   r   r   r   r   r   )
isinstancer   numel
ValueErrordictsuper__init__)
selfr   r   r   r   r   r   r   r   defaults	__class__ D/var/www/auris/lib/python3.9/site-packages/torch/optim/_adafactor.pyr$      s0    	zAdafactor.__init__c                    s~   t  | | jD ]f}|dd  |d D ]L}| j|g }t|dkr*t|d s*t	|d }tj
|t d|d< q*qd S )Nr   r   r   stepdtype)r#   __setstate__param_groups
setdefaultstategetlentorchZ	is_tensorfloattensorr	   )r%   r1   grouppZp_stateZstep_valr'   r)   r*   r.   =   s    
zAdafactor.__setstate__c                 C   s4  |d D ]$}|j d u rqt|r,td|j jr<td|| ||j  | j| }	t|	dkrtjdt	 d|	d< |j 
 dkrt|j j}
d|
d	< |j |
|	d
< t|j j}d|d< |j ||	d< ntj|j tjd|	d< ||	d
d  ||	dd  ||	dd  ||	d  qdS )Nr   z-Adafactor does not support complex parametersz+Adafactor does not support sparse gradientsr   r   r,   r+   r   row_varcol_var)Zmemory_formatvarianceF)gradr4   Z
is_complexRuntimeErrorZ	is_sparseappendr1   r3   r6   r	   dimlistshapeZ	new_zerosZ
zeros_likeZpreserve_formatr2   )r%   r7   params_with_gradgradsrow_varscol_vars	variancesstate_stepsr8   r1   Z	row_shapeZ	col_shaper)   r)   r*   _init_groupG   s6    





zAdafactor._init_groupc                 C   s   |    d}|durBt  | }W d   n1 s80    Y  | jD ]}g }g }g }g }g }g }	|d \}
}| |||||||	}t||||||	|d |d |d |d |
||d |d t| d	dt| d
d|d qH|S )zPerform a single optimization step.

        Args:
            closure (Callable, optional): A closure that reevaluates the model
                and returns the loss.
        Nr   r   r   r   r   r   r   
grad_scale	found_inf)r   r   r   r   eps1eps2r   r   rK   rL   has_complex)Z _cuda_graph_capture_health_checkr4   Zenable_gradr/   rJ   r   getattr)r%   closureZlossr7   rD   rE   rF   rG   rH   rI   rM   rN   rO   r)   r)   r*   r+   y   sT    
$



zAdafactor.step)r   r   r   r   r   )N)__name__
__module____qualname__r   r   r5   r   tupler   boolr$   r.   rJ   r4   Zno_gradr+   __classcell__r)   r)   r'   r*   r      s,        	
%
2a  Implements Adafactor algorithm.

    .. math::
        \begin{aligned}
            &\rule{110mm}{0.4pt}                                                                 \\
            &\textbf{input}      : \gamma \text{(lr)}, \: \tau
                \text{(}\beta_2\text{ decay)}, \: \theta_0 \text{(params)}, \: f(\theta) \text{(objective)},    \\
            &\hspace{15mm}      \: \epsilon_1, \epsilon_2 \text{ (epsilons)}, \: d \text{(clipping threshold)}, \\
            &\hspace{15mm}      \: \lambda \text{(weight decay)},
                \: \textit{maximize}                                                             \\
            &\textbf{initialize} : \: R_0 \leftarrow 0 \text{ (second moment row factor)},       \\
            &\hspace{23mm} \: C_0 \leftarrow 0 \text{ (second moment col factor)},               \\
            &\hspace{23mm} \: \widehat{V}_0 \leftarrow 0 \text{ (second moment for vectors)}     \\[-1.ex]
            &\rule{110mm}{0.4pt}                                                                 \\
            &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do}                         \\

            &\hspace{5mm}\textbf{if} \: \textit{maximize}:                                       \\
            &\hspace{10mm}G_t           \leftarrow   -\nabla_{\theta} f_t (\theta_{t-1})         \\
            &\hspace{5mm}\textbf{else}                                                           \\
            &\hspace{10mm}G_t           \leftarrow   \nabla_{\theta} f_t (\theta_{t-1})          \\
            &\hspace{5mm}\widehat{\beta}_{2_t} \leftarrow 1 - t^{\tau}                           \\
            &\hspace{5mm}\rho_t         \leftarrow min(lr, \frac{1}{\sqrt{t}})                   \\
            &\hspace{5mm}\alpha_t       \leftarrow max(\epsilon_2,
                \text{RMS}(\theta_{t-1}))\rho_t                                                  \\
            &\hspace{5mm}\theta_t       \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1}    \\
            &\hspace{5mm}\textbf{if} \: \text{dim}(G_t) > 1:                                     \\
            &\hspace{10mm}R_t           \leftarrow \widehat{\beta}_{2_t}R_{t-1}+
                (1-\widehat{\beta}_{2_t})(G_t \odot G_t) \cdot 1_m                               \\
            &\hspace{10mm}C_t           \leftarrow \widehat{\beta}_{2_t}C_{t-1}+
                (1-\widehat{\beta}_{2_t}) 1^\top_n \cdot (G_t \odot G_t)                         \\
            &\hspace{10mm}\widehat{V}_t \leftarrow
                \frac{R_t \cdot C_t}{max(1^\top_n \cdot R_t, \epsilon_1)}                        \\
            &\hspace{5mm}\textbf{else}                                                           \\
            &\hspace{10mm}\widehat{V}_t \leftarrow \widehat{\beta}_{2_t}\widehat{V}_{t-1}+
                (1-\widehat{\beta}_{2_t}) \cdot (G_t \odot G_t)                                  \\
            &\hspace{5mm}U_t            \leftarrow
                \frac{G_t}{max(\sqrt{\widehat{V}_t}, \epsilon_1)}                                \\
            &\hspace{5mm}\widehat{U}_t  \leftarrow \frac{U_t}{max(1, \frac{\text{RMS}(U_t)}{d})} \\
            &\hspace{5mm}\theta_t       \leftarrow \theta_{t-1} - \alpha_t \widehat{U}_t         \\

            &\rule{110mm}{0.4pt}                                                          \\[-1.ex]
            &\bf{return} \:  \theta_t                                                     \\[-1.ex]
            &\rule{110mm}{0.4pt}                                                          \\[-1.ex]
        \end{aligned}

    For further details regarding the algorithm we refer to `Adafactor: Adaptive Learning Rates with Sublinear Memory Cost`_.
    z
    Args:
        a  
        lr (float, Tensor, optional): unlike other optimizers, Adafactor does not require a
            learning rate, and Noam Shazeer and Mitchell Stern do not use lr at all.
            Deviating from the paper, this implementation uses lr for applying weight
            decay and as the maximum value for relative step size rho_t. Note that in
            the paper, a constant of 0.01 is used as the maximum value for relative
            step size, and so we set 0.01 as the default value. (default: 1e-2)
        beta2_decay (float, optional): the decay rate of beta2. beta2 standardly refers
            to the coefficient used for computing the running average of the gradient
            squared. (default: -0.8)
        eps (Tuple[float, float], optional): epsilon1 is the term added to the denominator
            of the update calculation to improve numerical stability. This use of epsilon1
            deviates from the algorithm written in the paper! See note below for more details.
            epsilon2 is the term used to avoid having too small a weight update when applying
            parameter scaling. (default: (None, 1e-3))
        d (float, optional): the clipping threshold, used to avoid larger-than-desired
            updates.
        weight_decay (float, optional): weight decay coefficient (default: 1e-2)
        foreach (bool, optional): whether foreach implementation of optimizer is used. Note
            that the foreach implementation uses ~ sizeof(params) more peak memory than the
            for-loop version due to the intermediates being a tensorlist vs just one tensor.
            As Adafactor is commonly used when memory is prohibitive, Adafactor will default
            to the slower single tensor for-loop implementation unless this flag is explicitly
            True. This behavior is contrary to other optimizers, which will attempt defaulting
            to foreach on CUDA for faster runtime. (default: None)
        a(  
    .. Note::
        The implementation of Adafactor subtly differs from Noam Shazeer and Mitchell Stern
        and implementations in some other frameworks with its use of learning rate and
        :math:`\epsilon_1`.

        Regarding the learning rate hyperparameter: Noam Shazeer and Mitchell Stern do not
        use lr at all, as the stated algorithm uses :math:`\rho_t` and update clipping to
        affect the step size.

        This implementation allows `lr` to influence the maximum value for :math:`\rho_t`:

        .. math::
            \begin{aligned}
                &\hspace{5mm}\rho_t \leftarrow min(lr, \frac{1}{\sqrt{t}})
            \end{aligned}

        This differs from Noam Shazeer and Mitchell Stern, who use a constant of 0.01 as
        the maximum value of :math:`\rho_t`

        .. math::
            \begin{aligned}
                &\hspace{5mm}\rho_t \leftarrow min(0.01, \frac{1}{\sqrt{t}})
            \end{aligned}

        Noam Shazeer and Mitchell Stern do not enforce an opinion on how weight decay should
        be computed, and so we use the learning rate as a coefficient for decoupled weight
        decay, similar to what is suggested in `Decoupled Weight Decay Regularization`_.

        Regarding the use of :math:`\epsilon_1`: The implementation attempts to replicate the
        presumed intention of Noam Shazeer and Mitchell Stern to use :math:`\epsilon_1` as
        a stabilizing term when the squared gradient becomes small.

        This stabilization can be written as

        .. math::
            \begin{aligned}
                &\hspace{5mm}R_t \leftarrow \widehat{\beta}_{2_t}R_{t-1}+
                    (1-\widehat{\beta}_{2_t})(G_t \odot G_t + 1_n \cdot 1^\top_m) \cdot 1_m          \\
                &\hspace{5mm}C_t \leftarrow \widehat{\beta}_{2_t}C_{t-1}+
                    (1-\widehat{\beta}_{2_t}) 1^\top_n \cdot (G_t \odot G_t + 1_n \cdot 1^\top_m)    \\
                &\hspace{5mm}\widehat{V}_t \leftarrow
                    \frac{R_t \cdot C_t}{max(1^\top_n \cdot R_t, \epsilon_1)}                        \\
                &\hspace{5mm}U_t \leftarrow \frac{G_t}{max(\sqrt{\widehat{V}_t}, \epsilon_1)}        \\
            \end{aligned}

        where the row and column factors of gradient squared :math:`R_t` and :math:`C_t`
        are left alone, and we apply :math:`\epsilon_1` at the final calculation of
        the variance estimate :math:`\widehat{V}_t` and for the update :math:`U_t`.

        This is in contrast to Noam Shazeer and Mitchell Stern and other frameworks which
        apply :math:`\epsilon_1` to both row and column factors of the squared gradient, but
        not in the calculations after:

        .. math::
            \begin{aligned}
                &\hspace{5mm}R_t \leftarrow \widehat{\beta}_{2_t}R_{t-1}+
                            (1-\widehat{\beta}_{2_t})(G_t \odot G_t + \epsilon_1 1_n \cdot 1^\top_m) \cdot 1_m          \\
                &\hspace{5mm}C_t \leftarrow \widehat{\beta}_{2_t}C_{t-1}+
                            (1-\widehat{\beta}_{2_t}) 1^\top_n \cdot (G_t \odot G_t + \epsilon_1 1_n \cdot 1^\top_m)    \\
                &\hspace{5mm}\widehat{V}_t \leftarrow \frac{R_t \cdot C_t}{1^\top_n \cdot R_t}                          \\
                &\hspace{5mm}U_t \leftarrow \frac{G_t}{\sqrt{\widehat{V}_t}}                                            \\
            \end{aligned}


    .. _Adafactor\: Adaptive Learning Rates with Sublinear Memory Cost:
        https://arxiv.org/pdf/1804.04235
    .. _Decoupled Weight Decay Regularization:
        https://arxiv.org/abs/1711.05101
    )r   rE   rF   rG   rH   rI   rK   rL   r   r   r   r   rM   rN   r   rO   c          !      C   s4  |d u r|d u sJ dt j r2t|	ts:J nt|	}	t| D ]\}}|sX|| n||  }|| }|| }|| }|| }|d u rt |jj	}|d7 }|
 }||
 }t|	d|d  }t||d
 | d  | }|dkr|d|	|   | dkr|d ur$|d us,J dt j|ddd	 |d}||| t j|d
dd	 |d
}||| || }||jd
dd	j|d n.|d usJ d|| }||| | }|j|| d }|| td|d
 | d |  } |j|| |  d qBd S )N5Grad scaling should occur outside of optimizer.step()r         ?   r   Crow_var and col_var should be defined when grad is multidimensionalr9   TrA   Zkeepdimr;   )min0variance should be defined when grad is a vectorr   alpha)r4   ZjitZis_scriptingr   r5   r   	enumeratefinfor-   r   itemr]   maxnormr    Zmul_rA   Zsquare_Zdiv_sizeZlerp_meanZclamp_cloneZrsqrt_Zadd_)!r   rE   rF   rG   rH   rI   rK   rL   r   r   r   r   rM   rN   r   rO   iparamr>   Zstep_tr:   r<   r=   Z
step_floatZone_minus_beta2_tZrho_tr`   Zrow_meanZcol_meanZvar_estimateZgrad_squaredupdateZdenomr)   r)   r*   _single_tensor_adafactorG  sV    
$

$rl   )tensorlistsreturnc                 C   s   t | }i }| D ]\\}}\} }||df}||df}t| d D ]\}}	|	dus^J d|	 dkr||vrdd | D ||< tt| D ]}
|| |
 | |
 |  qqF||vrdd | D ||< tt| D ]}
|| |
 | |
 |  qqFq|S )	zGroups tensors by device, dtype, AND multidimensionality -- whether the tensor
    has multiple dims or just one dim (is a vector). This allows the foreach impl of
    Adafactor to assume that every group of params will either be factored or not.TFr   Nzgrad should not be Nonec                 S   s   g | ]}g qS r)   r)   .0_r)   r)   r*   
<listcomp>      zB_group_tensors_by_device_dtype_and_is_multidim.<locals>.<listcomp>c                 S   s   g | ]}g qS r)   r)   ro   r)   r)   r*   rr     rs   )r   Z"_group_tensors_by_device_and_dtypeitemsra   rA   ranger3   r@   )rm   grouped_tensorsZultra_grouped_tensorsdevicer-   rq   Z
matrix_keyZ
vector_keyjr6   ri   r)   r)   r*   ._group_tensors_by_device_dtype_and_is_multidim  s$    	


 ry   c          +   	      s\  t | dkrd S |d u r |d u s(J dt|	}	t| |||||g}| D ]\\}}}\}}}}}}ttt |}ttt |}ttt |}|d u r|d usJ dt|j	}t
r|d d usJ |rt|}tj s|d jrtj|tjddddd nt|d g }g }g }|D ]J} ||  |
  |d|  |
   |t|	d|  d	   q(fd
dt||D }!|dkrt|d|	|   |rttt |}"ttt |}#|"d d ur|#d d usJ ddd |D }$t|$|$ t|$dd |D  t|"|$| ~$dd |D }%t|%|% t|%dd |D  t|#|%| ~%dd t|"|#D }&dd |"D }'t|'| t|&|' ~'nNttt |}(|(d d usJ dt||})t|(|)| ~)dd |(D }&t|&||  t|& t|&| |&}* fddt|!|*D }!t|*|! t||* qLd S )Nr   rX   z2dtype is needed to compute eps1 when eps1 is unsetr   cpu)rw   r_   r   rY   c                    s4   g | ],\}}t  |d  | d  | qS )rZ   rY   rd   re   rc   r    )rp   r8   r)rN   r)   r*   rr     s   z+_multi_tensor_adafactor.<locals>.<listcomp>r[   c                 S   s   g | ]}t j|d ddqS )r9   Tr\   r4   re   rp   r>   r)   r)   r*   rr      s   c                 S   s   g | ]}| d qS )r9   rf   r~   r)   r)   r*   rr   $  rs   c                 S   s   g | ]}t j|d ddqS r;   Tr\   r}   r~   r)   r)   r*   rr   )  s   c                 S   s   g | ]}| d qS )r;   r   r~   r)   r)   r*   rr   -  rs   c                 S   s   g | ]\}}|| qS r)   r)   )rp   r:   r<   r)   r)   r*   rr   1  s   c                 S   s   g | ]}|j d ddqS r   )rg   )rp   r:   r)   r)   r*   rr   5  s   r^   c                 S   s   g | ]}|  qS r)   )rh   )rp   vr)   r)   r*   rr   F  rs   c                    s:   g | ]2\}}| t d |d | d     qS )r   rZ   rY   r{   )rp   ark   )r   r)   r*   rr   N  s   )r3   r   ry   rt   r   rB   r   r4   rb   r   r   Z_foreach_negcompileris_compilingZis_cpuZ_foreach_add_r6   r@   rc   r]   zipZ_foreach_mul_Z_foreach_div_Z_foreach_lerp_Z_foreach_clamp_min_Z_foreach_mulZ_foreach_rsqrt_)+r   rE   rF   rG   rH   rI   rK   rL   r   r   r   r   rM   rN   r   rO   rv   rq   r-   Zis_multidimZdevice_params_Zdevice_grads_Zdevice_row_vars_Zdevice_col_vars_Zdevice_variances_Zdevice_state_steps_Zdevice_paramsZdevice_gradsZdevice_state_stepsZone_minus_beta2_tsZbeta2_tsZrho_tssalphasZdevice_row_varsZdevice_col_varsZ	row_meansZ	col_meansZvar_estimatesZrow_var_meansZdevice_variancesZgrads_squaredZupdatesr)   )r   rN   r*   _multi_tensor_adafactor  s    

 



r   )Zsingle_tensor_fnF)r   rE   rF   rG   rH   rI   r   rK   rL   rO   r   r   r   r   rM   rN   r   c
                C   s^   t j s$tdd |D s$td|r.t}nt}|| ||||||
|||||||||	d dS )zxFunctional API that performs Adafactor algorithm computation.

    See :class:`~torch.optim.Adafactor` for details.
    c                 s   s   | ]}t |tjV  qd S )N)r   r4   r   )rp   tr)   r)   r*   	<genexpr>q  s   zadafactor.<locals>.<genexpr>z?`state_steps` argument must contain a list of singleton tensors)
r   r   r   r   rM   rN   r   rK   rL   rO   N)r4   r   r   allr?   r   rl   )r   rE   rF   rG   rH   rI   r   rK   rL   rO   r   r   r   r   rM   rN   r   funcr)   r)   r*   r   V  s6    )NNNF)typingr   r   r   r   r4   r   Z	optimizerr   r	   r
   r   r   r   r   r   __all__r   __doc__rB   r5   rV   rl   r"   rU   rw   r-   ry   r   r   r)   r)   r)   r*   <module>   s   ( /K 



Y#



 
    



