a
    h                    @   s  d dl Z d dlZd dlZd dlZd dlZd dlmZmZ d dlm	Z	m
Z
 d dlmZ d dlmZmZmZmZmZ d dlZd dlmZ d dlmZ d dlmZmZ d dlmZ d d	lmZ d
dl m!Z! d
dl"m#Z#m$Z$m%Z% d
dl&m'Z' g dZ(e)e*Z+G dd deZ,e,j-Z-e,j.Z.e,j/Z/e,j0Z0e,j1Z1e,j2Z2e,j3Z3e,j4Z4e,j5Z5e,j6Z6e-Z7e.Z8e/Z9e6Z:e;dZ<G dd deZ=dMe>e?e@ee=  f ee? eAdddZBG dd deZCdNe@ejD eeA e@ejE dddZFdOe@ejD eeA e>e?e@ejE f dddZGe@ejE ddd ZHG d!d" d"eCZIG d#d$ d$eIZJG d%d& d&eIZKG d'd( d(eIZLdPe@ee=  e?e@e= d*d+d,ZMe@ee=  e@e= d-d.d/ZNe>e?e@e= f ee?ge?f e?e>e?e@e= f d0d1d2ZOe>e?e@ee=  f e?e?e?e>e?e?f d3d4d5ZPG d6d7 d7eCZQG d8d9 d9eQZRG d:d; d;eQZSdQd=d>ZTG d?d@ d@eQZUG dAdB dBeQZVG dCdD dDeQZWeAdEdFdGZXee?ge?f e?dHdIdJZYdKdL ZZdS )R    N)ABCabstractmethod)Counterdefaultdict)Enum)AnyCallable
NamedTupleOptionalUnion)OptimizedModule)
FSDPModuleUnshardHandle)_Loss)record_function   )generate_stage_to_rank_mapping)merge_chunkssplit_args_kwargs_into_chunksTensorChunkSpec)_PipelineStageBase)	get_schedule_classPipelineScheduleSinglePipelineScheduleMultiSchedule1F1BScheduleGPipeScheduleInterleaved1F1BScheduleLoopedBFSScheduleInterleavedZeroBubbleScheduleZBVZeroBubblec                   @   sH   e Zd ZdZdZdZdZdZdZdZ	dZ
d	Zd
Zdd Zedd ZdS )_ComputationTyper                        	   
   c                 C   sH   t jdt jdt jdt jdt jdt jdt jdt jdt j	d	t j
d
i
}||  S )NFIWUNSHARDRESHARDSEND_FRECV_FSEND_BRECV_BB)r    FORWARDBACKWARD_INPUTBACKWARD_WEIGHTr-   r.   r/   r0   r1   r2   FULL_BACKWARD)selfZstr_map r9   T/var/www/auris/lib/python3.9/site-packages/torch/distributed/pipelining/schedules.py__str__6   s    z_ComputationType.__str__c                 C   s   | dkrt jS | dkrt jS | dkr*t jS | dkr8t jS | dkrFt jS | dkrTt jS | dkrbt jS | dkrpt jS | d	kr~t j	S | d
krt j
S td|  d S )Nr*   r+   r,   r-   r.   r/   r0   r1   r2   r3   zInvalid computation type )r    r4   r5   r6   r-   r.   r/   r0   r1   r2   r7   RuntimeErroractionr9   r9   r:   from_strE   s*    z_ComputationType.from_strN)__name__
__module____qualname__r4   r5   r6   r-   r.   r/   r0   r1   r2   r7   r;   staticmethodr?   r9   r9   r9   r:   r    )   s   r    z?(\d+)(F|I|B|W|UNSHARD|RESHARD|SEND_F|RECV_F|SEND_B|RECV_B)(\d*)c                   @   sH   e Zd ZU eed< eed< dZee ed< dd Ze	e
ddd	ZdS )
_Actionstage_indexcomputation_typeNmicrobatch_indexc                 C   s4   t | j}|t | j7 }| jd ur0|t | j7 }|S N)strrE   rF   rG   )r8   reprr9   r9   r:   __repr__{   s
    

z_Action.__repr__)action_stringc                 C   sj   |   } t|  }rJ| \}}}tt|t|t|rDt|ndS | dkrVdS t	d|  ddS )z
        Reverse of __repr__

        String should be formatted as [stage][action type][(microbatch)]
            e.g. `2F0`, `1UNSHARD`, `3SEND_F1`
        N zInvalid action string: zD, should be formatted as [stage][action type][(microbatch)] e.g. 2F0)
strip_action_regexmatchgroupsrD   intr    r?   lenr<   )rL   rP   rE   rF   rG   r9   r9   r:   r?      s    
z_Action.from_str)r@   rA   rB   rR   __annotations__r    rG   r
   rK   rC   rI   r?   r9   r9   r9   r:   rD   v   s   
rD   )pipeline_ordererror_step_numberreturnc                    s6  t D ]6}tt| D ] }| | du r"d| |< q"qtdd  D fddtD }fddtD }ttj	|ddi}t}d	d t|D }d
d t
|g|R  D dt|d d  dfddt|D  }	 fddt
||D }
|	d d|
 d }|S )z
    Formats the pipeline order in a timestep (row) x rank (column) grid of actions
    and returns the formatted string.

    If `error_step_number` is passed in, an additional label will be added to signify which step
    that it is erroring on.
    NrM   c                 s   s   | ]}t |V  qd S rH   )rS   ).0actionsr9   r9   r:   	<genexpr>       z)_format_pipeline_order.<locals>.<genexpr>c              	      s*   g | ]"}d t |tt  d  qS )zStep r   )rI   zfillrS   rX   i)	num_stepsr9   r:   
<listcomp>   s   z*_format_pipeline_order.<locals>.<listcomp>c                    s   g | ]} |d g  qS )rM   )get)rX   key)r_   rU   r9   r:   r`      s   	fillvaluec                 S   s   g | ]}d t | qS )zRank rI   r]   r9   r9   r:   r`      r[   c                 S   s   g | ]}t d d |D qS )c                 s   s&   | ]}|d urt t|ndV  qd S )Nr   )rS   rI   )rX   itemr9   r9   r:   rZ      r[   4_format_pipeline_order.<locals>.<listcomp>.<genexpr>)max)rX   colr9   r9   r:   r`      s    r   r!   c                 3   s$   | ]\}}|d  |  V  qdS <Nr9   )rX   r^   labelmax_lengthsr9   r:   rZ      s   c                    sZ   g | ]R\}}| d d fddt|D   durPt| d  krPdnd qS )z: ri   c                 3   s(   | ] \}}t |d  |  V  qdS rj   rd   )rX   r^   re   rm   r9   r:   rZ      r[   rf   Nr   z <-- ERROR HERErM   )join	enumeraterR   split)rX   rl   row)rV   rn   r9   r:   r`      s   	
)copydeepcopyrangerS   rg   valuessortedlist	itertoolszip_longestzipro   rp   )rU   rV   rankr^   Zstep_labelsZrank_actionsZtransposed_actionsZ	num_ranksZrank_labelsZ
header_rowZformatted_rowsZformatted_tabler9   )rV   rn   r_   rU   r:   _format_pipeline_order   s4    

 
	r~   c                
   @   s,  e Zd Zdeeedejf  eee	df  ee
ee	f  eee
eef ee f  edddZdd Zd	d
 Zdd Zedee ee ee ee dddZedddee dddZd ee ee ee ee dddZdd Zd!eedf ee
