o
    Zho;                     @   st  d dl Z d dlmZ d dlmZ d dlmZmZ ejfddZ	d ejfddZ
d ejfdd	Zejejfd
dZejejfddZejfddZejfddZejfddZddejfddZejejfddZG dd deZG dd deZG dd deZG dd deZG d d! d!eZG d"d# d#eZG d$d% d%eZG d&d' d'eZG d(d) d)eZG d*d+ d+eZdS ),    N)Function)groupReduceOpc                 C      t ||| S )a  
    Broadcasts the tensor to the whole group.

    ``tensor`` must have the same number of elements in all processes
    participating in the collective.

    Arguments:
        tensor (Tensor): Data to be sent if ``src`` is the rank of current
            process.
        src (int): Source rank.
        group (ProcessGroup, optional): The process group to work on.

    Returns:
        Tensor: Received tensor from the broadcast op.

    )
_Broadcastapply)tensorsrcr    r
   N/var/www/auris/lib/python3.10/site-packages/torch/distributed/nn/functional.py	broadcast   s   r   c                 C   r   )aT  
    Gathers a list of tensors in a single process.

    Arguments:
        tensor (Tensor): Input tensor.
        dst (int, optional): Destination rank (default is 0).
        group (ProcessGroup, optional): The process group to work on.

    Returns:
        tuple[Tensor]: List of appropriately-sized tensors with the gathered data.
    )_Gatherr   )r   dstr   r
   r
   r   gather    s   r   c                 C   s   t j||g| R  S )a  
    Scatters a list of tensors to all processes in a group.

    Each process will receive exactly one tensor and store its data in the
    ``tensor`` argument.

    Arguments:
        tensors (list[Tensor]): List of tensors to scatter on the source rank.
            Receivers must pass ``None`.
        src (int, optional): Source rank (default is 0).
        group (ProcessGroup, optional): The process group to work on.

    Returns:
        Tensor: Output tensor from the scatter operation.

    )_Scatterr   )tensorsr	   r   r
   r
   r   scatter/   s   r   c                 C   s   t |||| S )a  
    Reduces the tensor data across all machines.

    Only the process with rank ``dst`` is going to receive the final result.

    Arguments:
        tensor (Tensor): Input of the collective.
        dst (int): Destination rank.
        op (optional): One of the values from
            ``torch.distributed.ReduceOp``
            enum.  Specifies an operation used for element-wise reductions.
        group (ProcessGroup, optional): The process group to work on.

    Returns:
        Tensor: Output of the collective.

    )_Reducer   )r   r   opr   r
   r
   r   reduceC   s   r   c                 C   s   t j||| g|R  S )a  
    Reduces, then scatters a list of tensors to all processes in a group.

    Arguments:
        output (Tensor): Output tensor.
        input_list (list[Tensor]): List of tensors to reduce and scatter.
        op (optional): One of the values from
            ``torch.distributed.ReduceOp``
            enum.  Specifies an operation used for element-wise reductions.
        group (ProcessGroup, optional): The process group to work on.

    Returns:
        Tensor: Output of the collective.

    )_Reduce_Scatterr   )outputZ
input_listr   r   r
   r
   r   reduce_scatterX   s   r   c                 C   s   t || S )a  
    Gathers tensors from the whole group in a list.

    Arguments:
        tensor (Tensor): Tensor to be broadcast from current process.
        group (ProcessGroup, optional): The process group to work on.

    Returns:
        tuple([Tensor]): Output of the collective.

    )
_AllGatherr   )r   r   r
   r
   r   
all_gatherk   s   r   c                 C   s   t | ||S )a  
    Single tensor all gather. Gathers a single tensor from all ranks, and puts them in a single output tensor.

    Args:
        output_tensor (Tensor): Output tensor. It should contain
            correctly-sized tensors to be used for output of the collective.
        input_tensor (Tensor): Tensor to be broadcast from current process.
        group (ProcessGroup, optional): The process group to work on. If None,
            the default process group will be used.

    Examples:
        >>> # All tensors below are of torch.int64 dtype.
        >>> # We have 2 process groups, 2 ranks.
        >>> # xdoctest: +SKIP("incorrect want text")
        >>> output_tensor = torch.zeros(2, dtype=torch.int64)
        >>> output_tensor
        [tensor([0, 0])] # Rank 0 and 1
        >>> tensor = torch.arange(1, dtype=torch.int64) + 1 + rank
        >>> tensor
        tensor([1]) # Rank 0
        tensor([2]) # Rank 1
        >>> dist.all_gather_base(output_tensor, tensor)
        >>> output_tensor
        tensor([1,2]) # Rank 0
        tensor([1,2]) # Rank 1

    .. warning::
        `_all_gather_base` is experimental and subject to change.
        It is the caller's responsibility to ensure the output_tensor
        is correctly sized.

    )_AllGatherBaser   )output_tensorinput_tensorr   r
   r
   r   _all_gather_basez   s   !r   c                 C   s   t j|| g|R  S )a  
    Each process scatters list of input tensors to all processes in a group and return gathered list of tensors in output list.

    Arguments:
        output_tensor_list (list[Tensor]): list of tensors to gather one per rank.
        input_tensor_list (list[Tensor]): List of tensors to scatter one per rank.
        group (ProcessGroup, optional): The process group to work on.

    Returns:
        tuple([Tensor]): Output of the collective.

    )	_AlltoAllr   )Zoutput_tensor_listinput_tensor_listr   r
   r
   r   
all_to_all   s   r!   c                 C   s   t || |||S )a  
    Each process splits input tensor and then scatters the split list to all processes in a group.

    Then concatenate the received tensors from all the processes in the group and return single output tensor.

    Arguments:
        output (Tensor): Gathered concatenated output tensor.
        input (Tensor): Input tensor to scatter.
        output_split_sizes: (list[Int], optional): Output split sizes for dim 0
            if specified None or empty, dim 0 of ``output`` tensor must divide
            equally by ``world_size``.
        input_split_sizes: (list[Int], optional): Input split sizes for dim 0
            if specified None or empty, dim 0 of ``input`` tensor must divide
            equally by ``world_size``.

    Returns:
        Tensor: Output of the collective.

    )_AlltoAllSingler   )r   inputoutput_split_sizesinput_split_sizesr   r
   r
   r   all_to_all_single   s   
r&   c                 C   r   )a&  
    Reduces the tensor data across all machines in such a way that all get the final result.

    After the call the returned tensor is going to be bitwise
    identical in all processes.

    Arguments:
        tensor (Tensor): Input of the collective.
        op (optional): One of the values from
            ``torch.distributed.ReduceOp``
            enum.  Specifies an operation used for element-wise reductions.
        group (ProcessGroup, optional): The process group to work on.

    Returns:
        Tensor: Output of the collective

    )
_AllReducer   )r   r   r   r
   r
   r   
all_reduce   s   r(   c                   @   $   e Zd Zedd Zedd ZdS )r   c                 C   s6   || _ || _tj|d| _| }tj|||d |S Nr   )r	   r   distget_rankrankcloner   )ctxr	   r   r   r
   r
   r   forward   s   z_Broadcast.forwardc                 C   s4   t | jtj| j|}| j| jkr|  d d |fS N)r   r   r	   r   SUMr   r.   Zzero_)r0   grad_outputgxr
   r
   r   backward   s   
z_Broadcast.backwardN__name__
__module____qualname__staticmethodr1   r6   r
   r
   r
   r   r      s
    

r   c                   @   r)   )r   c                    sv   || _ || _ fddttj|dD }   tj|d|kr.tj |||d t|S tj d ||d t|S )Nc                       g | ]}t  qS r
   )torch
zeros_like.0ir   r
   r   
<listcomp>       
z#_Gather.forward.<locals>.<listcomp>r+   )	r   r   ranger,   get_world_size
contiguousr-   r   tuple)r0   r   r   r   tensor_listr
   rB   r   r1      s   
z_Gather.forwardc                 G   s   dt j| j| jg|R  f S NNN)r   r   r   r   )r0   grad_outputsr
   r
   r   r6   
  s   z_Gather.backwardNr7   r
   r
   r
   r   r      s
    
r   c                   @   r)   )r   c                    st   || _ || _t fdd D sJ t d }tj|d|kr/tj|t ||d |S tj|d ||d |S )Nc                 3   s$    | ]}|   d    kV  qdS )r   Nsizer@   tr   r
   r   	<genexpr>  s   " z#_Scatter.forward.<locals>.<genexpr>r   r+   )	r	   r   allr=   r>   r,   r-   r   list)r0   r	   r   r   r   r
   rQ   r   r1     s   z_Scatter.forwardc                 C   s   dt | j| j| S rJ   )r   r   r	   r   r0   r4   r
   r
   r   r6     s   z_Scatter.backwardNr7   r
   r
   r
   r   r     s
    
r   c                   @   r)   )r   c                 C   s*   || _ || _| }tj||||d |S )Nr   r   )r	   r   r/   r,   r   )r0   r	   r   r   r   r
   r
   r   r1   "  s
   z_Reduce.forwardc                 C      dt | j| j|f S N)NNN)r   r   r	   r   rU   r
   r
   r   r6   *     z_Reduce.backwardNr7   r
   r
   r
   r   r   !  
    
r   c                   @   r)   )r   c                 G   s:   || _ | }tdd |D }tj|t|||d |S )Nc                 s       | ]}|  V  qd S r2   rG   rO   r
   r
   r   rR   5      z*_Reduce_Scatter.forward.<locals>.<genexpr>rV   )r   rG   rH   r,   r   rT   )r0   r   r   r   r    r
   r
   r   r1   0  s
   z_Reduce_Scatter.forwardc                 C   s   dt | j| S rX   )r   r   r   rU   r
   r
   r   r6   9  s   z_Reduce_Scatter.backwardNr7   r
   r
   r
   r   r   /  s
    
r   c                   @   r)   )r   c                    sD       || _ fddttj|dD }tj| |d t|S )Nc                    r<   r
   r=   
empty_like)r@   _rB   r
   r   rC   E  rD   z&_AllGather.forward.<locals>.<listcomp>r+   )rG   r   rE   r,   rF   r   rH   )r0   r   r   out_tensor_listr
   rB   r   r1   ?  s   
z_AllGather.forwardc                 G   s   t j| jdt jju r*t j| jd}t|| }tj	t
j| j|g|R  }d |fS dd |D }tj	| j|g|R  }tjt|dd}d |fS )Nr+   c                 S   s   g | ]}t |qS r
   r^   )r@   r   r
   r
   r   rC   U  s    z'_AllGather.backward.<locals>.<listcomp>r   )dim)r,   get_backendr   BackendNCCLr-   r=   r_   r   r   r   r3   r   sumstack)r0   rL   r.   r5   rI   Zgxsr
   r
   r   r6   L  s   z_AllGather.backwardNr7   r
   r
   r
   r   r   >  s
    
r   c                   @   r)   )r   c                 C   s   || _ tj|| |d |S r*   )r   r,   r   rG   )r0   r   r   r   r
   r
   r   r1   \  s   z_AllGatherBase.forwardc                 C   s   t j| jdt jju rLt j| jd}t| }|d | dkr*td| d| |d t j| jd |d< t	j
||j|jd}t ||tj| j ntdd |d fS )Nr+   r   zTensor with dimensions: z8 does not have first dimension divisible by world_size: devicedtypezBackend not supported!)r,   rc   r   rd   re   rF   rT   rN   RuntimeErrorr=   emptyri   rj   Z_reduce_scatter_baser   r3   )r0   r4   Z
world_sizeZout_sizer5   r
   r
   r   r6   b  s    

z_AllGatherBase.backwardNr7   r
   r
   r
   r   r   [  s
    
r   c                   @   r)   )r   c                    s   || _  fddttj|dD | _tj|d}tdd  D  tj|dtjj	u rPttj|dD ]}d }||kr@t
 }tj|| |||d q4t|S tj|t
 |d t|S )Nc                    s   g | ]} |   qS r
   rM   r?   rQ   r
   r   rC   z  s    z%_AlltoAll.forward.<locals>.<listcomp>r+   c                 s   r[   r2   r\   rO   r
   r
   r   rR   ~  r]   z$_AlltoAll.forward.<locals>.<genexpr>)r   rE   r,   rF   input_tensor_size_listr-   rH   rc   rd   ZGLOOrT   r   r!   )r0   r   ra   r   Zmy_rankrA   to_sendr
   rQ   r   r1   w  s&   
z_AlltoAll.forwardc                    s.    fdd| j D }dtj| j|g R   S )Nc                    s(   g | ]}t j| d  j d  jdqS )r   rh   )r=   rl   ri   rj   )r@   rN   rL   r
   r   rC     s    z&_AlltoAll.backward.<locals>.<listcomp>rK   )rm   r   r   r   )r0   rL   rI   r
   ro   r   r6     s   
z_AlltoAll.backwardNr7   r
   r
   r
   r   r   v  s
    
r   c                   @   r)   )r"   c                 C   s4   || _ | | _|| _|| _tj|||||d |S )N)r$   r%   r   )r   rN   
input_sizer$   r%   r,   r&   )r0   r   r   r$   r%   r#   r
   r
   r   r1     s   
z_AlltoAllSingle.forwardc              	   C   s8   t j| j|j|jd}dt| j|| j| j	|
 f S )Nrh   )NNNN)r=   rl   rp   ri   rj   r"   r   r   r$   r%   rG   )r0   r4   r   r
   r
   r   r6     s   z_AlltoAllSingle.backwardNr7   r
   r
   r
   r   r"     s
    
r"   c                   @   r)   )r'   c                 C   s.   || _ || _|jtjd}tj|||d |S )N)Zmemory_formatrV   )r   r   r/   r=   Zcontiguous_formatr,   r(   )r0   r   r   r   r
   r
   r   r1     s
   z_AllReduce.forwardc                 C   rW   rJ   )r'   r   r   r   rU   r
   r
   r   r6     rY   z_AllReduce.backwardNr7   r
   r
   r
   r   r'     rZ   r'   )r=   Ztorch.distributeddistributedr,   Ztorch.autogradr   r   r   ZWORLDr   r   r   r3   r   r   r   r   r!   r&   r(   r   r   r   r   r   r   r   r   r"   r'   r
   r
   r
   r   <module>   s6   $
# 