a
    h$5                     @   s   d dl Z d dlZd dlmZ d dlZd dlZd dlmZmZm	Z	m
Z
mZmZmZmZmZmZmZmZmZmZmZmZmZmZmZ G dd dZG dd deZG dd	 d	eZejejejd
ddZG dd dZ dS )    N)Union)_keep_floatBitwiseFn_bitwise_andBitwiseFn_bitwise_orFloatPowFloatTrueDivFloorDiv
IntTrueDivMaxMinModOpaqueUnaryFn_expOpaqueUnaryFn_logOpaqueUnaryFn_log2OpaqueUnaryFn_sqrtPowByNaturalRoundDecimal
RoundToIntToFloat
TruncToIntc                   @   s  e Zd Zedd Zedd Zedd Zedd Zed	d
 Z	edd Z
edd Zedd Zedd Zedd Zedd Zedd Zedd Zedd Zedd Zedd  Zed!d" Zed#d$ Zed%d& Zed'd( Zed)d* Zed+d, Zed-d. Zed/d0 Zed1d2 Zed3d4 Zed5d6 Zed7d8 Z ed9d: Z!ed;d< Z"ed=d> Z#ed?d@ Z$edAdB Z%edCdD Z&edEdF Z'edGdH Z(edIdJ Z)edKdL Z*edMdN Z+edOdP Z,edQdR Z-dSS )TReferenceAnalysisc                 C   s
   t | S N)sympyZsympifycdtype r   J/var/www/auris/lib/python3.9/site-packages/torch/utils/_sympy/reference.pyconstant'   s    zReferenceAnalysis.constantc                 C   s   | |B S r   r   abr   r   r   or_+   s    zReferenceAnalysis.or_c                 C   s   | |@ S r   r   r   r   r   r   and_/   s    zReferenceAnalysis.and_c                 C   s,   t | tjst |tjr$t| |S | |kS r   )
isinstancer   ExprEqr   r   r   r   eq3   s    zReferenceAnalysis.eqc                 C   s   |  | ||S r   )not_r'   clsr    r!   r   r   r   ne9   s    zReferenceAnalysis.nec                 C   s   | |k S r   r   r   r   r   r   lt=   s    zReferenceAnalysis.ltc                 C   s   | |kS r   r   r   r   r   r   gtA   s    zReferenceAnalysis.gtc                 C   s   | |kS r   r   r   r   r   r   leE   s    zReferenceAnalysis.lec                 C   s   | |kS r   r   r   r   r   r   geI   s    zReferenceAnalysis.gec                 C   s   t | trJ |  S r   )r$   boolr    r   r   r   r(   M   s    zReferenceAnalysis.not_c                 C   s
   t d| S )Ng      ?r   xr   r   r   
reciprocalR   s    zReferenceAnalysis.reciprocalc                 C   s
   t | dS )N   r   r3   r   r   r   squareV   s    zReferenceAnalysis.squarec                 C   s   t | S r   )r   r4   r   r   r   r   trunc_to_intZ   s    zReferenceAnalysis.trunc_to_intc                 C   s
   t | S r   )r   ceilingr9   r   r   r   ceil_to_int^   s    zReferenceAnalysis.ceil_to_intc                 C   s
   t | S r   )r   floorr9   r   r   r   floor_to_intb   s    zReferenceAnalysis.floor_to_intc                 C   s   t tj| S r   )r   r   r=   r3   r   r   r   r=   f   s    zReferenceAnalysis.floorc                 C   s   t tj| S r   )r   r   r;   r3   r   r   r   ceilj   s    zReferenceAnalysis.ceilc                 C   s&   |t jkrt| S td| dd S Nz	to_dtype z NYI)torchfloat64r   NotImplementedErrorr9   r   r   r   to_dtypen   s    