eef  dddZee edddZdS )"_PipelineScheduleNT.n_microbatchesloss_fnargs_chunk_speckwargs_chunk_specoutput_merge_specscale_gradsc                 C   sJ   || _ || _|| _|| _|| _|| _| jd u| _g | _t	d| j
j d S )NzUsing %s)_n_microbatches_loss_fnr   _args_chunk_spec_kwargs_chunk_spec_output_merge_spec_has_backward_internal_lossesloggerinfo	__class__r@   )r8   r   r   r   r   r   r   r9   r9   r:   __init__   s    
z_PipelineSchedule.__init__c                 C   s,   |j r(| jr(| ||| }| j| d S rH   )is_lastr   _compute_lossr   append)r8   stageoutput
target_mbsmb_indexlossr9   r9   r:   _maybe_compute_loss   s    z%_PipelineSchedule._maybe_compute_lossc                 C   sj   d|  kot | jk n  }|jr8| jr8|r8| j| S t | jdkrb|sbtd| d| j nd S d S )Nr   zLoss for microbatch z6 is not available. Available losses for microbatches: )rS   r   r   r   r<   )r8   r   r   Zvalid_indexr9   r9   r:   _maybe_get_loss   s    
z!_PipelineSchedule._maybe_get_lossc                 C   s|   t |ts|g}tdd |D }|rn|durnt| j| jkrZtd| j dt| j |  || j | j  dS )zB
        Update the losses to those in the internal state
        c                 s   s   | ]}|j V  qd S rH   r   rX   r   r9   r9   r:   rZ     r[   z3_PipelineSchedule._update_losses.<locals>.<genexpr>N
Expecting z losses but got )	
isinstancery   anyrS   r   r   r<   clearextend)r8   stageslossesZcontains_last_stager9   r9   r:   _update_losses  s    
z _PipelineSchedule._update_lossesarg_mbs	kwarg_mbsr   r   c                 C   s   t dS )z
        Run one iteration of the pipeline schedule with list of microbatches.
        Will go through all the microbatches according to the schedule
        implementation.

        Args:
            microbatches: list of microbatch args.
        NNotImplementedError)r8   r   r   r   r   r9   r9   r:   _step_microbatches"  s    z$_PipelineSchedule._step_microbatchestargetr   r   c                O   s   t dS a  
        Run one iteration of the pipeline schedule with *whole-batch* input.
        Will chunk the input into microbatches automatically, and go through the
        microbatches according to the schedule implementation.

        args: positional arguments to the model (as in non-pipeline case).
        kwargs: keyword arguments to the model (as in non-pipeline case).
        target: target for the loss function.
        losses: a list to store the losses for each microbatch.
        Nr   )r8   r   r   argskwargsr9   r9   r:   step4  s    z_PipelineSchedule.stepc                    s   t d fdd}|dur&||d ndg j }|durF||d ni g j }|durd||d |durt|tstd	t| ||fS )
z*
        Pre-process/check inputs
        )namec                    sR   t | ts t| dt|  t|  jkrNtd j d| dt|  d S )Nz must be a list but got a r   ri   z	 but got )r   ry   	TypeErrortyperS   r   
ValueError)Zmbsr   r8   r9   r:   check_type_and_lenM  s    
z;_PipelineSchedule._check_inputs.<locals>.check_type_and_lenNr   r9   r   r   z losses must be a list but got a )rI   r   r   ry   r   r   )r8   r   r   r   r   r   r9   r   r:   _check_inputsB  s    

z_PipelineSchedule._check_inputsc                 C   s   |  ||S rH   )r   )r8   r   r   r9   r9   r:   r   h  s    z_PipelineSchedule._compute_loss)r   r   c                 C   sF   |s|r*t ||| j| j| j\}}||fS dg| j i g| j fS dS )zj
        Splits a full-batch input into chunks (i.e. microbatches) and returns
        the chunks
        r9   N)r   r   r   r   )r8   r   r   
args_splitkwargs_splitr9   r9   r:   _split_inputsk  s    	z_PipelineSchedule._split_inputs)output_chunksrW   c                 C   s   t || jS )z
        Merge output chunks back to a batch state.
        If output_merge_spec is None, the utility will merge output chunks by dimension 0 (batch dim).
        )r   r   )r8   r   r9   r9   r:   _merge_outputs  s    z _PipelineSchedule._merge_outputs)NNNNT)NNNN)NNNN)N)r@   rA   rB   rR   r
   r   torchZTensortupler   dictrI   r   r   boolr   r   r   r   r   ry   r   r   r   r   r   r   r9   r9   r9   r:   r      s\        "        & 
r   )p2p_opsdescrW   c                 C   s:   t | dkrg S |r| dnd}td||  t| S )zt
    Simple wrapper over batch_isend_irecv from torch.distributed, which just adds a descriptive logger on top.
    r   z, rM   zbatch_p2p %s%s)rS   r   debugdistZbatch_isend_irecv)r   r   Zdesc_strr9   r9   r:   
_batch_p2p  s
    r   c                 C   s`   t t}i }t| dkr|S | D ]}||j | q t| D ]\}}t||d||< qB|S )z
    Sorts the list of P2P ops by the peer rank, and then calls
    batch_isend_irecv. Return a dictionary of works by peer rank. This function
    helps us avoid hangs in case of skip connections.
    r   r   )r   ry   rS   peerr   rx   itemsr   )r   r   Zops_by_peerZwork_by_peeropr   opsr9   r9   r:   _sorted_batch_p2p  s    r   )workc                 C   s   | D ]}|   qdS )zX
    Waits for a list of dist.Work (typically from _batch_p2p / _sorted_batch_p2p).
    N)wait)r   wr9   r9   r:   _wait_batch_p2p  s    r   c                       s   e Zd ZdZdeeee eee	df  ee
ee	f  eee
eef ee f  ed fddZdd	 Zddd
ee dddZee
eeee  f  dddZ  ZS )r   a  
    Base class for single-stage schedules.
    Implements the `step` method.
    Derived classes should implement `_step_microbatches`.

    Gradients are scaled by num_microbatches depending on the `scale_grads` argument, defaulting to True.  This setting
    should match the configuration of your loss_fn, which may either average losses (scale_grads=True)
    or sum losses (scale_grads=False).
    NT.)r   r   r   r   r   r   r   c                    sf   t  j||||||d || _|j| _| j| j_d| _|| jk rXtd| d| j d| 	 | _
d S )Nr   FzNumber of microbatches (z9) must be greater than or equal to the number of stages (z).)superr   _stage
num_stages_num_stagesr   has_backward_stage_initializedr   _get_pipeline_orderrU   )r8   r   r   r   r   r   r   r   r   r9   r:   r     s(    	

zPipelineScheduleSingle.__init__c                 C   s0   | j | j|| | jr&| j | j d| _d S NT)r   _prepare_forward_infrar   r   _prepare_backward_infrar   )r8   r   r   r9   r9   r:   _initialize_stage  s    z(PipelineScheduleSingle._initialize_stager   r   c                O   sh   | j   | ||\}}|dur6tt|| j}nd}| |||| | j jr`| 	| j j
S dS dS r   )r   clear_runtime_statesr   ry   r   tensor_splitr   r   r   r   r   )r8   r   r   r   r   r   r   targets_splitr9   r9   r:   r     s    
zPipelineScheduleSingle.steprW   c                 C   s   dS )a  
        Returns the pipeline execution order as a schedule IR.

        The returned IR is a dictionary mapping rank IDs to lists of actions.
        Each action is either an _Action object representing computation to perform,
        or None representing a deliberate idle step.

        The None values are used to represent pipeline bubbles where a rank
        must wait for dependencies from other ranks before proceeding. However
        during execution, with  the _PipelineScheduleRuntime, these Nones are
        skipped since the relevant communication (send/recv) will be scheduled and waited on.

        Returns:
            A dictionary mapping rank -> list of actions
        Nr9   r   r9   r9   r:   r     s    z*PipelineScheduleSingle._get_pipeline_order)NNNNT)r@   rA   rB   __doc__r   rR   r
   r   r   r   r   rI   r   r   r   r   r   ry   r   rD   r   __classcell__r9   r9   r   r:   r     s$        $!r   c                   @   s6   e Zd ZdZdee ee ee ee dddZdS )_ScheduleForwardOnlyzo
    The forward-only schedule.
    Will go through all the microbatches and perform only the forward pass
    Nr   c           
   	   C   s  |dus|durt d| ||||\}}| jsF| |d |d  g }t| jD ]}td| | | j|}t	|dd}|
 D ]}	t|	 q| j||| ||  | j|}t	|dd}||
  W d   n1 s0    Y  td| jj| qT|D ]}	t|	 q
