a
    hB                     @   s^  U d Z ddlZddlZddlZddlZddlZddlZddlZddl	Z	ddl
Z
ddlZddlmZ ddlmZmZmZmZmZmZ ddlmZ ddlmZ ddlZddlmZ ddlmZ dd	lmZ d
dlm Z m!Z!m"Z"m#Z# d
dl$m%Z%m&Z&m'Z'm(Z( d
dl)m*Z*m+Z+m,Z, d
dl
m-Z-m.Z.m/Z/ d
dl0m1Z1 dZ2ee
j3 e4d< zddl5Z2W n e6y`   dZ2Y n0 e!j7Z7dZ8e9e:Z;edZ<eej= eej= dddZ>e?e?dddZ@dd ZAejBjCeeeeDe dddZEeeFddd ZGeej=ej=dd!d"ZHeeeDe eIed#f eJeef f eKdd$d"ZHeeej=eKf dd%d"ZHe?d&d'd(ZLdxe?e
jMe?dd*d+d,ZNdd-e.eOeeeOe-d.d/d0ZPG d1d2 d2ZQG d3d4 d4ZRG d5d6 d6ZSG d7d8 d8ZTG d9d: d:ZUe?e?d;d<d=ZVe?e?d;d>d?ZWe?e?d@dAdBZXe?e?d;dCdDZYdyeed#ef eOeeO eeO eOddEdFdGZZej[eDej= ed#ef dHdIdJZ\dzeKeKeFeKe?dMdNdOZ]ej^dPdfeeO eeO ej_ee?ej`f eOej=dQdRdSZaedTZbeFd&dUdVZcee<ebf eee<ebf dWdXdYZdddZd[ d\eee?e?eee? eed#ef ged#ef f eed]d^d_Zfee<ebf ee<ebf d`dadbZgee<ebf ee<ebf d`dcddZhee<ebf ee<ebf d`dedfZiee<ebf ee<ebf d`dgdhZjee<ebf ee<ebf d`didjZkee<ebf ee<ebf d`dkdlZlee<ebf ee<ebf d`dmdnZmee<ebf ee<ebf d`dodpZnd{eFddrdsdtZoee<ebf e<jpe<jqebdudvdwZrdS )|a1  Testing utilities and infrastructure for Dynamo.

This module provides a comprehensive set of testing utilities including:
- Test result collection and validation
- Graph manipulation and comparison tools
- Test case management and execution helpers
- Specialized test decorators for different Python versions and features
- RNG state management
- Compilation counting and monitoring
- Debug utilities for bytecode transformation

The utilities in this module are used across Dynamo's test suite to ensure
consistent testing patterns and proper test isolation.
    N)Sequence)AnyCallableOptionaloverloadTypeVarUnion)	ParamSpec)patch)fx)	aot_eager)OutputGraph   )config
