a
    hք                     @   s  d dl Z d dlZd dlZd dlZd dlZd dlmZ d dlmZm	Z	m
Z
mZ d dlZd dlZd dlZd dlm  mZ d dlmZ d dlmZ d dlmZmZ d dlmZ eeZzzd dlZej e!dd	d
Z"dd Z#eG dd dZ$e	de	dddZ%G dd dejj&Z'G dd dZ(G dd dZ)W n  e*yN   dZ+g dZ,Y n0 dZ+g dZ,d dl-m.Z/ e0dddZ1e2dd d!Z3d"d# Z4G d$d% d%eZ5G d&d' d'eZ6e4  d(d) Z7dS )*    N)	dataclass)AnyCallableOptionalUnion)TorchDynamoException)dynamo_timed)ArgumentTarget)sympy_interpereturnc           
         s  t | sJ d|  t jtt ddd}t | } t | sPtd|  t | sdt 	| rl| 
 S |  }| t|}|| }t jkrd}nt jt jfv rƇ fdd  | }nt jkr8|  d	ksJ | d
}t |sJ |  }t jdt jdt jdi}||v r|| }||}nxt jt jfv r|  d	ks\J t| d
}|drd|dd   S |S t jkr|  d
ksJ t|S |d d| }	d|	  dS )Nzunsupported expression type: r   c                    s    fddt   D S )Nc                    s   g | ]}t  |qS  )z3strarg).0ir   r   M/var/www/auris/lib/python3.9/site-packages/torch/fx/experimental/validator.py
<listcomp>@       z/z3str.<locals>.get_args_str.<locals>.<listcomp>)rangenum_argsr   r   r   r   get_args_str?   s    zz3str.<locals>.get_args_strzcan't print Z3 expression: powc                    sD   t  r   ks$t gS  fddt  D S d S )Nc                    s$   g | ]}  |D ]}|qqS r   )r   )r   r   x)collect_str_argsr   r   r   r   ^   s   z3z3str.<locals>.collect_str_args.<locals>.<listcomp>)z3is_appdeclkindr   r   r   r   r   r!   r   r   r   Z   s
    

zz3str.<locals>.collect_str_args   r   z!=><z(/z(idiv    ())r   Zis_exprExprRefliststrZsimplifyr   
ValueErrorZis_int_valueZis_rational_value	as_stringr    r!   ZZ3_OP_POWERZ	Z3_OP_ADDZ	Z3_OP_MULZ	Z3_OP_NOTr   r   ZZ3_OP_EQZZ3_OP_LEZZ3_OP_GEZZ3_OP_TO_INTZZ3_OP_TO_REALr   
startswithZZ3_OP_UNINTERPRETEDjoinrstrip)
r   r   r    opargsr   ZargkindZlogic_inverseZargstrstringr   r"   r   r   <   sN    







r   c                    s   t   fdd}|S )Nc                    sB   d ur"t dd |D r"| S tdd |D }t | S )Nc                 s   s   | ]}t |tjV  qd S N)
isinstancer   BoolRef)r   r   r   r   r   	<genexpr>   s   z/_bitwise_op.<locals>.wrapper.<locals>.<genexpr>c                 s   s   | ]}t |d V  qdS )@   N)r   ZInt2BVr   ar   r   r   r8      r   )alltupler   ZBV2Int)selfr3   wrapped_argsbitwise_func	bool_funcr   r   wrapper   s    z_bitwise_op.<locals>.wrapper)	functoolswraps)rA   rB   rC   r   r@   r   _bitwise_op   s    	rF   c                   @   s  e Zd ZU ded< eejejdddZeejejdddZejejdd	d
Z	ejejejdddZ
ejejdddZejejejdddZejejdddZejejdddZejejejdddZejejejdddZejejejdddZejejejdd d!Zejejdd"d#Zejejdd$d%Zejejdd&d'ZeejejZeejejZeejd(Zeejd(Zd(S ))_Z3OpsTranslationValidator	validator)r   r   c                 C   s   |   r| S t| S r5   )is_realr   ToRealr   r   r   r   to_real   s    z_Z3Ops.to_realc                 C   s   |   r| S t| S r5   )Zis_intr   ToIntrL   r   r   r   to_int   s    z_Z3Ops.to_int)r3   r   c                 C   s   t |S r5   )sum)r>   r3   r   r   r   sym_sum   s    z_Z3Ops.sym_sum	numeratordenominatorr   c                 C   s$   | j |dk t|t| S Nr   )rI   add_assertionrG   rM   r>   rS   rT   r   r   r   div   s    z
_Z3Ops.div)numberr   c                 C   s
   t |S r5   )rG   rO   r>   rY   r   r   r   floor   s    z_Z3Ops.floorc                 C   s4   |  p|  }t| ||}|r0t|S |S r5   )rJ   rG   rO   rX   rM   )r>   rS   rT   Zcast_result_to_realresultr   r   r   floordiv   s    z_Z3Ops.floordivc                 C   s"   t | ||k | |d |S Nr#   )r   Ifr[   rZ   r   r   r   ceil   s    z_Z3Ops.ceilc                 C   s   t |dk| || |S rU   )r   r_   r[   r`   rZ   r   r   r   trunc   s    z_Z3Ops.trunc)r;   br   c                 C   s   t ||k||S r5   r   r_   r>   r;   rb   r   r   r   max   s    z