dS )	z<
        Run one iteration of the pipeline schedule
        Nz7Forward-only schedule does not support loss computationr   Forward fwd_recvr   fwd_send[%s] Forwarded microbatch %s)r<   r   r   r   rv   r   r   r   get_fwd_recv_opsr   rw   r   forward_one_chunkget_fwd_send_opsr   r   r   rE   )
r8   r   r   r   r   fwd_sends_to_waitr^   r   worksr   r9   r9   r:   r   ,  s*    

,z'_ScheduleForwardOnly._step_microbatches)NNNN)r@   rA   rB   r   r
   ry   r   r9   r9   r9   r:   r   &  s       r   c                   @   sX   e Zd ZdZd	ee ee ee ee dddZeeeeee	  f  dddZ
dS )
r   z^
    The GPipe schedule.
    Will go through all the microbatches in a fill-drain manner.
    Nr   c              	   C   s:  |  ||||\}}| js.| |d |d  g }t| jD ]}td| | | j|}t|dd}|	 D ]}	t
|	 qp| j||| || }
| j|}t|dd}||	  W d   n1 s0    Y  td| jj| | | j|
|| q<|D ]}	t
|	 q| js dS g }t| jD ]}td|  | j|}t|d	d}|	 D ]}	t
|	 qb| | j|}| jj|||| jd
 kd | j|}t|dd}||	  W d   n1 s0    Y  td| jj| q.| jj| jr| jnd
d | | j| |D ]}	t
|	 q&dS )z
        Run one iteration of the pipeline schedule with list of microbatches.
        Will go through all the microbatches according to the GPipe schedule.

        Args:
            microbatches: list of microbatch args.
        r   r   r   r   r   Nr   z	Backward bwd_recvr   r   last_backwardbwd_sendz[%s] Backwarded microbatch %sgrad_scale_factor)r   r   r   rv   r   r   r   r   r   rw   r   r   r   r   r   r   rE   r   r   get_bwd_recv_opsr   backward_one_chunkget_bwd_send_opsr   r   )r8   r   r   r   r   r   r^   r   r   r   r   Zbwd_sends_to_waitr   r9   r9   r:   r   _  sV    
,.z ScheduleGPipe._step_microbatchesr   c                 C   s   i }| j }t|D ]}g }|}|dg|  t| jD ]}|t|tj| q8d|d |  }|dg|  t| jD ]}|t|tj| q||||< q|S )z
        Returns the pipeline order for GPipe schedule.

        See base method in PipelineScheduleSingle for details on the schedule IR format.
        Nr"   r   )	r   rv   r   r   r   rD   r    r4   r7   )r8   rU   pp_group_sizer}   rY   Zwarmup_delaymb_idxZbackward_delayr9   r9   r:   r     s    
z!ScheduleGPipe._get_pipeline_order)NNNNr@   rA   rB   r   r
   ry   r   r   rR   rD   r   r9   r9   r9   r:   r   Y  s       Tr   c                   @   sX   e Zd ZdZd	ee ee ee ee dddZeeeeee	  f  dddZ
dS )
r   zo
    The 1F1B schedule.
    Will perform one forward and one backward on the microbatches in steady state.
    Nr   c                 C   sj  |  ||||\}}| js.| |d |d  t| j| j| jj }d}d}g }g }	t|D ]~}
| j	|}t
t|dd | j||| || }t
| | j|}	||d krt|	dd}| | j||| |d7 }q\| j|}t
t|	| dd | | j|}| jj|||| jd kd | j|}|d7 }|| jkrJq| j	|}t
t|| dd | j||| || }| | j||| | j|}	|d7 }qt|d	d}|| jk r4| j|}t
t|d
d | | j|}| jj|||| jd kd t
| | j|}t|d	d}|d7 }q| jj| jrH| jndd t
| | | j| dS )z
        Run one iteration of the pipeline schedule with list of microbatches.
        Will go through all the microbatches according to the 1F1B schedule.

        Args:
            microbatches: list of microbatch args.
        r   r   r   r   r   Zfwd_send_bwd_recvr   Zbwd_send_fwd_recvr   r   r   N)r   r   r   minr   r   r   rE   rv   r   r   r   r   r   r   r   r   r   r   r   r   )r8   r   r   r   r   Zwarmup_chunksfwd_mb_indexbwd_mb_indexZ	send_workZ	fwd_sends_Z	fwd_recvsr   Z	bwd_recvsr   Z	bwd_sendsr9   r9   r:   r     s|    