eval_frameoptimize_assertreset)create_instructiondebug_checksis_generatortransform_code_object)CheckFunctionManager	CompileIdGuardedCode)ConvertFrameReturnDynamoFrameTypewrap_guarded_code)samenp   _P)xreturnc                 C   s    | d u rd S |    | jS N)detachcloneZrequires_grad_requires_gradr!    r(   C/var/www/auris/lib/python3.9/site-packages/torch/_dynamo/testing.pyclone_me?   s    r*   )namer"   c                 C   s   t dd| S )Nz^_orig_mod[.] resub)r+   r(   r(   r)   remove_optimized_module_prefixE   s    r0   c                    sJ   ddl m  d d  fdd}tj|dd| |i | jfS )Nr   )InstructionTranslatorc                    s   |    jj| S r#   )Z
current_txoutputregion_tracker)Z_gmargskwargsr1   gmr3   r(   r)   extract_graph_backendO   s    z8extract_graph_and_tracker.<locals>.extract_graph_backendT)backendZ	fullgraph)Ztorch._dynamo.symbolic_convertr1   torchcompilegraph)fnr4   r5   r8   r(   r6   r)   extract_graph_and_trackerI   s    r>   )model
predictionlossexample_inputsr"   c                 C   s  g }| | | | i }i }|  D ]N\}}t| tjrDt|}|}	|j}
|jd u rbt|}
|
||d < |	||< q(| | | | i }| 	 D ]$\}}t| tjrt|}|||< q| | |D ]@}t|t
tfr|dd |D  qt|tjr| |j q|S )Nz.gradc                 s   s    | ]}t |tjr|jV  qd S r#   )
isinstancer:   Tensorgrad).0inpr(   r(   r)   	<genexpr>{       z"collect_results.<locals>.<genexpr>)appendZnamed_parametersrC   r   ZOptimizedModuler0   rE   r:   Z
zeros_likeZnamed_bufferstuplelistextendrD   )r?   r@   rA   rB   resultsZgradsparamsr+   paramZ
param_copyrE   buffersbufferZexampler(   r(   r)   collect_resultsZ   s8    








rS   )outr"   c                 C   s^   t | tjr| jS t | ttfr2tdd | D S | d u r>dS t | trLdS tdt	| d S )Nc                 s   s   | ]}t |V  qd S r#   )requires_bwd_passrF   r!   r(   r(   r)   rH      rI   z$requires_bwd_pass.<locals>.<genexpr>FDon't know how to reduce)
rC   r:   rD   r&   rL   rK   anyintNotImplementedErrortyperT   r(   r(   r)   rU      s    
rU   c                 C   s   d S r#   r(   r\   r(   r(   r)   reduce_to_scalar_loss   s    r]   .c                 C   s   d S r#   r(   r\   r(   r(   r)   r]      s    c                 C   s   t | tjr|  |   S t | ttfrDtdd | D t|  S t| j	dv r\t
| jS t| j	dkrt| j S t | trtdd |  D t|   S tdt| dS )z/Reduce the output of a model to get scalar lossc                 s   s   | ]}t |V  qd S r#   r]   rV   r(   r(   r)   rH      rI   z(reduce_to_scalar_loss.<locals>.<genexpr>)ZMaskedLMOutputZSeq2SeqLMOutputZ!CausalLMOutputWithCrossAttentionsZSquashedNormalc                 s   s   | ]}t |V  qd S r#   r^   )rF   valuer(   r(   r)   rH      rI   rW   N)rC   r:   rD   sumZnumelrL   rK   lenr[   __name__r]   ZlogitsmeandictvalueskeysrZ   r\   r(   r(   r)   r]      s    


r"   c                  C   s0   t jt jtd} t j| s,t |  | S )Nz../debug)ospathjoindirname__file__existsmkdir)ri   r(   r(   r)   	debug_dir   s    
ro   r,   )r+   codeextrar"   c              	   C   sj   t tjt | d@}|t|  dt|  d| d W d    n1 s\0    Y  d S )Nwz


)	openrh   ri   rj   ro   writedisBytecodeinfo)r+   rp   rq   fdr(   r(   r)   
debug_dump   s    (rz   )skip)frame
cache_sizehooks_r{   r"   c          	      C   s  t t tdddd}tjj }tjjd | t| jrdt	 W  d   W  d   S t
| j t| j|}ti ddddddit t | jg dd	}tt|t| j|jtddd
W  d   W  d   S 1 s0    Y  W d   n1 s0    Y  dS )zused to debug jump updatesN)instructionscode_optionsr"   c                 S   s$   |  dtd |  dtd d S )Nr   NOP)insertr   )r   r   r(   r(   r)   insert_nops   s    z&debug_insert_nops.<locals>.insert_nopsdebug_insert_nopsFZ_idr   )r   Zcompiler_fnZroot_txexportZexport_constraintsZframe_stateZlocal_scopeZglobal_scopef_codeZtorch_function_mode_stackpackage)Zframe_idZframe_compile_id)rL   r   r:   Z_dynamoutilsZget_metrics_contextZdynamo_timedr   r   r   r   r   r   localsglobalsr   r   r   Zguard_managerr   )	r|   r}   r~   r   r{   r   Zmetrics_contextrp   r<   r(   r(   r)   r      s6    
"

