a
    h6                     @   s   U d dl Z d dlZd dlmZmZ d dlmZ d dlmZm	Z	m
Z
mZmZ d dlZd dlmZ d dlmZ d dlmZ d dlmZ g Zee ed< e eZG d	d
 d
ejZee edddZdS )    N)
CollectionMapping)deepcopy)AnyCallableOptionaloverloadUnion)optim)ShardedTensor)FullyShardedDataParallel__all__c                	   @   sx  e Zd ZdZdeeeeje	f f e
jeeeeef   eej eedf eeef ddddZdddd	Zeeef dd
dZed dddddZeeg ef edddZd!eeg ef  ee dddZeeejef dddZeeef ddddZeeef ddddZddddZeeef eeef dddZeeef eeef dddZ dS )"_NamedOptimizera  
    ``_NamedOptimizer`` takes a dict of parameters and exposes ``state_dict`` by parameter key.

    We replace the original key (number) in an optim to the
    fully qualified name (FQN) string. User can initialize the optim as they
    initialize a PyTorch optim, the only difference is that they also need to
    pass in the FQN of each parameters.

    Args:
        named_parameters (Mapping[str, Union[torch.Tensor, ShardedTensor]]):
            Mapping from FQN to parameter.
        optimizer_class (optim.Optimizer):
            The class of optimizer to instantiate.
        param_groups (Collection[Mapping[str, Any]]):
            `param_groups` to pass to optimizer if specified.
            The key of the inner map needs to be FQNs.
            Default: None
        module (nn.Module): the module whose parameters to updated
            by the optimizer.
        args: arguments to pass to the optimizer constructor.
        kwargs: arguments to pass to the optimizer constructor.

    Example::
        >>> # xdoctest: +SKIP("distributed")
        >>> from torch import optim
        >>> from torch.distributed.optim import _NamedOptimizer
        >>>
        >>> # Define the named optimizer.
        >>> m = Model(...)
        >>> named_optim = _NamedOptimizer(m.named_parameters(), optim.SGD)
        >>> # Forward pass + backward pass.
        >>> named_optim.step()
        >>> ...
        >>> # Call state_dict for the named optimizer returns a FQN state_dict.
        >>> named_optim.state_dict()

    Warning: This API is still in development and subject to change.

    TODO: Add tutorial for _NamedOptimizer.
    TODO: Add documentation in the docstring for the public attributes
          like self.param_groups and self.named_parameters.
    N.)named_parametersoptimizer_classparam_groupsmoduleargskwargsreturnc                 O   s   t jd || _|   t|| _|d u r6| j n|}||g|R i || _|| _	|d u rrt
| j | _nftd dd | j D }g }	|D ]8}
|
d D ]*}||vrtd| d|	||  qq|	| _| jj| _d S )Nz'torch.distributed.optim._NamedOptimizerzvSince we pass in param_groups, we will use param_groups to initialize the optimizer, not all parameters of the module.c                 S   s   i | ]\}}||qS  r   .0keyparamr   r   U/var/www/auris/lib/python3.9/site-packages/torch/distributed/optim/named_optimizer.py
<dictcomp>\       z,_NamedOptimizer.__init__.<locals>.<dictcomp>paramszExpect param name z% found in param group but is missing.)torchZ_CZ_log_api_usage_oncer   _param_groups_checkdictr   values
_optimizerr   listkeysordered_param_keyswarningswarnitems
ValueErrorappend)selfr   r   r   r   r   r   Zparams_for_optimizerparam_to_keyr&   groupr   r   r   r   __init__?   s<    	

z_NamedOptimizer.__init__)r   c                 C   s   | j d ur| j D ]x}t|ts&J dd|v s6J d|d }t|tjrP|g}t|}|D ]"}t|tjs\tdt| q\||d< qd S )Nparam group must be a dictr   z#param group must contain key paramsz>optimizer can only optimize Tensors, but one of the params is )r   
isinstancer!   r   Tensorr$   	TypeErrortypename)r,   param_groupr   r   r   r   r   r    i   s     

z#_NamedOptimizer._param_groups_checkc           
         s    j  }|d } fdd|d  D }g }|D ]V} fdd|d D }dt|i}| D ]\}}	|dkrbt|	||< qb|| q4 ||dS )	z
        Return the ``state_dict`` of the optimizer.

        Instead of using number to index
        parameters, we will use module fully qualified name (FQN) as the key.
        r   c                    s   i | ]\}} j | |qS r   r&   )r   Zst_key	state_valr,   r   r   r      s   z._NamedOptimizer.state_dict.<locals>.<dictcomp>statec                    s   g | ]} j | qS r   r6   )r   r   r8   r   r   
<listcomp>   r   z._NamedOptimizer.state_dict.<locals>.<listcomp>r   )r9   r   )r#   
state_dictr)   sortedr   r+   _post_state_dict)
r,   r;   r   Z	ret_stateZ
ret_groupsr.   
param_keysZ	ret_groupkvr   r8   r   r;   z   s    