zSchedule1F1B._step_microbatchesr   c                 C   sp  i }| j }t|D ]V}g }|dg|  |d | }d}t|D ]}|t|tj| |}qDtdd|d |  }|dg|  d}	| j| }
|
dkr|d7 }|t|tj| |
d8 }
|t|tj	|	 |	d7 }	q| j|	 }|dkrb|| dkr<|d |dkr`|t|tj	|	 |	d7 }	|d8 }q|t|tj	|	 |	d7 }	|d8 }q|||< q|S )z
        Returns the pipeline order for 1F1B schedule.

        See base method in PipelineScheduleSingle for details on the schedule IR format.
        Nr   r   r!   )
r   rv   r   r   rD   r    r4   rg   r   r7   )r8   rU   r   r}   rY   Znum_forwardZ
forward_mbr^   Zwait_for_1f1bZbackward_mbZremaining_forwardZremaining_backwardr9   r9   r:   r   f  sN    








z Schedule1F1B._get_pipeline_order)NNNNr   r9   r9   r9   r:   r     s        r   r"   )compute_actionsmax_active_stagesrW   c           
         s   t ttt  tt  ddd}t  g t d fdd}t d fdd}t| D ]~\}}|d	u rjqX||| |d	 tt fd
d}ttfdd }|D ]}	||	 q|D ]}	||	 q| qXS )aQ  Given a basic schedule involving only compute actions (F,B,W), add UNSHARD/RESHARD actions for FSDP.

    UNSHARD refers to fetching the full contents of an FSDP-sharded layer, requiring an all-gather operation.
    RESHARD does the opposite, releasing memory (but doing no communication)

    We abandon the "timestep lock"  during lowering

    max_active_stages controls how many prefetches we allow. It should be measured in mb and tuneable but in practice
    3 stages is probably the thing we want?
    (to account for having one f and one b active, and something else prefetching?)
    )countnext_actionsrW   c                 S   sR   t  }g }|D ]>}|dur|j|vr||j ||j t|| kr qNq|S )zdRemove duplicates (same stage, different microbatch), find next 'count' stages that will do compute.N)setrE   addr   rS   )r   r  seenretar9   r9   r:   next_stage_indices  s    z0_add_unshard_reshard.<locals>.next_stage_indicesrE   c                    s      |  t| td  d S rH   )r  r   rD   r-   r  active_stagesfsdp_aware_actionsr9   r:   _unshard  s    
z&_add_unshard_reshard.<locals>._unshardc                    s      |  t| td  d S rH   )remover   rD   r.   r  r	  r9   r:   _reshard  s    
z&_add_unshard_reshard.<locals>._reshardNc                    s   |  vS rH   r9   s)r
  r9   r:   <lambda>  r[   z&_add_unshard_reshard.<locals>.<lambda>c                    s   |  vS rH   r9   r  )next_nr9   r:   r    r[   )rR   ry   r
   rD   r  rp   filterr   )
r   r   r  r  r  r^   r>   fetchZevictr   r9   )r
  r  r  r:   _add_unshard_reshard  s&    


r  )r   rW   c                 C   s   g }| r|  d}|du rqt| r@| d  }du r@|  d q|jtkr|dur|jtkr|j|jkr|j|jkr|t|jt	|j |  d q|| q|S )a9  Given a basic schedule involving only compute actions (F,I,W), merge adjacent I and W ops into B ops.
    (note: I = BACKWARD_INPUT, W = BACKWARD_WEIGHT, B = FULL_BACKWARD)

    B refers to running the whole backward (not separating grad_input and grad_weight), which can be more efficient
    in some cases.
    r   N)
poprS   rF   r5   r6   rE   rG   r   rD   r7   )r   Zmerged_actionsr>   Znext_actionr9   r9   r:   	_merge_bw  s.    	


r  )r   stage_to_rankr   rW   c                    s  dd | D }dd | D }t tdfdd t tt t f d fdd}tt  tt  td	fd
d}| rd}t| D ]}t| | dksJ d|dt| | | | d }	||	|| sq~|	d urJ|| |	 || |	  |	rJ||	\}
}|| |
 || |
 ||j	 | ||j	 | | | 
d t| | dkrp| |= d}q~|slJ dql|S )Nc                 S   s   i | ]
}|g qS r9   r9   rX   r}   r9   r9   r:   
<dictcomp>  r[   z"_add_send_recv.<locals>.<dictcomp>c                 S   s   i | ]}|t  qS r9   r  r  r9   r9   r:   r    r[   r>   rW   c                    sd   | j tkr0| j d ko.| jd | jkS | j ttfv r`| jdko^| jd | jkS dS )Nr   r   F)rF   r*   rE   r5   r7   r=   )r   r  r9   r:   
_has_comms  s    
z"_add_send_recv.<locals>._has_commsc                    sx    | sJ |  d| j }| j}| j}t||tkr8tnt|}|tkrP|d n|d }t||tkrhtnt|}||fS )Nz is not a valid comm actionr   )	rE   rF   rG   rD   r*   r/   r1   r0   r2   )r>   	stage_idxctyper   sendZrecv_stage_idxrecv)r  r9   r:   
_get_comms'  s    z"_add_send_recv.<locals>._get_comms)r>   prev_actionsrW   c                    s   | du rdS | j tkrX| jdksXt| jt| j|v r8dS t| jd t| j|v rTdS dS | j ttfv r| j d kst| jt| j|v rdS t| jd t| j|v rdS t| jd t| j|v rdS dS dS dS )a  We don't put our own recv ops in the schedule, we let a sender on another rank put our recv ops in place.
        This helps ensure a sane (non-hanging) ordering of sends and recvs.
        But it also means we might not be able to schedule our next compute action yet.
        NTr   r   F)	rF   r*   rE   rD   r0   rG   r5   r7   r2   )r>   r#  )r   r9   r:   _ready_to_schedule1  sF    z*_add_send_recv.<locals>._ready_to_scheduleFr   rank=z, len(compute_actions[rank])=Tz6Malformed compute schedule, can't schedule sends/recvs)rD   r   r   r
   r  rx   rS   r   r  rE   r  )r   r  r   Zcomm_actionsr#  r"  r$  progressr}   r>   r   r!  r9   )r  r   r  r:   _add_send_recv  s>    ,

r'  )rY   r   r   num_microbatchesrW   c                 C   s  t | |ks$J d| dt |  t|D ]}|| v s,J d| q,dd t|D }i }| D ]r}| | D ]b}|d u rqpt|tsJ d| d|j}|j}	|j}
|	tkr|| t |
 n|	t	kr|
|| t v sJ d| d	|
 d
|| t	 |
 n|	t
krR|
|| t v s>J d| d	|
 d
|| t
 |
 nD|	tkr|
|| t
 v sJ d| d	|
 d|| t |
 ||vr|||< qp|| }||kspJ d| d| d| qpqb|D ]}t || t }t || t	 }t || t
 }t || t }||ksJJ d| dt d| d| ||| d  |ksJ d| d| d| d| d| 
q|S )Nz2Schedule has incorrect number of ranks - expected z	, actual z%Schedule is missing actions for rank c                 S   s*   i | ]"}|t t tt tt tt iqS r9   )r*   r  r3   r+   r,   )rX   Zstage_idr9   r9   r:   r    s   z&_validate_schedule.<locals>.<dictcomp>zGot an invalid action: z, expected instance of _Actionz Running Full Backward for stage z, microbatch z without first running Forwardz!Running Backward Input for stage z"Running Backward Weight for stage z% without first running Backward InputzStage z is assigned to both rank z
 and rank zGot ri   z microbatches for stage z, expected r!   z(Invalid backward microbatches for stage z: expected z( total backwards,             but got B=z, I=z, W=)rS   rv   r   rD   rE   rF   rG   r*   r  r3   r+   r,   )rY   r   r   r(  r}   Zstage_actionsZstage_index_to_rank_mappingr>   Zs_idr  Zmb_idZexisting_rankZf_mbZb_mbZi_mbZw_mbr9   r9   r:   _validate_schedule}  sx    	







r)  c                       s   e Zd ZdZdee eee ee	e
df  eeee
f  eeeeef e	e f  ee ed fddZe	edf dd	d
Zeeeee  f ddddZdd ZdddZdddee dddZdee ee ee ee dddZ  ZS )r   aX  
    Base class for multi-stage schedules.
    Implements the `step` method.

    Gradients are scaled by num_microbatches depending on the `scale_grads` argument, defaulting to True.  This setting
    should match the configuration of your loss_fn, which may either average losses (scale_grads=True)
    or sum losses (scale_grads=False).
    NT.)r   r   r   r   r   r   use_full_backwardr   c	           
         s   t  j||||||d || _|d j| _|d j| _|d j| _t	| j| j| _
| jD ]}	| j
|	_
qX| jD ]}	| j|	_qld| _| jd u  fdd| _i | _|d urtd d S )Nr   r   Fc                    s
   | j o S rH   r   )r   Zhas_lossr9   r:   r    r[   z0PipelineScheduleMulti.__init__.<locals>.<lambda>zDeprecation warning: 'use_full_backward' is no longer supported. Simply stop passing it, and everything should still work fine.)r   r   _stagesr   r   
group_sizer   
group_rankr}   r   stage_index_to_group_rankr   r   _stages_initializedr   Z_should_compute_lossrU   r   warning)
r8   r   r   r   r   r   r   r*  r   r   r   r+  r:   r     s6    	