zReferenceAnalysis.to_dtypec                 C   s
   t | |S r   )r   r4   yr   r   r   modt   s    zReferenceAnalysis.modc                 C   s   t | S r   )absr3   r   r   r   rH   x   s    zReferenceAnalysis.absc                 C   s   |  S r   r   r3   r   r   r   neg|   s    zReferenceAnalysis.negc                 C   s
   t | |S r   r2   r   r   r   r   truediv   s    zReferenceAnalysis.truedivc                 C   s
   t | |S r   )r	   r   r   r   r   int_truediv   s    zReferenceAnalysis.int_truedivc                 C   s
   t | |S r   )r   r   r   r   r   floordiv   s    zReferenceAnalysis.floordivc                 C   s   t dd S )NzTODO: truncdivrC   r   r   r   r   truncdiv   s    zReferenceAnalysis.truncdivc                 C   s   t tj| |S r   )r   operatoraddr   r   r   r   rP      s    zReferenceAnalysis.addc                 C   s
   t j| S r   )r   Add)r*   argsr   r   r   sym_sum   s    zReferenceAnalysis.sym_sumc                 C   s   t tj| |S r   )r   rO   mulr   r   r   r   rT      s    zReferenceAnalysis.mulc                 C   s   t tj| |S r   )r   rO   subr   r   r   r   rU      s    zReferenceAnalysis.subc                 C   s   t | S r   )r   r3   r   r   r   exp   s    zReferenceAnalysis.expc                 C   s   t | S r   )r   r3   r   r   r   log   s    zReferenceAnalysis.logc                 C   s   t | S r   )r   r3   r   r   r   log2   s    zReferenceAnalysis.log2c                 C   s   t | S r   )r   r3   r   r   r   sqrt   s    zReferenceAnalysis.sqrtc                 C   s   t t| |S r   )r   r   r   r   r   r   pow   s    zReferenceAnalysis.powc                 C   s
   t | |S r   r7   r   r   r   r   pow_by_natural   s    z ReferenceAnalysis.pow_by_naturalc                 C   s
   t | |S r   )r   r   r   r   r   minimum   s    zReferenceAnalysis.minimumc                 C   s
   t | |S r   )r
   r   r   r   r   maximum   s    zReferenceAnalysis.maximumc                 C   s   t | S r   )r   r    r   r   r   r   round_to_int   s    zReferenceAnalysis.round_to_intc                 C   s
   t | |S r   )r   r   r   r   r   round_decimal   s    zReferenceAnalysis.round_decimalc                 C   s
   t | |S r   )r   r   r   r   r   bitwise_and   s    zReferenceAnalysis.bitwise_andc                 C   s
   t | |S r   )r   r   r   r   r   
bitwise_or   s    zReferenceAnalysis.bitwise_orN).__name__
__module____qualname__staticmethodr   r"   r#   r'   classmethodr+   r,   r-   r.   r/   r(   r5   r8   r:   r<   r>   r=   r?   rD   rG   rH   rI   rJ   rK   rL   rN   rP   rS   rT   rU   rV   rW   rX   rY   rZ   r[   r\   r]   r_   r`   ra   rb   r   r   r   r   r   &   s   







































r   c                   @   s,  e Zd Zedd Zedd Zedd Zedd Zed	d
 Z	edd Z
edd Zedd Zedd Zedd Zedd Zedd Zedd Zedd Zedd Zedd  Zed!d" Zed#d$ Zed%d& Zed'd( Zed)d* Zed+d, Zed-d. Zed/d0 Zd1S )2PythonReferenceAnalysisc                 C   sH   |t ju rt| S |t ju r$t| S |t ju r6t| S td| d S )Nunrecognized dtype )rA   int64intdoublefloatr0   AssertionErrorr   r   r   r   r      s    