_Z3Ops.maxc                 C   s   t ||k ||S r5   rc   rd   r   r   r   min   s    z
_Z3Ops.minpqr   c                 C   s   ||  |||  S r5   )r]   r>   rh   ri   r   r   r   mod   s    z
_Z3Ops.modbaseexpr   c                 C   s$   | j t|dk|dk || S rU   )rI   rV   r   Orr>   rm   rn   r   r   r   r      s    z
_Z3Ops.powc                 C   s"   t |}| j|dk |d S )Nr         ?)rG   rM   rI   rV   rZ   r   r   r   sqrt   s    
z_Z3Ops.sqrtc                 C   s
   t |S r5   )r   ZAbsrZ   r   r   r   abs   s    z
_Z3Ops.absc                 C   s4   t | |t ddk| |d | |d S )Nr&   rq   )r   r_   rk   IntValr`   r[   rZ   r   r   r   round_to_int   s
    	z_Z3Ops.round_to_intN) __name__
__module____qualname____annotations__staticmethodr   ArithRefrM   rO   rQ   rX   r[   r]   r`   ra   re   rf   rk   r   rr   rs   ru   rF   operatorand_Andbitwise_andor_ro   
bitwise_orlshiftrshiftr   r   r   r   rG      s0   
		rG   rH   )r2   rI   r   c              (      s0  t jh}| |v   fdd}t|}t j|tjt j||jt j||jt j	||j	t j
||j
t j||jt j||jt j||jt j||jtj||jtj||jtj||jtj||jtj||jtj||jtj||jtj||jtj|dd tj ||j!tj"tj"i}| |v r(||  S || S )Nc                    s0   t jdfddt  fdd}|S )Nr   c                    s   t | tjtjfr| S t | ts. r<t | tr<tt| S t | ttjfrZt	t| S t | t
tjfrxtt
| S tdt|  d S )Nzcan't lift type: )r6   r   r{   r7   boolintBoolValsympyZIntegerrt   floatZFloatRealValr-   type)r;   Zas_boolr   r   wrap  s    z z3op.<locals>.lift.<locals>.wrapc                     sZ   t | dkr<t| d ttfr<tfdd| d D f}ntfdd| D } | S )Nr#   r   c                 3   s   | ]} |V  qd S r5   r   r:   r   r   r   r8   &  r   z6z3op.<locals>.lift.<locals>.wrapper.<locals>.<genexpr>c                 3   s   | ]} |V  qd S r5   r   r:   r   r   r   r8   (  r   )lenr6   r+   r=   )r3   r?   funcr   r   r   rC   "  s    z#z3op.<locals>.lift.<locals>.wrapper)r   r*   rD   rE   )r   rC   r   r   r   lift  s    	zz3op.<locals>.liftc                 S   s   | r|S |S r5   r   )rb   tfr   r   r   <lambda>D  r   zz3op.<locals>.<lambda>)#r|   not_rG   r   Notr}   r   r   r   r   r   r]   truedivrX   rk   rs   builtinsroundru   mathr`   r[   ra   torchZ	sym_floatrM   Zsym_maxre   Zsym_minrf   rQ   Zsym_iteZ	_sym_sqrtrr   _assert)r2   rI   Zboolean_opsr   opsZreplacement_mapr   r   r   z3op  s4    r   c                       st   e Zd Zejjdd fddZeee	df e
eef edddZeee	df e
eef ed fd	d
Z  ZS )PopulateValidatorrH   )graphrI   c                    s*   || _ tjji |d}t j|dd d S )N)rootr   T)Zgarbage_collect_values)rI   r   fxZGraphModulesuper__init__)r>   r   rI   module	__class__r   r   r   V  s    zPopulateValidator.__init__.)targetr3   kwargsr   c                 C   s   t  d }| j|S )Nsymbol)fx_tracebackZget_current_metarI   z3var)r>   r   r3   r   r   r   r   r   placeholder^  s    zPopulateValidator.placeholderc                    sV   |t jkr"t t|| j||S t|dksBJ dt| d| j|d  d S )Nr#   z'expected 1 argument on assertion. Got: r'   r   )r   r   r   call_functionr   rI   r   add_source_expr)r>   r   r3   r   r   r   r   r   d  s    
zPopulateValidator.call_function)rv   rw   rx   r   r   ZGraphr   r
   r=   r	   dictr,   r   r   r   __classcell__r   r   r   r   r   U  s   	r   c                   @   s~  e Zd Zh dZdddddZeejej	ddd	Z