zPipelineScheduleMulti.__init__)r   c                 C   sV   t  }| jD ]>}|jr(|| j||}n|| j||}| jr|| j qd| _d S r   )r   r,  is_firstr   r   r   r   r0  )r8   r   r   Znext_stage_argsr   r9   r9   r:   _initialize_stages  s    
z(PipelineScheduleMulti._initialize_stages)rY   rW   c                 C   s.   t || j| j| j| _| jD ]}| j|_qdS )z]
        Allocates the stage index to rank mapping which is needed for communication
        N)r)  r   r   r   r/  r,  )r8   rY   r   r9   r9   r:   _validate_and_set_stage_mapping  s    
z5PipelineScheduleMulti._validate_and_set_stage_mappingc                 C   sX   t |ddd6}t|}| jD ]}|| j|  q W d   n1 sJ0    Y  dS )zQDump a CSV representation of the schedule into a file with the provided filename.r   rM   newlineN)opencsvwriterrU   writerowr8   filenamecsvfiler9  r}   r9   r9   r:   	_dump_csv&  s    

zPipelineScheduleMulti._dump_csvcompute_onlyc                 C   sx   |dksJ t |dd@}t|}t|D ]\}}dd |D | j|< q,W d   n1 s^0    Y  | | j dS )zLoad a CSV representation of the schedule from a file with the provided filename.
        This API will most likely get renamed/refactored so is marked as internal for now.

        format must be "compute_only" for PipelineScheduleMulti.
        r?  rM   r5  c                 S   s   g | ]}t |qS r9   rD   r?   rX   r  r9   r9   r:   r`   7  r[   z3PipelineScheduleMulti._load_csv.<locals>.<listcomp>N)r7  r8  readerrp   rU   r4  )r8   r<  formatr=  rB  r}   rr   r9   r9   r:   	_load_csv-  s    
4zPipelineScheduleMulti._load_csvr   r   c          	      O   sz   | j D ]}|  q| ||\}}|dur@tt|| j}nd}| |||| | j D ]}|jrZ| 	|j
  S qZdS r   )r,  r   r   ry   r   r   r   r   r   r   r   )	r8   r   r   r   r   r   r   r   r   r9   r9   r:   r   =  s    


zPipelineScheduleMulti.stepr   c                 C   sZ  |  ||||\}}| js.| |d |d  dd | jD }t }t }| D ]B}|dkrr|| j|d   || jd k rR|| j|d   qRt	 }	t
| j| j D ]\}
}z4g }|durr|j}|j}|j}|dusJ d|tjkr8|| }|||| || }| |||| ||| n:|tjkr|| }| ||}|	|  d7  < |	| | jk}| jr| jnd}|j||d|d |r|| ||| n|tjkr || }| ||}|j||d	d	d ||| nr|tjkrd|| }|	|  d7  < |	| | jk}| jr@| jnd}|j||d
 |rr|| ntd| |D ]}| j| }d}|
t |k r||
 }|durv|j}|j}|j}|dusJ d|tjkr|d |v r$||d  }||!| n |tttfv rntd| qv|D ]}| j| }d}|
t |k rT||
 }|dur,|j}|j}|j}|dusJ d|ttfv rnH|ttfv r|d |v r||d  }||"| ntd| q,t#t$| W q t%yD } z>t&'d| j| j(j)|
| t&'dt*| j|
d |W Y d}~qd}~0 0 q| +| j| dS )
        Operate on the microbatches for looped schedules (multiple stages on each rank).

        TODO: Does not use sorted_batch_isend_irecv(). As a result, this schedule does
        not support models with skip connections.
        r   c                 S   s   i | ]}|j |qS r9   r  r   r9   r9   r:   r  s  s   z<PipelineScheduleMulti._step_microbatches.<locals>.<dictcomp>r   NzCAll currently supported action types require valid microbatch_indexTr   Zfull_backwardr   Fr   zUnknown computation type zy[Rank %s] pipeline schedule %s caught the following exception                      at time_step %s when running action %sz%srV   ),r   r0  r3  r,  r  keysr  r/  r   r   rp   rU   r}   rF   rG   rE   r    r4   r   r   r   r   r7   r   r   r   r   r   r5   r6   backward_weight_one_chunkr   rS   r   r   r   r   	Exceptionr   errorr   r@   r~   r   )r8   r   r   r   r   stage_index_to_stageZall_prev_ranksZall_next_ranksrE   backward_counter	time_stepr>   r   rF   r   r   r   r   r   r   Z	prev_rankZprev_rank_opsZprev_rank_actionZ	next_rankZnext_rank_opsZnext_rank_actioner9   r9   r:   r   _  s   






z(PipelineScheduleMulti._step_microbatches)NNNNNT)r?  )NNNN)r@   rA   rB   r   ry   r   rR   r
   r   r   r   r   rI   r   r   r   r   r3  rD   r4  r>  rD  r   r   r   r9   r9   r   r:   r     sD         2
$    r   c                       s   e Zd ZdZdeeeee  f e	d fddZ
de	e	d fddZe	d	d
dZdd Zdee ee ee ee dddZ  ZS )_PipelineScheduleRuntimea%  
    Provides a simple runtime that requires a 'schedule IR' including specified communication operations.

    Can be instantiated directly by creating _PipelineScheduleRuntime and calling load_csv, or can be
    subclassed and the subclass can be responsible for creating a schedule IR.
    r?  )rY   rC  c                    s   t  | i  _|dkrZ|D ]8}g  j|< || D ] }|dusDJ  j| | q4qnR|dkr|D ]}t||  j|< qft j fdd jd _ntd|ddS )	z
        Given an in-memory representation for a simple compute-only schedule, lower it to a complex schedule including
        communication actions.  Stores the schedule in self, and must be called before running step_mo()
        compute_commsNr?  c                    s
    j |  S rH   r/  r  r   r9   r:   r  @  r[   z8_PipelineScheduleRuntime._load_actions.<locals>.<lambda>r  r   format= is not implemented)r   r4  pipeline_order_with_commsr   r  r'  r   r   )r8   rY   rC  r}   r>   r   r   r:   _load_actions#  s&    



z&_PipelineScheduleRuntime._load_actions)r<  rC  c                    s   |dkr"t  | | | j n|dkri }t|ddL}t|}t|D ]\}}dd |D ||< qN| j||d W d   q1 s0    Y  ntd	|d
dS )a	  Loads a csv in simple format and then lowers it to include communication actions

        format must be either "compute_only" or "compute_comms".  If compute_only, the lowering passes
        will automatically be run to generate a compute_comms schedule.
        r?  rR  rM   r5  c                 S   s   g | ]}t |qS r9   r@  rA  r9   r9   r:   r`   V  r[   z6_PipelineScheduleRuntime._load_csv.<locals>.<listcomp>)rC  NrU  rV  )	r   rD  rX  rU   r7  r8  rB  rp   r   )r8   r<  rC  rY   r=  rB  r}   rr   r   r9   r:   rD  F  s    
.z"_PipelineScheduleRuntime._load_csv)r<  c                 C   sj   | j dusJ dt|ddd6}t|}| j D ]}|| j |  q2W d   n1 s\0    Y  dS )zaDump a CSV representation of the compute + comms schedule into a file with the provided filename.Nz6Must initialize compute_comms schedule before dump_csvr   rM   r5  )rW  r7  r8  r9  r:  r;  r9   r9   r:   r>  [  s    

z"_PipelineScheduleRuntime._dump_csvc                    s   t  j fdd jS )Nc                    s
    j |  S rH   rS  r  r   r9   r:   r  j  r[   z4_PipelineScheduleRuntime._simulate.<locals>.<lambda>)_simulate_comms_computerW  r   r   r9   r   r:   	_simulateg  s
    