z PythonReferenceAnalysis.constantc                 C   s
   t | S r   )rA   Zsym_notr1   r   r   r   r(      s    zPythonReferenceAnalysis.not_c                 C   s`   t |dkrdS t |dkr$|d S | |d |d }tdt |D ]}| ||| }qF|S )Nr      r6   )lenrP   range)r*   rR   accir   r   r   rS      s    zPythonReferenceAnalysis.sym_sumc                 C   s   | | S r   r   r   r   r   r   rL      s    z PythonReferenceAnalysis.floordivc                 C   s   | | S r   r   rE   r   r   r   rG      s    zPythonReferenceAnalysis.modc                 C   s   | | S r   r   r   r   r   r   rN      s    z PythonReferenceAnalysis.truncdivc                 C   s(   |t jkrt | S td| dd S r@   )rA   rB   Z	sym_floatrC   r9   r   r   r   rD      s    

z PythonReferenceAnalysis.to_dtypec                 C   s   t dd S )Nz!exp is not valid shape sympy exprrn   r3   r   r   r   rV     s    zPythonReferenceAnalysis.expc                 C   s   t dd S )Nz!log is not valid shape sympy exprrt   r3   r   r   r   rW     s    zPythonReferenceAnalysis.logc                 C   s
   t | S r   )rA   Z	_sym_log2r3   r   r   r   rX   	  s    zPythonReferenceAnalysis.log2c                 C   s
   t | S r   )rA   Z	_sym_sqrtr3   r   r   r   rY     s    zPythonReferenceAnalysis.sqrtc                 C   s   t | |S r   )rA   Zsym_minr   r   r   r   r\     s    zPythonReferenceAnalysis.minimumc                 C   s   t | |S r   )rA   Zsym_maxr   r   r   r   r]     s    zPythonReferenceAnalysis.maximumc                 C   s
   t | S r   )mathr=   r9   r   r   r   r>     s    z$PythonReferenceAnalysis.floor_to_intc                 C   s
   t | S r   )ru   r?   r9   r   r   r   r<     s    z#PythonReferenceAnalysis.ceil_to_intc                 C   s   t t| S r   )rm   ru   r=   r3   r   r   r   r=   !  s    zPythonReferenceAnalysis.floorc                 C   s   t t| S r   )rm   ru   r?   r3   r   r   r   r?   %  s    zPythonReferenceAnalysis.ceilc                 C   s   | | S r   r   r   r   r   r   rJ   )  s    zPythonReferenceAnalysis.truedivc                 C   s   | | S r   r   r   r   r   r   rZ   -  s    zPythonReferenceAnalysis.powc                 C   s   | | S r   r   r   r   r   r   r[   1  s    z&PythonReferenceAnalysis.pow_by_naturalc                 C   s   t | S r   roundr^   r   r   r   r_   8  s    z$PythonReferenceAnalysis.round_to_intc                 C   s   t | |dS )N)ndigitsrv   r   r   r   r   r`   <  s    z%PythonReferenceAnalysis.round_decimalc                 C   s   | |@ S r   r   r   r   r   r   ra   @  s    z#PythonReferenceAnalysis.bitwise_andc                 C   s   | |B S r   r   r   r   r   r   rb   D  s    z"PythonReferenceAnalysis.bitwise_orN)rc   rd   re   rf   r   r(   rg   rS   rL   rG   rN   rD   rV   rW   rX   rY   r\   r]   r>   r<   r=   r?   rJ   rZ   r[   r_   r`   ra   rb   r   r   r   r   rh      s`   
























rh   c                   @   s   e Zd Zedd ZdS ) OptimizedPythonReferenceAnalysisc                 C   s
   t | S r   )rA   rS   )rR   r   r   r   rS   L  s    z(OptimizedPythonReferenceAnalysis.sym_sumN)rc   rd   re   rf   rS   r   r   r   r   ry   K  s   ry   )r4   r   returnc                 C   s   t jjj| |S r   )rA   opsZprimsZconvert_element_typedefaultr9   r   r   r   	_to_dtypeQ  s    r}   c                   @   sX  e Zd Zedd Zedd Zedd Zedd Zed	d
 Zedd Z	e
dd Zedd Zedd Zedd Zedd Zedd Zedd Zedd Zedd Zedd  Zed!d" Zed#d$ Zed%d& Zed'd( Zed)d* Zed+d, Zed-d. Zed/d0 Zed1d2 Zed3d4 Zed5d6 Zed7d8 Z ed9d: Z!ed;d< Z"ed=d> Z#ed?d@ Z$edAdB Z%edCdD Z&edEdF Z'edGdH Z(edIdJ Z)edKdL Z*edMdN Z+edOdP Z,edQdR Z-edSdT Z.edUdV Z/edWdX Z0edYdZ Z1ed[d\ Z2ed]d^ Z3ed_d` Z4edadb Z5dcS )dTensorReferenceAnalysisc                 C   s^   |t ju rt| }n6|t ju r(t| }n"|t ju r<t| }ntd| t jjj	j
||dS )Nri   )r   )rA   rj   rk   rl   rm   r0   rn   r{   atenZscalar_tensorr|   )r   r   dr   r   r   r   f  s    





