a
    h                     @   s  U d dl Z d dlZd dlZd dlZd dlZd dlZd dlZd dlZd dlZd dl	Z	d dl
Z	d dlZd dlZd dlZd dlZd dlmZ d dlmZmZmZmZmZ d dlmZ d dlZd dlmZmZmZ d dlmZ d dlm Z  d dl!m"Z" d d	l#m$Z$ d d
l%m&Z& d dl'm(Z(m)Z) d dl*m+Z+ d dl,m-Z- d dl.m/Z/ d dl0m1Z1 ddl2m3Z3m4Z4 ddl5m6Z6m7Z7m8Z8m9Z9m:Z: ddl;m<Z< e=e>Z?e$e>dZ@e$e>dZAeBe ZCe DdddgZEg dZFejGeHdddZIdJeBe6 eHeeJ dddd ZKeBe6 ejLd!d"d#ZMdKeeC eNeJeJf eeJ eOdd$d%d&ZPeNeJeJf eNeJeEf d'd(d)ZQejj&eCdd*d+d,ZRejSed dd-d.ZTi ZUeNeJef eVd/< dZWeeO eVd0< G d1d2 d2ZXG d3d4 d4ZYeCdd5d6d7ZZeCdd5d8d9Z[ej\G d:d; d;Z]e^ Z_eOeNeJef eNeJef eNeJeNeJef f d<d=d>Z`eedd?d@dAZaeJedBdCdDZbddEedFeJf ejcjdeNeJef ee- eJdGdHdIZedS )L    N)Iterator)AnyCallableIOOptionalUnion)patch)
draw_graphget_aot_graph_nameget_graph_being_compiled)fx)save_graph_repro)get_debug_dir)getArtifactLogger)GraphModule)_extract_tensor_metadataTensorMetadata)legalize_graph)FileLike)
OrderedSet)tree_map   )configir)BaseSchedulerNodeFusedSchedulerNodeNopKernelSchedulerNode
OutputNodeSchedulerNode)Vir_pre_fusionir_post_fusionBufMetanameZn_origin)dotz-Gnslimit=2z-Gnslimit1=2z-Gmaxiter=5000returnc                   C   s   t dd uS )Nr$   )shutilwhich r)   r)   C/var/www/auris/lib/python3.9/site-packages/torch/_inductor/debug.pyhas_dot5   s    r+   F)nodesprint_graphfnamer&   c           	   	   C   s   t  std dS |du r"t }t| }|jD ]~}d|jvr@q0|jd j}t|t	rxt|d t
rp|d f}n|d }d}t|tjr|jj}t||ddddd}||jd< q0|rt| ti |}t| |j  t||dtjjd dS )z$
    Draw a graph in fname.svg.
    z*draw_buffers() requires `graphviz` packageNfusion_metar   Ztensor_metaF)
clear_metadot_graph_shape)r+   logwarningr   create_fx_from_snodesr,   metagroup
isinstancetupleintr   ZComputedBufferdatadtyper   printr   r   graphZlintr	   r   tracer1   )	r,   r-   r.   r=   noder6   r;   metadatagmr)   r)   r*   draw_buffers:   s6    





rB   )snodesr&   c              
      sl  t tdtf ddd}tdg d}i }i }tj }d}g }d}| D ]F}	|	 rbd}
|
}nZ|		 rtd	}
|
}nHt
|	trd
}
|
}n4t
|	trd}
|	j}nt
|	trd}
|	j}ntdtjj|	 d}|
 d| }||}i }t|	drd|	 i}|j|d|d}tttf td fdd  |	rB|| |	 }||_|||	|
|jd< |||< |	 D ]}||| < qr|du rJ|}qJ| D ]}	|	 }|	jj }|| }g }|D ]x}|j|v r||j }nB|!|& |"|j}|||j< W d   n1 s0    Y  ||kr(q|| qt#||_$q|%t&|dkr^|d nt#| |S )B
    Creates a FX Graph from a list of SchedulerNode objects.
    .r#   r&   c                 S   s   t tddd}| |_|S )N)argsr&   c                  W   s   dS )Nr   r)   )rF   r)   r)   r*   func1n   s    z;create_fx_from_snodes.<locals>.get_fake_func.<locals>.func1)r   r9   __name__)r#   rG   r)   r)   r*   get_fake_funcm   s    z,create_fx_from_snodes.<locals>.get_fake_func
FusionMeta)r6   snodetypeNexterntemplateZnopZcomputeZfusedzUnknown node typeZoriginal_atenz: 
get_devicedevicer)   rF   kwargs)rK   r&   c                    s8   t | tr"t fdd| jD S tdd |  D S )Nc                 3   s   | ]} |V  qd S Nr)   ).0x	in_outputr)   r*   	<genexpr>       z;create_fx_from_snodes.<locals>.in_output.<locals>.<genexpr>c                 s   s&   | ]}|j D ]}t|jtV  qqd S rS   )Zusersr7   r?   r   )rT   bufuserr)   r)   r*   rX      s   
)r7   r   anyrC   get_outputs)rK   rV   r)   r*   rW      s
    
z(create_fx_from_snodes.<locals>.in_outputr/   r   r   )'strr   r9   collections
namedtupletorchr   GraphZ	is_externZis_templater7   r   r   r6   r   RuntimeError	_inductorutilsZget_fused_kernel_name	get_nodeshasattrrO   call_functionr   r   boolappendget_namer#   r5   r]   Zread_writesZreadsZinserting_beforeplaceholderr8   rF   outputlen)rC   rI   rJ   Zbuf_to_fx_nodeZnode_to_fx_noder=   Z
first_nodeoutputsr6   rK   Z	node_typeZ
fused_name	func_nameZ	node_funcrR   Zfx_noder#   rZ   depsnew_argsdepZdep_noder)   rV   r*   r4   h   s|    




	

*
$r4   )r,   node_name_to_buf_nameparent_buf_name	n_originsr&   c           
      C   s   | d u rd S | D ]}|  }| }|d urTt|dkrTt|||d u rJ|n| qnt|dkrl|d |kspJ |j}|d u s|jd u rq|jD ]&}|j}	|	|vr|d u r|n|||	< qqd S )Nr   r   )rk   rf   rn   $update_orig_fx_node_name_to_buf_namer?   Zoriginsr#   )
r,   rt   ru   rv   r?   buf_nameZchildren_nodesZir_nodeorigin	node_namer)   r)   r*   rw      s*    
rw   )rt   r&   c                 C   sp   i }|   D ].\}}||vr,t|g||< q|| | qi }|   D ]"\}}t|| }t||||< qH|S rS   )itemsr   addrn   r"   )rt   Zbuf_name_to_n_noderz   rx   node_name_to_buf_metaZn_noder)   r)   r*   get_node_name_to_buf_meta   s    r~   )rA   rC   r&   c                 C   sP   i }t || |du rdS t|}| jjD ] }|j|v r*||j|jd< q*dS )rD   NZbuf_meta)rw   r~   r=   r,   r#   getr5   )rA   rC   rt   r}   r?   r)   r)   r*   annotate_orig_fx_with_snodes   s    

r   c               	   c   s   t jdddk} dd l}t|jjj}t	
 }| sXzd V  W |  n
|  0 d S |tdd t jt d}t j|dd tt j|d	t  d
}|tj |td || zd V  W || |  n|| |  0 d S )NZTORCH_COMPILE_DEBUG01r   z*functorch.compile.config.debug_partitionerTtorchinductor)exist_okZaot_z
_debug.log3[%(filename)s:%(lineno)d %(levelname)s] %(message)s)osenvironr   Ztorch._functorch.aot_autogradlogging	getLoggerZ
_functorchZaot_autogradrH   
contextlib	ExitStackcloseenter_contextr   pathjoinr   makedirsFileHandlerr
   setLevelDEBUGsetFormatter	Formatter
addHandlerremoveHandler)Zcompile_debugra   r2   stackr   fhr)   r)   r*   enable_aot_logging  s:    



r    _inductor_post_to_pre_grad_nodes_pre_grad_graph_idc                	   @   s.  e Zd ZU e Zi Zeee	e f e
d< eeee dddZddddZedd	d
dZd%eeeeee dddZejd&eeeeeee  dddZeedddZddddZddddZeeddddZeee  ee ee ddddZdddd Zeee d!  d"d#d$Z!dS )'DebugContext._inductor_triton_kernel_to_post_grad_node_info)folder_namer&   c                 C   sV   t jjpt }tjD ]<}tj|d|  d| }tj	|st
| |  S qd S )Nr   .)r   r>   	debug_dirr   r   _counterr   r   r   existsr   )r   r   ndirnamer)   r)   r*   create_debug_dirE  s    


zDebugContext.create_debug_dirNr%   c                 C   s   d | _ d | _t | _d S rS   )_prof_pathr   r   _stack)selfr)   r)   r*   __init__S  s    zDebugContext.__init__)new_pathr&   c                 C   s   | j s
d S |dsJ |ddlm} zV|| d4 tj|rPt| t	| j | W d    n1 sr0    Y  W n" t
y   td| j | Y n0 d S )Nz.debugr   )FileLockz.lockz(Failed to copy debug files from %s to %s)r   endswithZfilelockr   r   r   r   r'   rmtreecopytreeOSErrorr2   r3   )r   r   r   r)   r)   r*   copyX  s    
0zDebugContext.copyw)filename
write_moderF   rR   r&   c                 O   s.   | j s
J ttj| j ||g|R i |S rS   r   openr   r   r   )r   r   r   rF   rR   r)   r)   r*   fopenh  s    
zDebugContext.fopenc                 o   sX   | j s
J ttj| j ||g|R i |}|V  W d    n1 sJ0    Y  d S rS   r   )r   r   r   rF   rR   fr)   r)   r*   fopen_contextr  s    
&zDebugContext.fopen_context)suffixr&   c                 C   s   | j s
J tj| j |S rS   )r   r   r   r   )r   r   r)   r)   r*   r   ~  s    
zDebugContext.filenamec                 C   s   t jjd urdd l}| jsJ tj| jtj| j d}|	|d*}|j
| jtj| jd W d    n1 sz0    Y  t j| d S )Nr   z.tar.gzzw:gz)arcname)r   r>   
upload_tartarfiler   r   r   r   basenamer   r|   )r   r   Ztar_filetarr)   r)   r*   r     s    
8zDebugContext.upload_tarc                    s   t jrDtd  j} tj td d fdd}| j	|| | j
t|  t jjsbd S | t | _t jjr| dtj t jjr| dtj d S )Nztorch._dynamo)levelr&   c                    s     |  d S rS   )r   )r   r2   r)   r*   reset_log_level  s    z/DebugContext.__enter__.<locals>.reset_log_levelz	debug.logzinfo.log)r   debugr   r   r   r   r   r   r   callbackr   r   Zset_debug_handlerr>   enabledr   r
   r   Z	debug_log_setup_log_captureZinfo_logINFO)r   Z
prev_levelr   r)   r   r*   	__enter__  s    
zDebugContext.__enter__)r   r   r&   c                 C   sp   t d}| j| |}t |}|| |t d |	| |t
|j| | j|j| d S )Nztorch._inductorr   )r   r   r   r   r   StreamHandlerr   r   r   r   minr   r   r   )r   r   r   r2   fdchr)   r)   r*   r     s    



zDebugContext._setup_log_capture)exc_typeexc_valexc_tbr&   c                 C   sF   | j r| j   |   | jr8|   tdt | j | j	  d S )Nz%s debug trace: %s)
r   disable_save_profile_datar   r   r2   r3   r   r   r   )r   r   r   r   r)   r)   r*   __exit__  s    
zDebugContext.__exit__c                 C   s   | j s
J | j | d | dP}tj| j |d}|  |d |d |d |d W d    n1 s|0    Y  d S )Nzcompile.profzcompile.stats)streamZcumtimed   Ztottime)	r   Z
dump_statsr   r   pstatsZStatsZ
strip_dirsZ
sort_statsZprint_stats)r   r   statsr)   r)   r*   r     s    



zDebugContext._save_profile_data).NrE   c                 C   sd   t jjrJtt j|rJztt| |W S  tyF   tjddd Y d S 0 nttd ddd}|S d S )Nz Ignoring exception in debug codeT)exc_inforF   rR   r&   c                  _   s   d S rS   r)   rQ   r)   r)   r*   ignored  s    z)DebugContext.__getattr__.<locals>.ignored)	r   r>   r   getattrDebugFormatter	Exceptionr2   r3   r   )r   r#   r   r)   r)   r*   __getattr__  s    
zDebugContext.__getattr__)r   )r   )"rH   
__module____qualname__	itertoolscountr   r   dictr^   list__annotations__staticmethodr   r   r   r   r   r   r   r   contextmanagerr   r   r   r   r   r9   r   rL   BaseExceptionr   r   r   r   r)   r)   r)   r*   r   ?  sJ   
 
 

r   c                   @   s  e Zd ZeddddZejjeej	 ddddZ
ejjeej	 dddd	Zedd
ddZedd
ddZeeed
ddZedd
ddZejjeddddZd"eeddddZd#eeeeee f eeef f dddZeeej edef eeee ddd d!ZdS )$r   N)handlerr&   c                 C   s"   |j | _ |j| _|j| _|| _d S rS   )r   r   r   r   )r   r   r)   r)   r*   r     s    zDebugFormatter.__init__)rA   inputsr&   c              
   C   s   |  d}d }tjjjjr8tjj|}t	j
|j}tjjjj}tjjddd$ t|||d||d W d    n1 s0    Y  W d    n1 s0    Y  |  d"}||jdd W d    n1 s0    Y  d S )Nzfx_graph_runnable.pyF)ztrace.enabledztrace.save_real_tensorsZinductor)save_dirstable_hashzfx_graph_readable.pyZprint_output)r   ra   rd   r   r>   Zsave_real_tensors_subclassesZ
fake_utilsZtry_convert_fake_to_realr   r   r   r#   r   r   writeprint_readable)r   rA   r   r   r   r   r)   r)   r*   fx_graph  s&    B	zDebugFormatter.fx_graphc                 C   s@   |  d"}||jdd W d    n1 s20    Y  d S )Nzfx_graph_transformed.pyFr   )r   r   r   )r   rA   r   r   r)   r)   r*   fx_graph_transformed  s    z#DebugFormatter.fx_graph_transformedr,   r&   c                 C   s>   |  d }|| | W d    n1 s00    Y  d S )Nzir_pre_fusion.txtr   r   	_write_irr   r,   r   r)   r)   r*   r      s    zDebugFormatter.ir_pre_fusionc                 C   s>   |  d }|| | W d    n1 s00    Y  d S )Nzir_post_fusion.txtr   r   r)   r)   r*   r!     s    zDebugFormatter.ir_post_fusionc                 C   s2   t  }| D ]}||  |d q| S )Nz


)ioStringIOr   Z	debug_strgetvalue)r,   rZ   r?   r)   r)   r*   r     s
    zDebugFormatter._write_irc                 C   s   t || dd d S )Nzgraph_diagram.svg)r.   )rB   r   )r   r,   r)   r)   r*   graph_diagram  s    zDebugFormatter.graph_diagram)rA   r,   r&   c                 C   s,   t || t|| ddtdtjjd d S )Nzorig_fx_graph_diagram.svgFT)r.   r0   progZparse_stack_tracer1   )r   r	   r   GRAPHVIZ_COMMAND_SCALABLEr   r>   r1   )r   rA   r,   r)   r)   r*   draw_orig_fx_graph  s    
z!DebugFormatter.draw_orig_fx_graphpy)r   	extensionr&   c                 C   s   t || d|  d S )Nzoutput_code.)r'   r   r   )r   r   r   r)   r)   r*   output_code-  s    zDebugFormatter.output_code1inductor_generated_kernel_to_post_grad_nodes.json)r   r&   c                 C   s   i }|  |d0}td|j tj}t|| W d    n1 sF0    Y  i }tr|  dd(}t	tt
|}t|| W d    n1 s0    Y  ||fS )Nr   z/Writing provenance tracing debugging info to %sz/inductor_provenance_tracking_node_mappings.json)r   r2   infor#   r   r   jsondumpr   create_node_mappingr   )r   r   Z
debug_infor   Znode_mappingr)   r)   r*   1log_inductor_triton_kernel_to_post_grad_node_info0  s     **z@DebugFormatter.log_inductor_triton_kernel_to_post_grad_node_infoZChoiceCaller)r#   input_nodestimingselapseprecompile_elapseprescreening_elapser&   c                    s   ddl m  t jtttf d fdd|tj tj fdd|D |||d}| j	d	d
ddV}|
 D ]<\}	}
t|	 }|| |
|d< t|| |d qtW d    n1 s0    Y  d S )Nr   )FixedLayoutr?   r&   c                    s  t | dr| j}nd}|t| jd}z|  }t| rd}zt|j}W n> ty   zt	j
jj|jdd}W n ty   Y n0 Y n0  |j|jg t	j
j|jg t	j
j|j|d}t||d< nt||d< W n ty   Y n0 zt|  |d< W n ty   Y n0 zt|  |d	< W n tyD   Y n0 ztt	j
j|  |d
< W n tyx   Y n0 ztt	j
j|  |d< W n ty   Y n0 ztt	j
j|  |d< W n ty   Y n0 t | drt| jtjr| j|d< |S )Nr#    )r#   rL   r   )fallback)r;   sizestrideoffsetlayoutr;   rP   r  r  Znumelr:   )rg   r#   rL   rH   Zget_output_specr7   r9   r  r   r   r=   ZsizevarsZ	size_hintrP   r;   Z
size_hintsr  r  r^   Z	get_dtyperO   Z
get_strideget_sizeZ	get_numelr:   r   IRNode)r?   rz   Z	node_infor  r  Zstatic_layoutr  build_node_infor)   r*   r  N  sn    


z>DebugFormatter.log_autotuning_results.<locals>.build_node_infoc                    s   g | ]} |qS r)   r)   )rT   r?   )r  r)   r*   
<listcomp>  rY   z9DebugFormatter.log_autotuning_results.<locals>.<listcomp>)Zop_nameZcuda_device_nameZcuda_device_countr  Zautotuning_timeZprecompile_timeZprescreening_timezautotuning_result_json_list.txtatzutf-8)encodingZbenchmark_result
)r   r  r  r   r^   ra   ZcudaZget_device_nameZdevice_countr   r{   	info_dictupdater  r  r   )r   r#   r  r  r	  r
  r  Zgeneral_propertiesr   Zcallertimer  r)   r  r*   log_autotuning_resultsC  s(    	 =	
z%DebugFormatter.log_autotuning_results)r   )r  )rH   r   r   r   r   ra   r   r   r   Tensorr   r   SchedulerNodeListr    r!   r   r^   r   r   r   r   r8   r   r   r  r   r  floatr   r  r)   r)   r)   r*   r     s@     
r   r   c                 C   s.   t tjrt dt|  tj	|  d S )NzBEFORE FUSION
%s)
ir_pre_fusion_logisEnabledForr   r   r  r   r   r   r   r    r,   r)   r)   r*   log_ir_pre_fusion  s    r&  c                 C   s.   t tjrt dt|  tj	|  d S )NzAFTER FUSION
%s)
ir_post_fusion_logr$  r   r   r  r   r   r   r   r!   r%  r)   r)   r*   log_ir_post_fusion  s    r(  c                   @   s    e Zd ZU eed< ejed< dS )TensorMetadataHoldertensor_metadatarP   N)rH   r   r   r   r   ra   rP   r)   r)   r)   r*   r)    s   
r)  )pre_grad_graph_idpost_to_pre_grad_nodes_jsontriton_kernel_to_post_grad_jsonr&   c              
      s  i i i i d}t d t|ts0t d |S t|tsHt d |S t| ts`t d |S tt}tt}tt}z|	 D ]>\ }t|t