z"_PipelineScheduleRuntime._simulateNr   c                    sZ  |  ||||\}}| js.| |d |d  dd | jD }| jdusPJ di }i }g }i  t td fdd}	t }
t| j| j	 D ]\}}z.|j
}|jdur|jnd	}|dks|ttfv sJ d
|d|j}|| }t|jt}|d |v }|d |v }td|| |tkr>|t|| n|tkr`|t|| nh|tkr||f|vsJ dt|||||f< n.|tkr||f|vsJ dt|||||f< n|tkr|r|vr| vsJ d|d|jjdd |< n|tkrp|r|v sHJ d|d| vsbJ d|d|j  nX|t kr
|r|	| |j!s|s||f|v sJ d|dt"|#||f |$||| || }| %|||| |r||d  &|| n|t'kr|r"|	| |j(s`|s`||f|v sNJ d|dt"|#||f | )||}|
|  d7  < |
| | j*k}| j+r| j*nd}|j,||d|d |r|+| |r||d  -|.|| n|t/krv|r|	| |j(s6|s6||f|v s$J d|dt"|#||f | )||}|j,||ddd |r||d  -|.|| nR|t0kr|r|	| |
|  d7  < |j1||
| | j*kd nt2d
|dW q t3y } z0t4d|| t5t6| j|d |W Y d}~qd}~0 0 qt7|r2t"|#  qt7 dksHJ d | 8| j| dS )!rE  r   c                 S   s   i | ]}|j |qS r9   r  r   r9   r9   r:   r    s   z?_PipelineScheduleRuntime._step_microbatches.<locals>.<dictcomp>Nz=Must call _load_actions() before calling _step_microbatches()r  c                    s>   |  v r$ |      | = |  | v s:J d| dS )zQIf an unshard is active for `stage_idx`, wait() it and mark `stage_idx` unshared.z*Attempted to compute on sharded stage_idx=N)r   r  r[  Zunshard_opsZunsharded_stagesr9   r:   _assert_unsharded  s    

zF_PipelineScheduleRuntime._step_microbatches.<locals>._assert_unshardedzaction=z missing mb_indexr   z8_PipelineScheduleRuntime running time_step %d, action %szARecv twice for {stage_idx=} {mb_index=} without executing forwardzBRecv twice for {stage_idx=} {mb_index=} without executing backwardzUnsharding the same stage_idx=z twiceT)Zasync_opzResharding stage_idx=z without unshardingz before finishing unshardzComputing action=z before receiving inputz Attempted to run compute action=rF  FrG  z is unknown or unsupportedz\_PipelineScheduleRuntime caught exception at step %s when running action %s.  Full Schedule:rH  zUnused unshard operations)9r   r0  r3  r,  rW  r  rR   r   rp   r}   rF   rG   r-   r.   rE   r   submodr   r   r   r/   r   r   r   r1   r   r0   r   r2   r   ZunshardZreshardr4   r2  r   r  r   r   Zset_local_fwd_inputr7   r   r   r   r   r   Zset_local_bwd_inputZget_local_bwd_outputr5   r6   rJ  r   rK  rL  printr~   rS   r   )r8   r   r   r   r   rM  Zbwd_recv_opsZfwd_recv_opsZsend_opsr]  rN  rO  r>   Z	comp_typer   r  r   Zstage_uses_fsdpZis_next_stage_on_this_rankZis_prev_stage_on_this_rankr   r   r   r   rP  r9   r\  r:   r   n  sx   





















z+_PipelineScheduleRuntime._step_microbatches)r?  )r?  )NNNN)r@   rA   rB   r   r   rR   ry   r
   rD   rI   rX  rD  r>  rZ  r   r   r9   r9   r   r:   rQ    s&   
 #	    rQ  c                	       s`   e Zd ZdZd	ee eeee	e
f  eeeeef ee f  ed fddZdd Z  ZS )
r   ai  
    Breadth-First Pipeline Parallelism.
    See https://arxiv.org/abs/2211.05953 for details.
    Similar to Interleaved 1F1B, Looped BFS supports multiple stages per rank.
    What is different is that when microbatches are ready for multiple local
    stages, Loops BFS will prioritizes the earlier stage, running all available
    microbatches at once.
    NTr   r   r   r   r   c                    sD   t  j|||||d i | _t| jD ]}| |}|| j|< q&d S )Nra  )r   r   rU   rv   r   !_calculate_single_rank_operations)r8   r   r   r   r   r   r}   rank_opsr   r9   r:   r   k  s    
zScheduleLoopedBFS.__init__c                    s   t | j}t|| j| | j}dd t|D }|D ]" | fddt| jD  q4d| jd |  }|d g|  t|D ]& | fddtt| jD  q|S )Nc                 S   s   g | ]}d qS rH   r9   rX   r   r9   r9   r:   r`     r[   zGScheduleLoopedBFS._calculate_single_rank_operations.<locals>.<listcomp>c                 3   s   | ]}t  tj|V  qd S rH   )rD   r    r4   rX   r   r  r9   r:   rZ     s   zFScheduleLoopedBFS._calculate_single_rank_operations.<locals>.<genexpr>r!   r   c                 3   s   | ]}t  tj|V  qd S rH   )rD   r    r7   re  r  r9   r:   rZ     s   )rS   r,  rv   r   r   r   reversed)r8   r}   n_local_stagesZstage_indicesrc  post_warmup_opsr9   r  r:   rb    s     


z3ScheduleLoopedBFS._calculate_single_rank_operations)NNT)r@   rA   rB   r   ry   r   rR   r
   r   r   r   r   rI   r   r   r   r   rb  r   r9   r9   r   r:   r   a  s      r   Fc
                 C   s  t t}