z TensorReferenceAnalysis.constantc                 C   s   t jjj| |S r   )rA   r{   r   
logical_orr|   r   r   r   r   r"   s  s    zTensorReferenceAnalysis.or_c                 C   s   t jjj| |S r   )rA   r{   r   logical_andr|   r   r   r   r   r#   w  s    zTensorReferenceAnalysis.and_c                 C   s   t jj| |S r   )rA   r{   r   ra   r   r   r   r   ra   {  s    z#TensorReferenceAnalysis.bitwise_andc                 C   s   t jj| |S r   )rA   r{   r   rb   r   r   r   r   rb     s    z"TensorReferenceAnalysis.bitwise_orc                 C   s   t jjj| |S r   )rA   r{   r   r'   Tensorr   r   r   r   r'     s    zTensorReferenceAnalysis.eqc                 C   s   t jjj||S r   )rA   r{   r   r+   r   r)   r   r   r   r+     s    zTensorReferenceAnalysis.nec                 C   s   t jjj| |S r   )rA   r{   r   r,   r   r   r   r   r   r,     s    zTensorReferenceAnalysis.ltc                 C   s   t jjj| |S r   )rA   r{   r   r-   r   r   r   r   r   r-     s    zTensorReferenceAnalysis.gtc                 C   s   t jjj| |S r   )rA   r{   r   r.   r   r   r   r   r   r.     s    zTensorReferenceAnalysis.lec                 C   s   t jjj| |S r   )rA   r{   r   r/   r   r   r   r   r   r/     s    zTensorReferenceAnalysis.gec                 C   s   t jjj| S r   )rA   r{   r   Zlogical_notr|   r1   r   r   r   r(     s    zTensorReferenceAnalysis.not_c                 C   s   t jjj| S r   )rA   r{   r   r5   r|   r3   r   r   r   r5     s    z"TensorReferenceAnalysis.reciprocalc                 C   s   t jjj| S r   )rA   r{   r   r8   r|   r3   r   r   r   r8     s    zTensorReferenceAnalysis.squarec                 C   s   t tjjj| |S r   )r}   rA   r{   r   truncr|   r9   r   r   r   r:     s    z$TensorReferenceAnalysis.trunc_to_intc                 C   s   t tjjj| |S r   )r}   rA   r{   r   r?   r|   r9   r   r   r   r<     s    z#TensorReferenceAnalysis.ceil_to_intc                 C   s   t tjjj| |S r   )r}   rA   r{   r   r=   r|   r9   r   r   r   r>     s    z$TensorReferenceAnalysis.floor_to_intc                 C   s   t jjj| S r   )rA   r{   r   r=   r|   r3   r   r   r   r=     s    zTensorReferenceAnalysis.floorc                 C   s   t jjj| S r   )rA   r{   r   r?   r|   r3   r   r   r   r?     s    zTensorReferenceAnalysis.ceilc                 C   s
   t | |S r   )r}   r9   r   r   r   rD     s    z TensorReferenceAnalysis.to_dtypec                 C   s   t dd S )Nz8no C-style modulus operation available from frontend atmrM   rE   r   r   r   rG     s    zTensorReferenceAnalysis.modc                 C   s   t jjj| S r   )rA   r{   r   rH   r|   r3   r   r   r   rH     s    zTensorReferenceAnalysis.absc                 C   s   t jjj| S r   )rA   r{   r   rI   r|   r3   r   r   r   rI     s    zTensorReferenceAnalysis.negc                 C   s   t jjj| |S r   )rA   r{   r   true_divider   r   r   r   r   rJ     s    zTensorReferenceAnalysis.truedivc                 C   s*   t dtjjjt| tjt|tjS )Nz8Python int truediv difficult to implement in PyTorch atm)rC   rA   r{   r   r   r|   r}   rB   r   r   r   r   rK     s    