ejejejd
ddZejejejd
ddZejejejd
ddZejejejdddZejejejdddZejejejdddZejejejdddZejejejdddZejejejdddZejejejdd d!Zejejejd
d"d#Zejejejd
d$d%Zeed&d'd(Zejej	d)d*d+ZdS ),	SympyToZ3>   neltgtgeaddleeqmulrH   N)rI   r   c                 C   s   || _ t| j | _d S r5   )
_validatorrG   _ops)r>   rI   r   r   r   r   z  s    zSympyToZ3.__init__)valuedtyper   c                 C   sZ   |t ju rtt|S |t ju r0tt|S |t ju rHt	t|S t
d| d S )Nzunsupported dtype (SympyToZ3): )r   int64r   rt   r   doubler   r   r   r   r-   )r>   r   r   r   r   r   constant  s    


zSympyToZ3.constant)r   r   r   c                 C   s(   |t jkrt|S td| dd S )Nz	to_dtype z NYI)r   float64r   rK   NotImplementedErrorr>   r   r   r   r   r   to_dtype  s    

zSympyToZ3.to_dtypec                 C   s
   t |S r5   )r   rN   r   r   r   r   trunc_to_int  s    zSympyToZ3.trunc_to_intc                 C   s   | j |S r5   )r   ru   r   r   r   r   ru     s    zSympyToZ3.round_to_intrR   c                 C   s   | j ||S r5   r   rX   rW   r   r   r   int_truediv  s    zSympyToZ3.int_truedivc                 C   s   | j ||S r5   r   rW   r   r   r   r     s    zSympyToZ3.truedivc                 C   s   | j ||S r5   r   r]   rW   r   r   r   r]     s    zSympyToZ3.floordivc                 C   s   | j ||S r5   r   rW   r   r   r   rX     s    zSympyToZ3.divrl   c                 C   s   | j ||S r5   r   r   rp   r   r   r   r     s    zSympyToZ3.powc                 C   s   | j ||S r5   r   rp   r   r   r   pow_by_natural  s    zSympyToZ3.pow_by_naturalrg   c                 C   s   | j ||S r5   )r   rk   rj   r   r   r   rk     s    zSympyToZ3.modc                 C   s   | j |S r5   )r   r`   r   r   r   r   ceil_to_int  s    zSympyToZ3.ceil_to_intc                 C   s   | j |S r5   )r   r[   r   r   r   r   floor_to_int  s    zSympyToZ3.floor_to_int)namer   c                 C   sx   t jt jt j| jj| jj| jj| jj| jj	| jj
| jj| jjd}||v rR|| S || jv rftt|S td| d S )N)r}   r   r   r   r   r   r   r[   r`   Zminimummaximumzunhandled operator: )r   r~   ro   r   r   r   r   r   r   r[   r`   rf   re   OPERATOR_HANDLESgetattrr|   AttributeError)r>   r   ZREPLACEMENTr   r   r   __getattr__  s"    

zSympyToZ3.__getattr__)exprr   c                 C   s   t | | jj|S r5   )r   r   symbols)r>   r   r   r   r   run  s    zSympyToZ3.run)rv   rw   rx   r   r   r   r   r   r   r*   r   r{   r   r   ru   r   r   r]   rX   r   r   rk   r   r   r,   r   r   Basicr   r   r   r   r   r   w  s2   
r   c                   @   s   e Zd ZddddZejejdddZeje	ejdd	d
Z
ejddddZejejdddZejddddZdddddZeejejf ddddZddddZddddZdS )rH   Nr   c                 C   s,   t d i | _t | _t | _t | _d S )Nznew instance)logdebugr   set_source_exprs_target_exprs_assertionsr>   r   r   r   r     s
    
zTranslationValidator.__init__)r   r   c                 C   s"   || j v sJ d| | j | S )NzZ3 variable not found for: )r   )r>   r   r   r   r   r     s    zTranslationValidator.z3var)r   r   r   c                 C   s   || j v r| j | S td|j|j |tu rRt|j}|jr| j	
|dk n:|tu rht|j}n$|tu r~t|j}ntd| || j |< |S )Nznew variable: %s (%s)r   z"unsupported type for Z3 variable: )r   r   r   r   rv   r   r   ZIntZis_positiver   r   r   Realr   ZBoolRuntimeError)r>   r   r   varr   r   r   add_var  s    


zTranslationValidator.add_varr   c                 C   s*   |j D ]}t|tjsJ | | qd S r5   )Zfree_symbolsr6   r   Symbolr   )r>   r   sr   r   r   _check_freesymbols  s    
z'TranslationValidator._check_freesymbolsc                 C   s,   t | |}t|tjs(J d| |S )Nz"expected boolean expression. Got: )r   r   r6   r   r7   r>   r   Zz3exprr   r   r   to_z3_boolean_expr  s
    z'TranslationValidator.to_z3_boolean_exprc                 C   s*   || j vrtdt| | j | d S )Nzadd source guard: %s)r   r   r   r   r   )r>   r   r   r   r   r     s    
z$TranslationValidator.add_source_exprzsympy.logic.boolalg.Booleanc                 C   s>   |  | | |}|| jvr.tdt| | j| d S )Nzadd target guard: %s)r   r   r   r   r   r   r   r   r   r   r   add_target_expr#  s
    


z$TranslationValidator.add_target_exprc                 C   s`   t |tjr"| | | |}n|}t |tjs6J || jvrPt	dt
| | j| d S )Nzadd assertion: %s)r6   r   r   r   r   r   r7   r   r   r   r   r   )r>   r   refr   r   r   rV   *  s    