t t}t t}dd t|D }| | d|d |   ||  }|	rZ|| d }|| | }g }d}|	rvtnt}t|D ]}||k r||}|
|  }d |
|< |t|tj| ||d kr|	d g|  q||  kr|| k rn n||}|
|  }d |
|< |t|tj| ||}||  }d ||< |t||| || |	rJ|| |krJ||| }||  }d ||< |t|tj
| |d7 }q|	s|d  ||}||  }d ||< |t||| || |	r|| |kr||| }||  }d ||< |t|tj
| |d7 }q|	r|t|k r||| }||  }d ||< |t|tj
| |d7 }qL|S )Nc                 S   s   g | ]}d qS rH   r9   rd  r9   r9   r:   r`     r[   z&_get_1f1b_rank_ops.<locals>.<listcomp>r!   r   r   )r   rR   rv   r5   r7   r   rD   r    r4   r   r6   rS   )rg  r   
warmup_opsfwd_bwd_opscooldown_opsr}   forward_stage_indexbackward_stage_indexnum_1f1b_microbatchesenable_zero_bubbleZfwd_stage_mb_indexZbwd_stage_mb_indexZweight_stage_mb_indexrc  rh  	total_opsZbackward_op_idsZweight_op_countZFULL_BACKWARD_OR_BACKWARD_INPUTr   Zfwd_stage_indexr   r   Zbwd_stage_indexr   Zweight_stage_indexZweight_mb_indexr9   r9   r:   _get_1f1b_rank_ops  s    	














rq  c                       s   e Zd ZdZdee eee ee	e
df  eeee
f  eeeeef e	e f  ed fddZeee  dd	d
Z  ZS )r   a  
    The Interleaved 1F1B schedule.
    See https://arxiv.org/pdf/2104.04473 for details.
    Will perform one forward and one backward on the microbatches in steady
    state and supports multiple stages per rank. When microbatches are ready for
    multiple local stages, Interleaved 1F1B prioritizes the earlier microbatch
    (also called "depth first").

    This schedule is mostly similar to the original paper.
    It differs by being relaxing the requirement of num_microbatch % pp_size == 0.
    Using the flex_pp schedule, we will have num_rounds = max(1, n_microbatches // pp_group_size) and
    it works as long as n_microbatches % num_rounds is 0. As a few examples, support

    1. pp_group_size = 4, n_microbatches = 10. We will have num_rounds = 2 and n_microbatches % 2 is 0.
    2. pp_group_size = 4, n_microbatches = 3. We will have num_rounds = 1 and n_microbatches % 1 is 0.
    NT.r   r   r   r   r   r   r   c           
   	      s   |d j | _t j|||||||d t|| _|d j| _td|| j | _	|| j	 | _
|| j	 dkrtd| j	 d| di | _t| jD ]}| |}	|	| j|< qd S )Nr   rr  r   z_Interleaved 1F1B requires the number of microbatches to be a multiple of the number of rounds (), but got .)r-  r   r   r   rS   rg  r.  r}   rg   number_of_roundsmicrobatches_per_roundr   rU   rv   rb  )
r8   r   r   r   r   r   r   r   r}   rc  r   r9   r:   r   <  s4    
	

z ScheduleInterleaved1F1B.__init__r   c           	   	      s   fdd}| j j }| }|| }| | }td |||  fdd} fdd}tj j|| ||S )Nc                    s<    j d  j }d}|| jd |    }t| j j  S )Nr   r!   rg  rv  r   r   r   r}   Zwarmups_ops_last_stageZmultiply_factorri  r   r9   r:   get_rank_warmup_opsc  s    zVScheduleInterleaved1F1B._calculate_single_rank_operations.<locals>.get_rank_warmup_ops=rank %s, warmup_ops %s, 1f1b %s, cooldown_ops %s total_ops %sc                    s   | j  j }|j   S rH   rv  rg  r   r   Zlocal_indexr}   r8   r9   r:   rl    s    zVScheduleInterleaved1F1B._calculate_single_rank_operations.<locals>.forward_stage_indexc                    s,   j d |  j j   }|j   S Nr   rg  rv  r   r|  r}   r8   ri  r9   r:   rm    s    zWScheduleInterleaved1F1B._calculate_single_rank_operations.<locals>.backward_stage_indexrg  r   r   r   rq  r   )	r8   r}   ry  microbatch_opsrj  rk  rp  rl  rm  r9   r  r:   rb  b  s4    
	z9ScheduleInterleaved1F1B._calculate_single_rank_operations)NNNNTr@   rA   rB   r   ry   r   rR   r
   r   r   r   r   rI   r   r   r   r   rD   rb  r   r9   r9   r   r:   r   *  s         &r   c                       s   e Zd ZdZdee eee ee	e
df  eeee
f  eeeeef e	e f  ed fddZeee  dd	d
Zdd Z  ZS )r   aw  
    The Interleaved Zero Bubble schedule.
    See https://arxiv.org/pdf/2401.10241 for details.
    Will perform one forward and one backward on inputs for the microbatches in steady
    state and supports multiple stages per rank. Uses the backward for weights to fill in
    the pipeline bubble.

    In particular this is implementing the ZB1P schedule in the paper.
    NT.rr  c              	      s   |D ]}t |jtrtdq|d j| _t j|||||||d t|| _	|d j
| _td|| j | _|| j | _|| j dkrtd| j d| di | _t| jD ]}	| |	}
|
| j|	< q| | j	| j | _d S )NzYThe Zero Bubble schedule is not supported with stage modules that have used torch.compiler   rr  r   zZZero bubble requires the number of microbatches to be a multiple of the number of rounds (rs  rt  )r   r_  r   r<   r-  r   r   r   rS   rg  r.  r}   rg   ru  rv  r   rU   rv   rb  _add_bubbles_to_actionsr8   r   r   r   r   r   r   r   r   r}   rc  r   r9   r:   r     sD    	


z&ScheduleInterleavedZeroBubble.__init__r   c           
         s   fdd}| j j }| }|| }| | }td |||  fdd} fdd} }	tj j|| |||	dd	
S )
Nc                    s<    j d  j }d}|| jd |    }t| j j  S r~  rw  rx  r   r9   r:   ry    s    z\ScheduleInterleavedZeroBubble._calculate_single_rank_operations.<locals>.get_rank_warmup_opsrz  c                    s   | j  j }|j   S rH   r{  r|  r}  r9   r:   rl  	  s    z\ScheduleInterleavedZeroBubble._calculate_single_rank_operations.<locals>.forward_stage_indexc                    s,   j d |  j j   }|j   S r~  r  r|  r  r9   r:   rm  	  s    z]ScheduleInterleavedZeroBubble._calculate_single_rank_operations.<locals>.backward_stage_indexT)ro  r  )
r8   r}   ry  r  rj  rk  rp  rl  rm  rn  r9   r  r:   rb    s:    	z?ScheduleInterleavedZeroBubble._calculate_single_rank_operationsc                 C   sx  | j }dd }t }i }i }i }d}t| jD ]}	g ||	< d||	< d||	< q.d}
t }t| jD ]}	||	 }|t||	 kr~q`d}
||	 | d ur(||	 | }|d usJ |\}}}||||||s||	 ||	 |  |d ur||||f ||	  d7  < n||	 d  ||	  d7  < q`||	  d7  < ||	 d  q`|| |
rLq\qL|dkrtt	d|| |S )Nc                 S   sf   |t jkr*| dkrb| d ||f|vrbdS n8|t jkrb| |d krP| t j|f|vS | d ||f|vS dS )Nr   r   TF)r    r4   r7   )r   r   
microbatchnum_stages_globalseen_opsr9   r9   r:   need_bubble"	  s    

zJScheduleInterleavedZeroBubble._add_bubbles_to_actions.<locals>.need_bubbler   TFr   z?Non zero bubbles added: total_bubbles_added=%s bubbles_added=%s)
rU   r  rv   r   rS   r   r  updater   r1  )r8   r  rY   r  r  resultZnext_pointerZbubbles_addedZtotal_bubbles_addedr}   Zshould_stopZtemp_seen_ops	timestampZtemp_actionrE   r   r  r9   r9   r:   r  	  sV    





z5ScheduleInterleavedZeroBubble._add_bubbles_to_actions)NNNNT)r@   rA   rB   r   ry   r   rR   r
   r   r   r   r   rI   r   r   r   r   rD   rb  r  r   r9   r9   r   r:   r     s"        6@r   c                       s   e Zd ZdZdee eee ee	e
df  eeee
f  eeeeef e	e f  ed fddZeee  dd	d
Z  ZS )r   a  
    The Zero Bubble schedule (ZBV variant).
    See https://arxiv.org/pdf/2401.10241 Section 6 for details.

    This schedules requires exactly two stages per rank.

    This schedule will perform one forward and one backward on inputs for the microbatches in steady
    state and supports multiple stages per rank. Uses backward with respect to weights to fill in
    the pipeline bubble.

    This ZB-V schedule would have the "zero bubble" property only if time forward == time backward input == time backward weights.
    In practice, this is not likely true for real models so alternatively
    a greedy scheduler could be implemented for unequal/unbalanced time.
    NT.rr  c              	      s   |d j | _t j|||||||d t| j| jdd| _| jD ]}| j|_q@t|| _	| j	dkrtt
d| j	 d|d j| _|d j| _i | _t| jD ]}	| |	}
|
| j|	< qd S )Nr   rr  v)styler!   z0ZBV requires exactly 2 stages per rank, but got rt  )r-  r   r   r   r   r   r/  r,  rS   rg  r   r.  r}   r   rU   rv   rb  r  r   r9   r:   r   r	  s8    
	