z#TensorReferenceAnalysis.int_truedivc                 C   s   t jjjj| |ddS )Nr=   )Zrounding_mode)rA   r{   r   divZTensor_moder   r   r   r   rL     s    z TensorReferenceAnalysis.floordivc                 C   s   t dd S )Nz9no C-style truncdiv operation available from frontend atmrM   r   r   r   r   rN     s    z TensorReferenceAnalysis.truncdivc                 C   s   t jjj| |S r   )rA   r{   r   rP   r   r   r   r   r   rP     s    zTensorReferenceAnalysis.addc                 C   s   t jjj| |S r   )rA   r{   r   rT   r   r   r   r   r   rT     s    zTensorReferenceAnalysis.mulc                 C   s   t jjj| |S r   )rA   r{   r   rU   r   r   r   r   r   rU     s    zTensorReferenceAnalysis.subc                 C   s   t jjj| S r   )rA   r{   r   rV   r|   r3   r   r   r   rV     s    zTensorReferenceAnalysis.expc                 C   s   t jjj| S r   )rA   r{   r   rW   r|   r3   r   r   r   rW     s    zTensorReferenceAnalysis.logc                 C   s   t jjj| S r   )rA   r{   r   rX   r|   r3   r   r   r   rX     s    zTensorReferenceAnalysis.log2c                 C   s   t jjj| S r   )rA   r{   r   rY   r|   r3   r   r   r   rY     s    zTensorReferenceAnalysis.sqrtc                 C   s   t jjj| S r   )rA   r{   r   sinr|   r3   r   r   r   r     s    zTensorReferenceAnalysis.sinc                 C   s   t jjj| S r   )rA   r{   r   cosr|   r3   r   r   r   r   
  s    zTensorReferenceAnalysis.cosc                 C   s   t jjj| S r   )rA   r{   r   tanhr|   r3   r   r   r   r     s    zTensorReferenceAnalysis.tanhc                 C   s   t jjj| S r   )rA   r{   r   sinhr|   r3   r   r   r   r     s    zTensorReferenceAnalysis.sinhc                 C   s   t jjj| S r   )rA   r{   r   coshr|   r3   r   r   r   r     s    zTensorReferenceAnalysis.coshc                 C   s   t jjj| S r   )rA   r{   r   tanr|   r3   r   r   r   r     s    zTensorReferenceAnalysis.tanc                 C   s   t jjj| S r   )rA   r{   r   acosr|   r3   r   r   r   r     s    zTensorReferenceAnalysis.acosc                 C   s   t jjj| S r   )rA   r{   r   atanr|   r3   r   r   r   r   "  s    zTensorReferenceAnalysis.atanc                 C   s   t jjj| S r   )rA   r{   r   asinr|   r3   r   r   r   r   &  s    zTensorReferenceAnalysis.asinc                 C   s   t jjj| |S r   rA   r{   r   rZ   ZTensor_Tensorr   r   r   r   rZ   *  s    zTensorReferenceAnalysis.powc                 C   s   t jjj| |S r   r   r   r   r   r   r[   .  s    z&TensorReferenceAnalysis.pow_by_naturalc                 C   s   t jjj| |S r   )rA   r{   r   r\   r|   r   r   r   r   r\   3  s    zTensorReferenceAnalysis.minimumc                 C   s   t jjj| |S r   )rA   r{   r   r]   r|   r   r   r   r   r]   7  s    zTensorReferenceAnalysis.maximumc                 C   s   t jjj| S r   )rA   r{   r   rw   r|   r^   r   r   r   r_   ;  s    z$TensorReferenceAnalysis.round_to_intc                 C   s   t dd S )Nz8round decimal doesn't support Tensor second argument atmrM   r   r   r   r   r`   ?  s    z%TensorReferenceAnalysis.round_decimalN)6rc   rd   re   rf   r   r"   r#   ra   rb   r'   rg   r+   r,   r-   r.   r/   r(   r5   r8   r:   r<   r>   r=   r?   rD   rG   rH   rI   rJ   rK   rL   rN   rP   rT   rU   rV   rW   rX   rY   r   r   r   r   r   r   r   r   r   rZ   r[   r\   r]   r_   r`   r   r   r   r   r~   c  s   















































r~   )!ru   rO   typingr   r   rA   Ztorch.utils._sympy.functionsr   r   r   r   r   r   r	   r
   r   r   r   r   r   r   r   r   r   r   r   r   rh   ry   r   r   r}   r~   r   r   r   r   <module>   s   T /w