st d |  W S |D ]}||   qqtttf tddd	}	|	 D ]\ }t|t
st d
 |  W S |D ]}
|	|
s4|    W S |
d| krh||
d    |  |
d   fdd|
dg D }|r| \}|	|s|    W S |d| kr||d   | |d  |fdd|dg D  qqqtttf dddd}|| || || ||||dW S  ty } zNt d| t d| t d| t d|  t t  |W  Y d}~S d}~0 0 dS )zCreate bidirectional mappings between:

    - pre_grad graph nodes and post_grad graph code nodes, and vice versa
    - triton kernel name and post_grad graph code nodes, and vice versa
    )Z	preToPostZ	postToPreZcppCodeToPostZpostToCppCodez.Creating node mappings for provenance trackingzCProvenance tacking error: post_to_pre_grad_nodes_json is not a dictzGProvenance tacking error: triton_kernel_to_post_grad_json is not a dictz9Provenance tacking error: pre_grad_graph_id is not an intzMProvenance tacking error: triton_kernel_to_post_grad_json value is not a listr  c                 S   sB   t | tstd dS d| vs0d| vs0d| vr>td dS dS )NzVProvenance tacking error: node provenance in post_to_pre_grad_nodes_json is not a dictFgraph_idr#   	from_nodezYProvenance tacking error: node provenance in post_to_pre_grad_nodes_json has wrong formatT)r7   r   r2   error)r?   r)   r)   r*   check_format  s    