z_NamedOptimizer.state_dict)closurer   c                 C   s   d S Nr   r,   rA   r   r   r   step   s    z_NamedOptimizer.stepc                 C   s   d S rB   r   rC   r   r   r   rD      s    c                 C   s   | j j|dS )z
        Perform a single optimization step.

        This will call :meth:`torch.optim.Optimizer.step` on the wrapped
        optimizer.
        rA   )r#   rD   rC   r   r   r   rD      s    c                 C   s   | j jS rB   )r#   r9   r8   r   r   r   r9      s    z_NamedOptimizer.state)r;   r   c                 C   s  | j  }| |}|d }|d }t|dkr8tdt| jD ]l\}}|| vrZqBt|| t|| krtdt||  d| dt||  ||  D ]\}}||| vrtd| d| d|| | }	t	|t
rnt	|	t
sJ t| }
t|	 }|
|kr<td	| d
|
 d| d| t| |	 D ]\}}|j |j qNqt	|tjrt	|	tjsJ | |	 qt|	|| |< qqB|d }|d }i }|D ]}t|d }||t|< qi }|D ]6}g }|d D ]}|| j|  q||t|< q| D ]\}}||vrHq2|| }t|t|krtdt| d| d
t| d|D ]@}||vrtd| d| d|dkrt|| ||< qq2| j | dS )a  
        Define the default behavior to load a state_dict for ``_NamedOptimizer``.

        Sample Code
        ```
            my_model = MyModule()
            optimizer = _NamedOptimizer(my_model.named_parameters(), Adagrad)
            ...

            optim_state_dict = optimizer.state_dict()
            ...
            ...

            optimizer.load_state_dict(optim_state_dict)
            ...
        ```
        Args:
            state_dict (dict[str, Any]) : A ``state_dict`` to load into the optimizer.
                Note that this state dict update is performed in place.

        .. note:: PyTorch is using lazy init to initialize the optim states.
            So it is possible that there is no optim state when user call
            ``load_state_dict`` and for ``_NamedOptimizer`` we make it stricter
            that users can only call ``load_state_dict`` after the state is initialized.
            By doing this, we can validate the optim ``state_dict`` to be loaded.
        r9   r   zJExpects the optim to be initialized before load but found not initialized.zExpects equal length as z for parameter z but found: zExpects state z but not found.z"Expects equal number of shards as z but found z for /r   r   z"Expects equal param_group size as z for group .zExpects group key z to be in group z  in `state_dict` but is missing.N)r#   r;   _pre_load_state_dictlenr*   	enumerater&   r%   r)   r1   r   Zlocal_shardszipZtensordetachZcopy_r   r2   r   r$   _gen_param_group_keyr+   load_state_dict)r,   r;   Znew_state_dictr9   Z	new_stateidxZ	param_keyZ	state_keyr7   Zsrc_state_valZ
num_shardsZnum_new_shardsZshardZ	src_shardZsrc_param_groupsZnew_param_groupsZsrc_group_mapr.   r>   Znew_group_mapZ	new_groupZ	group_keyZ	src_groupr?   r   r   r   rN      s    

$



z_NamedOptimizer.load_state_dict)r5   r   c                 C   s   t |tsJ d|d }t |tjr2|g|d< nt||d< dd | j D }|d D ]$}||vrntd| j	||  qZ| j
| | j
j| _dS )z
        Add a param group to the :class:`_NamedOptimizer` s `param_groups`.

        Warning: This API is still in development and subject to change.
        r0   r   c                 S   s   i | ]\}}||qS r   r   r   r   r   r   r     r   z3_NamedOptimizer.add_param_group.<locals>.<dictcomp>z%some parameters are not in the moduleN)r1   r!   r   r2   r$   r   r)   r*   r&   r+   r#   add_param_groupr   )r,   r5   r   r-   r   r   r   r   rP     s    z_NamedOptimizer.add_param_groupc                 C   s>   | j  D ]"}|jr
t|}tj||_q
| jdd dS )z
        Run a dummy optimizer step, which allows to initialize optimizer state because we do lazy init for most optimizers.

        This allows doing in-place loading of optimizer state from a checkpoint.
        NrE   )	r   r"   Zrequires_gradr   Z
zeros_likeZautogradVariableZgradrD   )r,   r   tr   r   r   
init_state'  s
    
z_NamedOptimizer.init_statec                 C   s&   t | jtr"tj| j| j|ddS |S )NT)Zis_named_optimizer)r1   r   FSDPZoptim_state_dict_to_loadr#   r,   r;   r   r   r   rH   4  s
    z$_NamedOptimizer._pre_load_state_dictc                 C   s"   t | jtrt| j| j| |S rB   )r1   r   rT   Zoptim_state_dictr#   rU   r   r   r   r=   =  s    z _NamedOptimizer._post_state_dict)NN)N)N)!__name__
__module____qualname____doc__r   strr	   r   r2   r   r
   	Optimizerr   r   r   nnModuletupler!   r/   r    r;   r   rD   r   floatpropertyr9   rN   rP   rS   rH   r=   r   r   r   r   r      s4   /  

*"	h 	r   )r>   r   c                 C   s   d t| S )zFConcatenate all param keys as a unique identifier for one param group.rF   )joinr<   )r>   r   r   r   rM   E  s    rM   ) loggingr'   collections.abcr   r   copyr   typingr   r   r   r   r	   r   Ztorch.nnr\   r
   Z'torch.distributed._shard.sharded_tensorr   Ztorch.distributed.fsdpr   rT   r   r$   rZ   __annotations__	getLoggerrV   loggerr[   r   rM   r   r   r   r   <module>   s   

  4