r   c                   @   sL   e Zd ZddddZejjeej e	de
f dddZddd	d
ZdS )CompileCounterNrg   c                 C   s   d| _ d| _d S Nr   frame_countop_countselfr(   r(   r)   __init__   s    zCompileCounter.__init__.r7   rB   r"   c                 C   s:   |  j d7  _ |jjD ]}d|jv r|  jd7  _q|jS )Nr   call)r   r<   nodesopr   forward)r   r7   rB   noder(   r(   r)   __call__   s
    
zCompileCounter.__call__c                 C   s   d| _ d| _d S r   r   r   r(   r(   r)   clear   s    zCompileCounter.clear)rb   
__module____qualname__r   r:   r   GraphModulerL   rD   r   r   r   r   r(   r(   r(   r)   r      s
   
	r   c                   @   sN   e Zd ZeddddZejjeej	 e
def dddZdd	d
dZdS )CompileCounterWithBackendN)r9   r"   c                 C   s   d| _ d| _|| _g | _d S r   )r   r   r9   graphs)r   r9   r(   r(   r)   r      s    z"CompileCounterWithBackend.__init__.r   c                 C   s\   ddl m} |  jd7  _|jjD ]}d|jv r"|  jd7  _q"| j| || j	||S )Nr   )lookup_backendr   )
Zbackends.registryr   r   r<   r   r   r   r   rJ   r9   )r   r7   rB   r   r   r(   r(   r)   r      s    
z"CompileCounterWithBackend.__call__rg   c                 C   s   d| _ d| _g | _d S r   )r   r   r   r   r(   r(   r)   r   
  s    zCompileCounterWithBackend.clear)rb   r   r   strr   r:   r   r   rL   rD   r   r   r   r   r(   r(   r(   r)   r      s
   
r   c                   @   s>   e Zd ZddddZejjeej e	de
f dddZdS )	EagerAndRecordGraphsNrg   c                 C   s
   g | _ d S r#   )r   r   r(   r(   r)   r     s    zEagerAndRecordGraphs.__init__.r   c                 C   s   | j | |jS r#   )r   rJ   r   )r   r7   rB   r(   r(   r)   r     s    zEagerAndRecordGraphs.__call__rb   r   r   r   r:   r   r   rL   rD   r   r   r   r(   r(   r(   r)   r     s   
r   c                   @   s>   e Zd ZddddZejjeej e	de
f dddZdS )	AotEagerAndRecordGraphsNrg   c                 C   s   g | _ g | _g | _d S r#   )r   	fw_graphs	bw_graphsr   r(   r(   r)   r     s    z AotEagerAndRecordGraphs.__init__.r   c                    sl    j | tjjttj tdtf d fdd}tjjttj tdtf d fdd}t	||||dS )N.r   c                    s    j |  | jS r#   )r   rJ   r   r7   rB   r   r(   r)   fw_compiler(  s    z5AotEagerAndRecordGraphs.__call__.<locals>.fw_compilerc                    s    j |  | jS r#   )r   rJ   r   r   r   r(   r)   bw_compiler.  s    z5AotEagerAndRecordGraphs.__call__.<locals>.bw_compiler)r   r   )
r   rJ   r:   r   r   rL   rD   r   r   r   )r   r7   rB   r   r   r(   r   r)   r   #  s    