z)create_node_mapping.<locals>.check_formatzIProvenance tacking error: post_to_pre_grad_nodes_json value is not a listr.  r#   c                    s   g | ]}| fqS r)   r)   rT   r   )	outer_keyr)   r*   r    rY   z'create_node_mapping.<locals>.<listcomp>r/  c                 3   s   | ]}| fV  qd S rS   r)   r2  )
parent_keyr)   r*   rX     s   z&create_node_mapping.<locals>.<genexpr>N)dr&   c                 S   s&   | D ]}t | | | |< qt| } d S rS   )r   r   )r5  keyr)   r)   r*   convert_sets_to_lists  s    z2create_node_mapping.<locals>.convert_sets_to_listsz+Unexpected error in create_node_mapping: %sz post_to_pre_grad_nodes_json:  %sz$triton_kernel_to_post_grad_json:  %szpre_grad_graph_id:  %s)r2   r  r7   r   r0  r9   r_   defaultdictr   r{   r   r|   r^   r   ri   r   popextendr   	traceback
format_exc)r+  r,  r-  Zempty_returnZpre_to_postZpost_to_preZpost_to_cpp_codeZ
node_arrayZ	curr_noder1  r?   r   Zcurrent_noder7  er)   )r3  r4  r*   r    s    














r  r   c            
      O   s   d}t j|st | ttddd}t|| |f\}}d}| d| dtt d}t|d	 }t	