zScheduleZBVZeroBubble.__init__r   c                    s,  t d j d  j}dd t|D }d\}}}}d j|  d }|}	 jd | }
t|D ] }|t|	t|d |d7 }q`|}t|D ]<}|t|
t|d |d7 }|t|	t|d |d7 }q j| }t|D ]P}|t|
t|d |d7 }|t|
t|d |t|
t	|d |d7 }q||k sD||k r||k rj|t|	t|d |d7 }|t|	t|d |t|	t	|d |d7 }|t|
t|d |d7 }|t|
t|d |t|
t	|d |d7 }q0|| }}|}t|D ]>}|t|	t|d |d7 }|t|
t|d |d7 }q  j| }t|D ]>}|t|	t|d |d7 }|t|	t	|d |d7 }qR||k r|t|
t	|d |d7 }q||k r|t|	t	|d |d7 }q||kr||ksJ ||kr||ksJ  fdd|D }|S )Nr!   r   c                 S   s   g | ]}d qS rH   r9   rd  r9   r9   r:   r`   	  r[   zKScheduleZBVZeroBubble._calculate_single_rank_operations.<locals>.<listcomp>)r   r   r   r   )rF   rG   c                    s2   g | ]*}|d ur*|j d ur*|j  jk r*|nd qS rH   )rG   r   )rX   r>   r   r9   r:   r`   
  s   
)
rg   r   r   rv   r   r   rD   r*   r+   r,   )r8   r}   Zn_microrc  Zf0_cntZf1_cntZb0_cntZb1_cntZ	warmup_n1Zstage_id_chunk0Zstage_id_chunk1r   Z	warmup_n2Z	warmup_n3Zw0_cntZw1_cntZcooldown_n1Zcooldown_n2r9   r   r:   rb  	  s    










z7ScheduleZBVZeroBubble._calculate_single_rank_operations)NNNNTr  r9   r9   r   r:   r   b	  s         ,r   )schedule_namec              	   C   s`   t tttttttd}dd | D }| 	 }||vrTt
d|  dt|  |||  S )z
    Maps a schedule name (case insensitive) to its corresponding class object.

    Args:
        schedule_name (str): The name of the schedule.
    )Z1F1BZInterleaved1F1BZGPipeZ	LoopedBFSZInterleavedZeroBubbler   r   ZZBVZeroBubblec                 S   s   i | ]}|  |qS r9   )lower)rX   kr9   r9   r:   r  &
  r[   z&get_schedule_class.<locals>.<dictcomp>zUnknown schedule name 'z'. The valid options are )r   r   r   r   r   r   r   r   rI  r  r   ry   )r  Zschedule_mapZlowercase_keysZlowercase_schedule_namer9   r9   r:   r   
  s     
r   rT  c           	         s  fddt D dd t D dd D  ttt d fdd}tt td fd	d
}rd}t D ]Z}t| dkrq~| d }||r|dur||| | d d}q~||d q~t ddD ]}t| dkr|= qt D ]z}t| dkr$q
| d dur:q
| d }||r
|durt|| d<  | | | d q
t ddD ]}t| dkr|= q|sltdt	 D ]"}td|d| d   qt
dqlS )a  This function dry-run simulates the actions in the schedule from the perspective of all ranks, and flags
    any deadlocks caused by missing or misordered communications.  It also simulates any bubbles in time where a rank
    can not execute any action due to waiting for unmet dependencies.  The total number of simulator steps can be used
    as a metric for unit tests involving IR optimization passes as reordering and merging of IR can reduce the number
    of simulated steps.

    The simulation is not high-fidelity and does not model overlapping of compute and communication, or cuda streams.
    Future work may be to enhance this and model the compute time, comms overlap, and even memory.
    c                    s    i | ]}|d d  | D qS )c                 S   s   g | ]}|d ur|qS rH   r9   )rX   r  r9   r9   r:   r`   <
  r[   z6_simulate_comms_compute.<locals>.<dictcomp>.<listcomp>r9   r  )rU   r9   r:   r  ;
  s   z+_simulate_comms_compute.<locals>.<dictcomp>c                 S   s   i | ]
}|g qS r9   r9   r  r9   r9   r:   r  ?
  s   c                 S   s   i | ]}|t  qS r9   r  r  r9   r9   r:   r  C
  r[   r}   r>   c                    s(   |   | |d ur$ |  | d S rH   )r   r  r  )_prev_ops_rank	_scheduler9   r:   add_to_scheduleE
  s    z0_simulate_comms_compute.<locals>.add_to_scheduler  c                    s  | d u rdS | j } | }| jtkrn| j dkr6dS t| j t| j|v rNdS t| j d t| j|v rjdS dS | jttfv r| j d krdS t| j t| j|v rdS t| j d t| j|v rdS t| j d t| j|v rdS dS | jt	krdS | jt
krt| j t| j}||v S | jtkrF|d }t|t
| j}| | v S | jtkrt| j t| j}t| j t| j}||v p||v S | jtkr|d }t|t| j}| | v S td|  d S )NTr   r   FzUnsupported action type )rE   rF   r*   rD   r0   rG   r5   r7   r2   r6   r/   r1   r   )r>   r  Zprev_opsZ
expected_fZpeer_stage_idxZexpected_sendZ
expected_bZexpected_bw)r  r   r  r9   r:   r$  J
  sh    




z3_simulate_comms_compute.<locals>._ready_to_scheduleFr   NT)reverser^  zWIP comms schedule:
r%  z next action= zSchedule is not progressing)rx   rR   r
   rD   r   rS   r  r  r`  r~   r   )	rU   r  r   r  r$  r&  r}   r>   r^   r9   )r  r  r   rU   r  r:   rY  /
  sX    
:



 
rY  c                 C   s   g }t | D ]V}t| | D ]D\}}|du r.q|t||jtttfv rLdndd|||dd qqddl}t	|d }|
d	|i| W d   n1 s0    Y  dS )
a  
    This function dumps a schedule IR into a chrometrace format so it can be visualized.

    It is currently very basic and only serves as a graphical alternative to dumping the schedule IR as text.

    As future work we may extend this to include more accurate heuristics for durations, or let users input durations,
    add 'flow events' to let the UI show the connection between sends and recvs, and model cuda streams for comm/compute
    as separate streams on the chrometrace view.
    NZcomputationZcommunicationXr   )r   catphpidtidtsZdurr   r   ZtraceEvents)rx   rp   r   rI   rF   r*   r3   r,   jsonr7  dump)Zscheduler<  eventsr}   Ztimestepr>   r  fr9   r9   r:   _dump_chrometrace
  s(    
r  )N)N)N)r"   )r   F)[rt   r8  rz   loggingreabcr   r   collectionsr   r   enumr   typingr   r   r	   r
   r   r   Ztorch.distributedZdistributedr   Ztorch._dynamor   Ztorch.distributed.fsdpr   r   Ztorch.nn.modules.lossr   Ztorch.profilerr   _utilsr   r  r   r   r   r   r   __all__	getLoggerr@   r   r    r4   r5   r6   r-   r.   r/   r0   r1   r2   r7   r*   r+   r,   r3   compilerO   rD   r   rR   ry   rI   r~   r   ZP2POpZWorkr   r   r   r   r   r   r   r  r  r'  r)  r   rQ  r   rq  r   r   r   r   rY  r  r9   r9   r9   r:   <module>   s   
6% > 8  i3{ Z 
G
$j
L  T  HI  
 
t E 4 