z AotEagerAndRecordGraphs.__call__r   r(   r(   r(   r)   r     s   
r   c                   @   s"   e Zd ZddddZdd ZdS )InductorAndRecordGraphsNrg   c                 C   s   g | _ g | _d S r#   )r   inductor_graphsr   r(   r(   r)   r   =  s    z InductorAndRecordGraphs.__init__c                    sr   dd l m  m} j| |j  fdd}tj|d|d |||W  d    S 1 sd0    Y  d S )Nr   c                     s   j | d   | i |S r   )r   rJ   )r4   r5   Zold_compile_fx_innerr   r(   r)   patchedH  s    z1InductorAndRecordGraphs.__call__.<locals>.patched_compile_fx_inner)new)Ztorch._inductor.compile_fxZ	_inductorZ
compile_fxr   rJ   r   r
   object)r   r7   rB   Zcompile_fx_modr   r(   r   r)   r   A  s    z InductorAndRecordGraphs.__call__)rb   r   r   r   r   r(   r(   r(   r)   r   <  s   r   )rp   r"   c                 C   s   t dd| S )Nz(?m)^ *#.*\n?r,   r-   rp   r(   r(   r)   strip_commentP  s    r   c                 C   s   d dd | dD S )Nrs   c                 S   s   g | ]}|  qS r(   )rstrip)rF   liner(   r(   r)   
<listcomp>U  rI   z)remove_trailing_space.<locals>.<listcomp>)rj   splitr   r(   r(   r)   remove_trailing_spaceT  s    r   )gm_strr"   c                 C   s   t t| S r#   )r   r   )r   r(   r(   r)   normalize_gmX  s    r   c                 C   s   t dd| }|S )z-
    Normalize code: remove empty lines.
    z[\r\n]+rs   r-   )rp   Znormal_coder(   r(   r)   empty_line_normalizer^  s    r   )r   r=   nargsexpected_opsexpected_ops_dynamicexpected_frame_countr"   c                 C   s   t js|d ur|}t }dd t|D }dd t|D }|| }	|| }
t  t||}|| }|| }|| }|| }t  | t||	 | t||	 | t||
 | t||
 | |j	| |d ur| |j
| d S )Nc                 S   s   g | ]}t d d qS 
   r:   randnrF   r   r(   r(   r)   r   s  rI   z!standard_test.<locals>.<listcomp>c                 S   s   g | ]}t d d qS r   r   r   r(   r(   r)   r   t  rI   )r   assume_static_by_defaultr   ranger   r   
assertTruer   assertEqualr   r   )r   r=   r   r   r   r   actualZargs1Zargs2Zcorrect1Zcorrect2Zopt_fnZval1aZval2aZval1bZval2br(   r(   r)   standard_testf  s*    r   r   c                 C   s   | j S r#   )r   r   r(   r(   r)   dummy_fx_compile  s    r   T皙?)speeduppvalue
is_correctpvalue_thresholdr"   c                 C   s.   |sdS ||kr| ddS | dd|dS )NERRORz.3fzx SAMEzx p=z.2fr(   )r   r   r   r   r(   r(   r)   format_speedup  s
    r   cpu)sizestridedtypedevice