||f| W d
   n1 s0    Y  ttjrd| d|d}	t|	 d
S )z
    This function is used to save arguments for a compile_fx_inner function call
    to the file system.  Later on one can replay the compile_fx_inner call
    with the saved arguments using load_args_and_run_compile_fx_inner.
    z/tmp/inductor_saved_argsrU   r&   c                 S   s$   t | tjrtt| | jS | S dS )z
        Pickle FakeTensor will result in error:
        AttributeError: Can't pickle local object 'WeakValueDictionary.__init__.<locals>.remove'

        Convert all Tensor to metadata. This may also makes pickle faster.
        N)r7   ra   r   r)  r   rP   rU   r)   r)   r*   handle_tensor6  s    z5save_args_for_compile_fx_inner.<locals>.handle_tensorcompile_fx_inner/_z.pklwbNz3
Arguments for a compile_fx_inner call is saved to z. To replay the call,
run the following:

from torch._inductor.debug import load_args_and_run_compile_fx_inner
load_args_and_run_compile_fx_inner(z
)
        )r   r   r   mkdirr   r   nextsave_args_cntr   pickler  r2   r$  r   r   r<   )
rF   rR   folderr@  Zargs_to_saveZkwargs_to_savefn_namer   r   messager)   r)   r*   save_args_for_compile_fx_inner+  s     
.rL  )r   r&   c              	   C   s   ddl m} t| d}t|\}}W d    n1 s:0    Y  ttddd}tjjdd}|j t	
d	d
> t|||f\}}||i |W  d    W  d    S 1 s0    Y  W d    n1 s0    Y  d S )Nr   )rA  rbr>  c                 S   s4   t | tr,tjj| jj| jj| jj	| j
S | S d S rS   )r7   r)  ra   Z_dynamoZtestingZrand_stridedr*  shaper  r;   rP   r?  r)   r)   r*   r@  _  s    
z9load_args_and_run_compile_fx_inner.<locals>.handle_tensorT)Zallow_non_fake_inputsZ	save_argsF)torch._inductor.compile_fxrA  r   rH  loadr   ra   r   ZFakeTensorModer   r   r   )r   rA  r   rF   rR   r@  Z	fake_moder)   r)   r*   "load_args_and_run_compile_fx_innerY  s    ,rQ  )package_path.)funcexported_programinductor_configsrR  r&   c             
   C   s  ddl m} ddlm} ddlm} ddlm} |jj	}|
 }	t|	tjjsRJ |j\}
}z|r||jjdkr|||d|d |r|jjd	krt|	}t|j}t|}|||d |d
 |d\}}t|}tjj||dd}| |
 |||ddd | |	|
