a
    hK@                     @   s   d dl Z d dlZd dlZd dlZd dlmZmZ d dlmZ d dl	m
Z
 d dlmZmZ d dlmZ d dlmZ eeZeej G dd	 d	ZG d
d dZG dd dZdS )    N)IterableSequence)Optional)GraphModule)_get_qualified_nameNode)OperatorSupportBase)fuse_by_partitionsc                   @   s\   e Zd Zdee eee  dddZedddZ	edd	d
Z
edddZdd ZdS )	PartitionNidnodesc                 C   s"   || _ |d urt|ni | _d S N)r   dictfromkeysr   )selfr   r    r   O/var/www/auris/lib/python3.9/site-packages/torch/fx/passes/infra/partitioner.py__init__   s    zPartition.__init__returnc                 C   s
   t | jS r   )strr   r   r   r   r   __repr__   s    zPartition.__repr__nodec                 C   s   | j |d i d S r   )r   updater   r   r   r   r   add_node   s    zPartition.add_nodec                 C   s   | j |= d S r   r   r   r   r   r   remove_node    s    zPartition.remove_nodec                 C   s
   t | jS r   )lenr   r   r   r   r   size#   s    zPartition.size)NN)__name__
__module____qualname__r   intr   r   r   r   r   r   r    r"   r   r   r   r   r
      s    r
   c                   @   s.   e Zd ZedddZeee dddZdS )_DependencyViewer)graph_modulec                 C   sT   t t| _t|jjD ]6}|jD ]*}| j| | | j| 	| j|  q"qd S r   )
collectionsdefaultdictsetdownstreamsreversedgraphr   usersaddr   )r   r(   r   Zoutput_noder   r   r   r   (   s
    
z_DependencyViewer.__init__r   r   c                 C   s
   | j | S r   )r,   r   r   r   r   downstreams_of1   s    z _DependencyViewer.downstreams_ofN)r#   r$   r%   r   r   r   r+   r2   r   r   r   r   r'   '   s   	r'   c                   @   s   e Zd Zdeeeeee  eee  ddddZ	e
edddZee d	d
dZdee eedddZee dddZdeedddZdS )CapabilityBasedPartitionerFN)r(   operator_supportallows_single_node_partitionnon_compute_ops!allowed_single_node_partition_opsr   c                 C   sD   || _ || _|| _|d ur|ng | _|d ur0|ng | _t|| _d S r   )r(   r4   r5   r6   r7   r'   dependency_viewer)r   r(   r4   r5   r6   r7   r   r   r   r   6   s    z#CapabilityBasedPartitioner.__init__r1   c                 C   s   | j t| j |S r   )r4   Zis_node_supportedr   r(   Znamed_modulesr   r   r   r   _is_node_supportedI   s    z-CapabilityBasedPartitioner._is_node_supportedr   c                    s  t ti  i i }i i t }ttd fdd}ttt d fdd}t	d t
jjjD ]}i }|r| vrt|}|||< ||< ||| d ||< t tdd	D ]\}}d ||< qt| }	t|	dkrz|	d
 }