extra_sizer"   c                 C   s~   t dd t| |D d | }|jr^|jdkrLtj|tj|dj|d}qptj|||d}ntj|g||d}t	|| |S )Nc                 s   s   | ]\}}|d  | V  qdS )r   Nr(   )rF   shaper   r(   r(   r)   rH     rI   zrand_strided.<locals>.<genexpr>r   )r   r   )r   )r   r   r   )
r`   zipZis_floating_pointitemsizer:   r   float16tozerosZ
as_strided)r   r   r   r   r   Zneeded_sizerR   r(   r(   r)   rand_strided  s    
r   _Tc                   C   s   t j S r#   )r   r   r(   r(   r(   r)   check_dynamic_shape_capture  s    r   )r=   patchesr"   c                    s*   t  tjtjtd fdd}|S )N)r4   r5   r"   c               	      s^   t  B}D ]\}}}|t||| q | i |W  d    S 1 sP0    Y  d S r#   )
contextlib	ExitStackenter_contextr
   r   )r4   r5   stackmoduleattrvalr=   r   r(   r)   _fn  s    
z"_make_fn_with_patches.<locals>._fn)	functoolswrapsr    r4   r5   r   )r=   r   r   r(   r   r)   _make_fn_with_patches  s    r   c                 C   s   | S r#   r(   r'   r(   r(   r)   <lambda>  rI   r   )
xfail_prop	decorator)cls
cls_prefix	fn_suffixr   r   r   r"   c                G   s   t | | j | ji }|j|_t| D ]}|drt| |}t|s\t||t| | q(| | }	t	|g|R  }
|	|
_|d urt
||rt|
}
t||	||
 q(t
||s(t||t| | q(|S )NZtest_)r[   rb   	__bases__r   dir
startswithgetattrcallablesetattrr   hasattrunittestexpectedFailure)r   r   r   r   r   r   ZDummyTestClassr+   r=   new_nameZnew_fnr(   r(   r)   make_test_cls_with_patches  s"    



r   )r=   r"   c                 C   s   t jdkr| S t| S )N)r      sysversion_infor   r{   r=   r(   r(   r)   skipIfNotPy311  s    
r  c                 C   s   t jdkr| S td| S )Nr      zRequires Python 3.12+r  r  r(   r(   r)   skipIfNotPy312  s    
r  c                 C   s   t jdkrt| S | S )Nr  )r  r  r   r   r  r(   r(   r)   xfailIfPy312  s    

r	  c                 C   s   t jdkrtd| S | S )Nr  zNot supported in Python 3.12+r  r  r(   r(   r)   skipIfPy312  s    
r
  c                 C   s    t jdkr| S td| S d S )N)r   r   zRequires Python 3.10+r  r  r(   r(   r)   requiresPy310  s    
r  c                 C   s
   d| _ | S NT)Z_expected_failure_dynamicr  r(   r(   r)   expectedFailureDynamic
  s    r  c                 C   s
   d| _ | S r  )Z!_expected_failure_codegen_dynamicr  r(   r(   r)   expectedFailureCodegenDynamic  s    r  c                 C   s
   d| _ | S r  )Z!_expected_failure_dynamic_wrapperr  r(   r(   r)   expectedFailureDynamicWrapper  s    r  F)use_xlar"   c                 C   sR   t d td tr$tjd | rNdd lm  m} |dt	|
  d S )Ni9  r   )r:   Zmanual_seedrandomseedr   Ztorch_xla.core.xla_modelcoreZ	xla_modelZset_rng_stater   Z
xla_device)r  Zxmr(   r(   r)   reset_rng_state  s    

r  )fr4   r5   r"   c                 O   s   | |i |S r#   r(   )r  r4   r5   r(   r(   r)   &_skipped_function_for_test_reconstruct&  s    r  )r,   )NNr   )Tr   )F)s__doc__r   rv   r   loggingos.pathrh   r  r.   r  typesr   collections.abcr   typingr   r   r   r   r   r   Ztyping_extensionsr	   Zunittest.mockr
   r:   r   Z torch._dynamo.backends.debuggingr   Ztorch._dynamo.output_graphr   r,   r   r   r   r   Zbytecode_transformationr   r   r   r   Zguardsr   r   r   r   r   r   r   r   r   
ModuleType__annotations__numpyModuleNotFoundErrorunsupportedZthree	getLoggerrb   logr    rD   r*   r   r0   r>   nnModulerL   rS   boolrU   r]   rK   rd   floatro   CodeTyperz   rY   r   r   r   r   r   r   r   r   r   r   r   r   r   r   float32r   r   r   r   r   r   r[   r   r  r  r	  r
  r  r  r  r  r  r4   r5   r  r(   r(   r(   r)   <module>   s    

( 
(   
"
	  "     	   