||||dW S  |yR } z*||dd|d td |W Y d }~n^d }~0  ty } z<|rd}|jjd
kr~d}||d||d |W Y d }~n
d }~0 0 d S )Nr   )AccuracyError)dump_to_minify)r   )_aoti_flatten_inputs   aot_inductor)options   r   F)strictTZaccuracy)rU  rR  load_and_runZcheck_accuracy)rU  rR  r^  Zaot_inductor_accuracyZminify)commandr[  zAccuracy failedrun)Ztorch._dynamo.debug_utilsrV  Ztorch._dynamo.repro.aotirW  Ztorch._inductorr   rO  rX  rZ  Zdump_aoti_minifiermoduler7   ra   r   r   Zexample_inputsZrepro_levelr   deepcopyr8   exportr2   r3   r   )rS  rT  rU  rR  rV  rW  r   rX  Zuse_minifierrA   rF   rR   Zgm_copyZexample_inputs_copyconfig_copyZflat_example_inputsZtuple_inputsZflattened_epr=  r_  r)   r)   r*   aot_inductor_minifier_wrapperp  s    



	
re  )FN)Nr   )fr_   r   r   Zdataclasses	functoolsr   r   r  r   r   os.pathrH  r   r'   r;  collections.abcr   typingr   r   r   r   r   Zunittest.mockr   ra   Zfunctorch.compiler	   r
   r   r   Ztorch._dynamo.repro.after_aotr   Ztorch._dynamo.utilsr   Ztorch._loggingr   Ztorch.fx.graph_moduler   Ztorch.fx.passes.shape_propr   r   Ztorch.fx.passes.tools_commonr   Ztorch.typesr   Ztorch.utils._ordered_setr   Ztorch.utils._pytreer   r  r   r   Z	schedulerr   r   r   r   r   Zvirtualizedr   r   rH   r2   r#  r'  r   r!  r`   r"   r   cacheri   r+   r^   rB   rb   r4   r   r9   rw   r~   r   r   r   r   r   r   r   r   r&  r(  Z	dataclassr)  r   rG  r  rL  rQ  rc  ZExportedProgramre  r)   r)   r)   r*   <module>   s   



  .e  
$

+   A

v.