|	dd  D ]}||
|\}
}qqzt	d i }jjjD ]x}d}|jD ],}|jdkspt|jdkrPd} q~qP|rB |d }|jD ] } |d |kr|||< qqB| D ]\}}||| qĈjst	d ddh}|tj}g } D ]~\}}d
}|jD ]T}|jdkr"t|jsBJ t|j|vrZ|d7 }t|jj v r"|d7 }q"|dkr|!| q|D ]}|= qt	d  D ]$\}}t	d|dd |jD  qdd " D S )N)self_idother_idc                    s   j   j tt d 	fdd}   B }| ||rbdfS   }}ttk r|| }}| j | j  | j D ]}||< q|= t| | |< |= | | |< |= ||< |= |dfS )N)all_user_nodesc                    s   | D ]x}t  }j|D ]`}|v s.|v r6  dS | v r | }||v rPq| }|v sh|v rp  dS || qqdS )NTF)r+   r8   r2   r0   )r<   	user_nodeZvisited_partition_idsZ	path_nodepartition_idZp_map)
assignmentr;   other_nodespartition_mapr   r:   
self_nodesr   r   dfs_iter_find_cyclel   s    ziCapabilityBasedPartitioner.propose_partitions.<locals>.maybe_merge_partition.<locals>.dfs_iter_find_cycleFT)r   r+   r   difference_updater!   r   minunion)r:   r;   rC   r<   Zmerge_idZ
removed_idr   r?   rA   partition_userspartitions_by_idZpartitions_orderr   )r;   r@   r:   rB   r   maybe_merge_partitiong   s2    

"


zLCapabilityBasedPartitioner.propose_partitions.<locals>.maybe_merge_partitionr   r   c                    s   t td fdd}|  v r0 |   |  |d u rD |  nR|vr| | < t|| gd|< t| j|< || | n| | < | |  d S )NrK   c                    sD   | j D ]8} |d }|d ur| | | |  qd S r   )r/   getr0   r   )r   r   r=   Z	target_id)r?   rA   r   r   _update_partition_map   s
    
zgCapabilityBasedPartitioner.propose_partitions.<locals>.merge_single_node.<locals>._update_partition_mapr   )r   r&   r    popr
   r+   r/   r   )r   r   rM   )r?   rA   rH   rI   r   r   merge_single_node   s    	zHCapabilityBasedPartitioner.propose_partitions.<locals>.merge_single_nodezProposing partitions...   )keyr   z=Reassigning getitem nodes to its producer node's partition...Tcall_functionz_operator.getitemFz'Filtering out single node partitions...ztorch.ops.aten.viewzPartitions proposed:zpartition #%s: %sc                 S   s   g | ]
}|j qS r   )name).0r   r   r   r   
<listcomp>      zACapabilityBasedPartitioner.propose_partitions.<locals>.<listcomp>c                 S   s   g | ]}|  d kr|qS )r   )r"   rT   	partitionr   r   r   rU     s   )#r)   r*   r+   	itertoolscountr&   r   r   loggerdebugr-   r(   r.   r   r9   nextsorteditemsoperator
itemgetterlistkeysr!   r/   opr   targetrL   r5   rF   r6   callabler7   appendvalues)r   Znodes_orderZnew_partition_idrJ   rO   r   Zmerge_candidatesr>   _Zmerge_candidates_listr:   r;   Znodes_reassignmentZis_tuple_outputuserr   Zdefault_non_compute_opsr6   Zpartitions_to_removerX   Zcompute_node_countr   rG   r   propose_partitionsN   s    
H











z-CapabilityBasedPartitioner.propose_partitionsfused_)
partitionsprefixr   c                 C   s$   t d t| jdd |D |dS )NzFusing partitions...c                 S   s   g | ]
}|j qS r   r   rW   r   r   r   rU   %  rV   z>CapabilityBasedPartitioner.fuse_partitions.<locals>.<listcomp>rn   )r[   r\   r	   r(   )r   rm   rn   r   r   r   fuse_partitions  s    
z*CapabilityBasedPartitioner.fuse_partitions)rm   c                    s   t | jtdfdd i i tt t t t d fddtt t t t d fdd|D ]p}t  }|jD ]:} |r||t |j|s|t |j|r||| q|t|d	krl|D ]}|j|d  qqld S )
Nr   c                    s   | j dkot| j v S )NrR   )rd   r   re   r   )r6   r   r   is_non_compute_node-  s    
zVCapabilityBasedPartitioner.remove_bookend_non_compute_ops.<locals>.is_non_compute_node)r   rX   removed_nodesc                    st   | j dks| |vs| |v rdS | v r.|  S  | rh| jD ]}|||s<d| <  dS q<d| < dS d| < dS NplaceholderTF)rd   Zall_input_nodes)r   rX   rr   Zinput_n)rq   is_transparent_input_nodetransparent_input_nodesr   r   ru   7  s$    
z\CapabilityBasedPartitioner.remove_bookend_non_compute_ops.<locals>.is_transparent_input_nodec                    st   | j dks| |vs| |v rdS | v r.|  S  | rh| jD ]}|||s<d| <  dS q<d| < dS d| < dS rs   )rd   r/   )r   rX   rr   Zoutput_n)rq   is_transparent_output_nodetransparent_output_nodesr   r   rw   L  s(    
z]CapabilityBasedPartitioner.remove_bookend_non_compute_ops.<locals>.is_transparent_output_noder   )r+   r6   r   r   r0   r!   rN   )r   rm   rX   r    r   r   )rq   ru   rw   r6   rv   rx   r   remove_bookend_non_compute_ops*  s,    

z9CapabilityBasedPartitioner.remove_bookend_non_compute_ops)rn   r   c                 C   s   |   }| j||d}|S )Nro   )rk   rp   )r   rn   rm   Zfused_gmr   r   r   partition_and_fuseu  s    z-CapabilityBasedPartitioner.partition_and_fuse)FNN)rl   )rl   )r#   r$   r%   r   r   boolr   r   r   r   r   r9   rb   r
   rk   rp   ry   rz   r   r   r   r   r3   5   s*      

 R Kr3   )r)   rY   loggingr`   collections.abcr   r   typingr   Ztorch.fx.graph_moduler   Ztorch.fx.noder   r   Z torch.fx.passes.operator_supportr   Z!torch.fx.passes.utils.fuser_utilsr	   	getLoggerr#   r[   setLevelWARNINGr
   r'   r3   r   r   r   r   <module>   s   