z"TranslationValidator.add_assertionc                 C   s4   t d |  W  d    S 1 s&0    Y  d S )NTranslationValidator.validate)r   	_validater   r   r   r   validate5  s    
r   c                    s   t | jdkst | jdkr d S td}|jt d | jD ]}|| q>|t	tj
| j  |j| j  td | }|tjkr|  t | j| j fdd| jD dn.|tjkrtd n|tjksJ td	 d S )
Nr   ZQF_NRA)timeoutztranslation validation: startc                    s   g | ]}  |s|qS r   )evaluate)r   inpmodelr   r   r   `  s   z2TranslationValidator._validate.<locals>.<listcomp>)failed_source_exprsz:translation validation: could not validate: got z3.unknownztranslation validation: success)r   r   r   r   Z	SolverForr   translation_validation_timeoutr   r   r   r~   r   r   checksatr   ValidationExceptionunknownwarningZunsat)r>   ZsolverZ	assertionrr   r   r   r   9  s4    




	
zTranslationValidator._validate)rv   rw   rx   r   r   r   r   r*   r   r   r   r   r   r7   r   r   r   r   rV   r   r   r   r   r   r   rH     s   F)translation_validation_enabledr   r   BisectValidationExceptionT)	r   r   r   r   rH   r   r   r   r   )_configr   c                   C   s   t   totjS r5   )_assert_z3_installed_if_tv_set_HAS_Z3configtranslation_validationr   r   r   r   r     s    r   c                   C   s   t jS r5   )r   r   r   r   r   r   r     s    r   c                   C   s   t stjrJ dd S )Nzotranslation validation requires Z3 package. Please, either install z3-solver or disable translation validation.)r   r   r  r   r   r   r   r     s    r   c                   @   s   e Zd Zdd Zdd ZdS )r   c                    s   t sJ td fdd}tddd}|tt| }|ttt|}|ttt|}	|ttt|}
d| _d| d| d	|	 d
|
 | _d S )Nr   c                    s   |  d |   S )N: r   )symr   r   r   	symbolstr  s    z/ValidationException.__init__.<locals>.symbolstrc                 S   s   d dd | D S )N
c                 s   s   | ]}d | V  qdS )z  ==> Nr   )r   r   r   r   r   r8     r   zBValidationException.__init__.<locals>.joinlines.<locals>.<genexpr>)r0   )xsr   r   r   	joinlines  s    z/ValidationException.__init__.<locals>.joinlinesztranslation validation failed.zModel:
z

Assertions:
z

Target Expressions:
z

Failed Source Expressions:
)r   r,   sortedmapr   msgdetails)r>   r   Z
assertionsZtarget_exprsr   r  r  Z	model_strZassertions_strZtarget_exprs_strZfailed_source_exprs_strr   r   r   r     s"    zValidationException.__init__c                 C   s   | j  d| j S N

r
  r  r   r   r   r   __str__  s    zValidationException.__str__Nrv   rw   rx   r   r  r   r   r   r   r     s   r   c                   @   s   e Zd Zdd Zdd ZdS )r   c                 C   s.   d| d| | _ d|  d|j | _d S )Nz#translation validation failed when r  z)Failure occurred while running node:
    r  )r
  Zformat_noder  )r>   Zvalidation_excr   failed_actiontraced_noder   r   r   r     s    z"BisectValidationException.__init__c                 C   s   | j  d| j S r  r  r   r   r   r   r    s    z!BisectValidationException.__str__Nr  r   r   r   r   r     s   r   c                    sP  ddl m mm} ddlm}m}m | jt	j
j|dfdd}|td fdd	|ttt  tt d
fddt	j
jtt dfdd}| |  }|std d S | jrtjr|i }dd | jjD }ddt|d   }	}
}||| ||< |	|k rh|	| d }
||
 }td|
|| ||||
< ||
 r\|
}n|
d }	q|	|v rt||	 tsJ ||	 }||}| rd}n| sJ d| d}|j}|d usJ t|dksJ d|j dt| t|d tj s.J d|j dt!|d  t"||	 |d ||j#| dd S )Nr   )FakeTensorMetareplay_shape_env_eventsShapeEnvEvent)CURRENT_NODE_KEYShapeEnvSHAPEENV_EVENT_KEY)noder   c                    s    | j v sJ | j    S r5   )meta)r  )r  eventsr   r   get_node_event  s    zbisect.<locals>.get_node_event)	shape_envr   c                    s   t |tr|S t |tjr,t|j S t |tjrJt|j S t |sXJ t fdd| D t fdd|	 D  |
 |jS )Nc                 3   s   | ]} |V  qd S r5   r   r   r   new_with_shape_envr  r   r   r8     r   z5bisect.<locals>.new_with_shape_env.<locals>.<genexpr>c                 3   s   | ]} |V  qd S r5   r   r  r  r   r   r8     r   )r6   r   r   ZSymIntr  Zwith_shape_envZSymFloatr=   sizeZstrideZstorage_offsetZ	is_nested)r  fake)r  r   r  r   r     s    
z"bisect.<locals>.new_with_shape_env)r  tracked_fakesr   c              
      st   |d usJ z: j  fdd|D dd |D dd |D d W d S  tyn } z|W  Y d }~S d }~0 0 d S )Nc                    s   g | ]} |j qS r   )r"  r:   r  r   r   r   
  r   z8bisect.<locals>.check_shapeenv_fails.<locals>.<listcomp>c                 S   s   g | ]
}|j qS r   )sourcer:   r   r   r   r     r   c                 S   s   g | ]
}|j qS r   )Zsymbolic_contextr:   r   r   r   r     r   )Zinput_contexts)Zproduce_guardsr   )r  r$  r   )r   r#  r   check_shapeenv_fails  s    z$bisect.<locals>.check_shapeenv_failsc                    s8   | j   }d |d  }|j  || jS r^   )r  r   Zlintr$  )r  rY   r  )r  r&  r  r  r   r   check_node_fails  s    

z bisect.<locals>.check_node_failsz2translation validation succeeded: no errors found.c                 S   s   g | ]}|j tjkr|qS r   )r   r   r   )r   r  r   r   r   r   /  s   zbisect.<locals>.<listcomp>r#   r&   zbisecting at %s: %sZ
evaluatingzunexpected event type: zadding runtime assertzbisecting expects z/ to have at least 2 positional arguments. Got: z9 to have a SymPy expression as its second argument. Got: )r   r  r  )$Ztorch.fx.experimental.recordingr  r  r  Z%torch.fx.experimental.symbolic_shapesr  r  r  r  r   r   Noder   r   r+   r   Z_snapshot_tracked_fakesr   infoZshould_record_eventsr   Z translation_validation_no_bisectr   nodesr   r   r6   Zis_evaluate_exprZis_defer_runtime_assertr3   r   r   r   r   r   r  )r  r  r  r  r  r'  Zlast_exception	exceptionZassert_nodesleftZmidrightr  eventr  r3   r   )r  r  r&  r  r   r  r   bisect  sp    	"







r/  )8r   rD   loggingr   r|   Zdataclassesr   typingr   r   r   r   r   r   Ztorch.fxZtorch.fx.tracebackr   	tracebackr   Ztorch._dynamo.excr   Ztorch._dynamo.utilsr   Ztorch.fx.noder	   r
   Ztorch.utils._sympy.interpr   	getLoggerrv   r   r   r*   r,   r   rF   rG   r   ZInterpreterr   r   rH   ImportErrorr   __all__Ztorch.fx.experimentalr   r   r   r   r   r   r   r   r   r/  r   r   r   r   <module>   sN   
$UgJ"f !