o
    wZh$5                     @   s   d dl Z d dlZd dlmZ d dlZd dlZd dlmZmZm	Z	m
Z
mZmZmZmZmZmZmZmZmZmZmZmZmZmZmZ G dd dZG dd deZG dd	 d	eZd
ejdejdejfddZG dd dZ dS )    N)Union)_keep_floatBitwiseFn_bitwise_andBitwiseFn_bitwise_orFloatPowFloatTrueDivFloorDiv
IntTrueDivMaxMinModOpaqueUnaryFn_expOpaqueUnaryFn_logOpaqueUnaryFn_log2OpaqueUnaryFn_sqrtPowByNaturalRoundDecimal
RoundToIntToFloat
TruncToIntc                   @   s  e Zd Zedd Zedd Zedd Zedd Zed	d
 Z	edd Z
edd Zedd Zedd Zedd Zedd Zedd Zedd Zedd Zedd Zedd  Zed!d" Zed#d$ Zed%d& Zed'd( Zed)d* Zed+d, Zed-d. Zed/d0 Zed1d2 Zed3d4 Zed5d6 Zed7d8 Z ed9d: Z!ed;d< Z"ed=d> Z#ed?d@ Z$edAdB Z%edCdD Z&edEdF Z'edGdH Z(edIdJ Z)edKdL Z*edMdN Z+edOdP Z,edQdR Z-dSS )TReferenceAnalysisc                 C   
   t | S N)sympyZsympifycdtype r   K/var/www/auris/lib/python3.10/site-packages/torch/utils/_sympy/reference.pyconstant'      
zReferenceAnalysis.constantc                 C      | |B S r   r   abr   r   r   or_+      zReferenceAnalysis.or_c                 C      | |@ S r   r   r"   r   r   r   and_/   r&   zReferenceAnalysis.and_c                 C   s,   t | tjst |tjrt| |S | |kS r   )
isinstancer   ExprEqr"   r   r   r   eq3   s   zReferenceAnalysis.eqc                 C   s   |  | ||S r   )not_r,   clsr#   r$   r   r   r   ne9      zReferenceAnalysis.nec                 C   s   | |k S r   r   r"   r   r   r   lt=   r&   zReferenceAnalysis.ltc                 C   s   | |kS r   r   r"   r   r   r   gtA   r&   zReferenceAnalysis.gtc                 C   s   | |kS r   r   r"   r   r   r   leE   r&   zReferenceAnalysis.lec                 C   s   | |kS r   r   r"   r   r   r   geI   r&   zReferenceAnalysis.gec                 C   s   t | trJ |  S r   )r)   boolr#   r   r   r   r-   M   s   zReferenceAnalysis.not_c                 C   s
   t d| S )Ng      ?r   xr   r   r   
reciprocalR   r    zReferenceAnalysis.reciprocalc                 C   s
   t | dS )N   r   r9   r   r   r   squareV   r    zReferenceAnalysis.squarec                 C      t | S r   )r   r:   r   r   r   r   trunc_to_intZ   r&   zReferenceAnalysis.trunc_to_intc                 C   r   r   )r   ceilingr@   r   r   r   ceil_to_int^   r    zReferenceAnalysis.ceil_to_intc                 C   r   r   )r   floorr@   r   r   r   floor_to_intb   r    zReferenceAnalysis.floor_to_intc                 C      t tj| S r   )r   r   rD   r9   r   r   r   rD   f      zReferenceAnalysis.floorc                 C   rF   r   )r   r   rB   r9   r   r   r   ceilj   rG   zReferenceAnalysis.ceilc                 C   s"   |t jkr	t| S td| dNz	to_dtype z NYI)torchfloat64r   NotImplementedErrorr@   r   r   r   to_dtypen   s   
zReferenceAnalysis.to_dtypec                 C   
   t | |S r   )r   r:   yr   r   r   modt   r    zReferenceAnalysis.modc                 C   r?   r   )absr9   r   r   r   rR   x   r&   zReferenceAnalysis.absc                 C   s   |  S r   r   r9   r   r   r   neg|   s   zReferenceAnalysis.negc                 C   rN   r   r8   r"   r   r   r   truediv   r    zReferenceAnalysis.truedivc                 C   rN   r   )r	   r"   r   r   r   int_truediv   r    zReferenceAnalysis.int_truedivc                 C   rN   r   )r   r"   r   r   r   floordiv   r    zReferenceAnalysis.floordivc                 C      t d)NzTODO: truncdivrL   r"   r   r   r   truncdiv   r&   zReferenceAnalysis.truncdivc                 C      t tj| |S r   )r   operatoraddr"   r   r   r   r\         zReferenceAnalysis.addc                 C   s
   t j| S r   )r   Add)r/   argsr   r   r   sym_sum   r    zReferenceAnalysis.sym_sumc                 C   rZ   r   )r   r[   mulr"   r   r   r   ra      r]   zReferenceAnalysis.mulc                 C   rZ   r   )r   r[   subr"   r   r   r   rb      r]   zReferenceAnalysis.subc                 C   r?   r   )r   r9   r   r   r   exp   r&   zReferenceAnalysis.expc                 C   r?   r   )r   r9   r   r   r   log   r&   zReferenceAnalysis.logc                 C   r?   r   )r   r9   r   r   r   log2   r&   zReferenceAnalysis.log2c                 C   r?   r   )r   r9   r   r   r   sqrt   r&   zReferenceAnalysis.sqrtc                 C   s   t t| |S r   )r   r   r"   r   r   r   pow   rG   zReferenceAnalysis.powc                 C   rN   r   r=   r"   r   r   r   pow_by_natural   r    z ReferenceAnalysis.pow_by_naturalc                 C   rN   r   )r   r"   r   r   r   minimum   r    zReferenceAnalysis.minimumc                 C   rN   r   )r
   r"   r   r   r   maximum   r    zReferenceAnalysis.maximumc                 C   r?   r   )r   r#   r   r   r   r   round_to_int   r&   zReferenceAnalysis.round_to_intc                 C   rN   r   )r   r"   r   r   r   round_decimal   r    zReferenceAnalysis.round_decimalc                 C   rN   r   )r   r"   r   r   r   bitwise_and   r    zReferenceAnalysis.bitwise_andc                 C   rN   r   )r   r"   r   r   r   
bitwise_or   r    zReferenceAnalysis.bitwise_orN).__name__
__module____qualname__staticmethodr   r%   r(   r,   classmethodr0   r2   r3   r4   r5   r-   r;   r>   rA   rC   rE   rD   rH   rM   rQ   rR   rS   rT   rU   rV   rY   r\   r`   ra   rb   rc   rd   re   rf   rg   rh   ri   rj   rl   rm   rn   ro   r   r   r   r   r   &   s    







































r   c                   @   s,  e Zd Zedd Zedd Zedd Zedd Zed	d
 Z	edd Z
edd Zedd Zedd Zedd Zedd Zedd Zedd Zedd Zedd Zedd  Zed!d" Zed#d$ Zed%d& Zed'd( Zed)d* Zed+d, Zed-d. Zed/d0 Zd1S )2PythonReferenceAnalysisc                 C   sD   |t ju r	t| S |t ju rt| S |t ju rt| S td| )Nunrecognized dtype )rJ   int64intdoublefloatr6   AssertionErrorr   r   r   r   r      s   


z PythonReferenceAnalysis.constantc                 C   r   r   )rJ   Zsym_notr7   r   r   r   r-      r    zPythonReferenceAnalysis.not_c                 C   s`   t |dkrdS t |dkr|d S | |d |d }tdt |D ]
}| ||| }q#|S )Nr      r<   )lenr\   range)r/   r_   accir   r   r   r`      s   zPythonReferenceAnalysis.sym_sumc                 C   s   | | S r   r   r"   r   r   r   rV      r&   z PythonReferenceAnalysis.floordivc                 C   s   | | S r   r   rO   r   r   r   rQ      r&   zPythonReferenceAnalysis.modc                 C      | | S r   r   r"   r   r   r   rY      r&   z PythonReferenceAnalysis.truncdivc                 C   s$   |t jkr
t | S td| drI   )rJ   rK   Z	sym_floatrL   r@   r   r   r   rM      s   

z PythonReferenceAnalysis.to_dtypec                 C   rW   )Nz!exp is not valid shape sympy exprr{   r9   r   r   r   rc     r&   zPythonReferenceAnalysis.expc                 C   rW   )Nz!log is not valid shape sympy exprr   r9   r   r   r   rd     r&   zPythonReferenceAnalysis.logc                 C   r   r   )rJ   Z	_sym_log2r9   r   r   r   re   	  r    zPythonReferenceAnalysis.log2c                 C   r   r   )rJ   Z	_sym_sqrtr9   r   r   r   rf     r    zPythonReferenceAnalysis.sqrtc                 C      t | |S r   )rJ   Zsym_minr"   r   r   r   ri        zPythonReferenceAnalysis.minimumc                 C   r   r   )rJ   Zsym_maxr"   r   r   r   rj     r   zPythonReferenceAnalysis.maximumc                 C   r   r   )mathrD   r@   r   r   r   rE     r    z$PythonReferenceAnalysis.floor_to_intc                 C   r   r   )r   rH   r@   r   r   r   rC     r    z#PythonReferenceAnalysis.ceil_to_intc                 C      t t| S r   )rz   r   rD   r9   r   r   r   rD   !  rG   zPythonReferenceAnalysis.floorc                 C   r   r   )rz   r   rH   r9   r   r   r   rH   %  rG   zPythonReferenceAnalysis.ceilc                 C   r   r   r   r"   r   r   r   rT   )  r&   zPythonReferenceAnalysis.truedivc                 C      | | S r   r   r"   r   r   r   rg   -  r&   zPythonReferenceAnalysis.powc                 C   r   r   r   r"   r   r   r   rh   1  s   z&PythonReferenceAnalysis.pow_by_naturalc                 C   r?   r   roundrk   r   r   r   rl   8  r&   z$PythonReferenceAnalysis.round_to_intc                 C   s   t | |dS )N)ndigitsr   r"   r   r   r   rm   <  r   z%PythonReferenceAnalysis.round_decimalc                 C   r'   r   r   r"   r   r   r   rn   @  r&   z#PythonReferenceAnalysis.bitwise_andc                 C   r!   r   r   r"   r   r   r   ro   D  r&   z"PythonReferenceAnalysis.bitwise_orN)rp   rq   rr   rs   r   r-   rt   r`   rV   rQ   rY   rM   rc   rd   re   rf   ri   rj   rE   rC   rD   rH   rT   rg   rh   rl   rm   rn   ro   r   r   r   r   ru      sb    
























ru   c                   @   s   e Zd Zedd ZdS ) OptimizedPythonReferenceAnalysisc                 C   r   r   )rJ   r`   )r_   r   r   r   r`   L  r    z(OptimizedPythonReferenceAnalysis.sym_sumN)rp   rq   rr   rs   r`   r   r   r   r   r   K  s    r   r:   r   returnc                 C      t jjj| |S r   )rJ   opsZprimsZconvert_element_typedefaultr@   r   r   r   	_to_dtypeQ  s   r   c                   @   sX  e Zd Zedd Zedd Zedd Zedd Zed	d
 Zedd Z	e
dd Zedd Zedd Zedd Zedd Zedd Zedd Zedd Zedd Zedd  Zed!d" Zed#d$ Zed%d& Zed'd( Zed)d* Zed+d, Zed-d. Zed/d0 Zed1d2 Zed3d4 Zed5d6 Zed7d8 Z ed9d: Z!ed;d< Z"ed=d> Z#ed?d@ Z$edAdB Z%edCdD Z&edEdF Z'edGdH Z(edIdJ Z)edKdL Z*edMdN Z+edOdP Z,edQdR Z-edSdT Z.edUdV Z/edWdX Z0edYdZ Z1ed[d\ Z2ed]d^ Z3ed_d` Z4edadb Z5dcS )dTensorReferenceAnalysisc                 C   s^   |t ju r
t| }n|t ju rt| }n|t ju rt| }ntd| t jjj	j
||dS )Nrv   )r   )rJ   rw   rx   ry   rz   r6   r{   r   atenZscalar_tensorr   )r   r   dr   r   r   r   f  s   





z TensorReferenceAnalysis.constantc                 C   r   r   )rJ   r   r   
logical_orr   r"   r   r   r   r%   s  r1   zTensorReferenceAnalysis.or_c                 C   r   r   )rJ   r   r   logical_andr   r"   r   r   r   r(   w  r1   zTensorReferenceAnalysis.and_c                 C      t jj| |S r   )rJ   r   r   rn   r"   r   r   r   rn   {  r]   z#TensorReferenceAnalysis.bitwise_andc                 C   r   r   )rJ   r   r   ro   r"   r   r   r   ro     r]   z"TensorReferenceAnalysis.bitwise_orc                 C   r   r   )rJ   r   r   r,   Tensorr"   r   r   r   r,     r1   zTensorReferenceAnalysis.eqc                 C   s   t jjj||S r   )rJ   r   r   r0   r   r.   r   r   r   r0     r1   zTensorReferenceAnalysis.nec                 C   r   r   )rJ   r   r   r2   r   r"   r   r   r   r2     r1   zTensorReferenceAnalysis.ltc                 C   r   r   )rJ   r   r   r3   r   r"   r   r   r   r3     r1   zTensorReferenceAnalysis.gtc                 C   r   r   )rJ   r   r   r4   r   r"   r   r   r   r4     r1   zTensorReferenceAnalysis.lec                 C   r   r   )rJ   r   r   r5   r   r"   r   r   r   r5     r1   zTensorReferenceAnalysis.gec                 C      t jjj| S r   )rJ   r   r   Zlogical_notr   r7   r   r   r   r-     r]   zTensorReferenceAnalysis.not_c                 C   r   r   )rJ   r   r   r;   r   r9   r   r   r   r;     r]   z"TensorReferenceAnalysis.reciprocalc                 C   r   r   )rJ   r   r   r>   r   r9   r   r   r   r>     s   zTensorReferenceAnalysis.squarec                 C      t tjjj| |S r   )r   rJ   r   r   truncr   r@   r   r   r   rA        z$TensorReferenceAnalysis.trunc_to_intc                 C   r   r   )r   rJ   r   r   rH   r   r@   r   r   r   rC     r   z#TensorReferenceAnalysis.ceil_to_intc                 C   r   r   )r   rJ   r   r   rD   r   r@   r   r   r   rE     r   z$TensorReferenceAnalysis.floor_to_intc                 C   r   r   )rJ   r   r   rD   r   r9   r   r   r   rD     r]   zTensorReferenceAnalysis.floorc                 C   r   r   )rJ   r   r   rH   r   r9   r   r   r   rH     r]   zTensorReferenceAnalysis.ceilc                 C   rN   r   )r   r@   r   r   r   rM     r    z TensorReferenceAnalysis.to_dtypec                 C   rW   )Nz8no C-style modulus operation available from frontend atmrX   rO   r   r   r   rQ     s   zTensorReferenceAnalysis.modc                 C   r   r   )rJ   r   r   rR   r   r9   r   r   r   rR     r]   zTensorReferenceAnalysis.absc                 C   r   r   )rJ   r   r   rS   r   r9   r   r   r   rS     r]   zTensorReferenceAnalysis.negc                 C   r   r   )rJ   r   r   true_divider   r"   r   r   r   rT     r1   zTensorReferenceAnalysis.truedivc                 C   rW   )Nz8Python int truediv difficult to implement in PyTorch atm)rL   rJ   r   r   r   r   r   rK   r"   r   r   r   rU        z#TensorReferenceAnalysis.int_truedivc                 C   s   t jjjj| |ddS )NrD   )Zrounding_mode)rJ   r   r   divZTensor_moder"   r   r   r   rV     r   z TensorReferenceAnalysis.floordivc                 C   rW   )Nz9no C-style truncdiv operation available from frontend atmrX   r"   r   r   r   rY     r   z TensorReferenceAnalysis.truncdivc                 C   r   r   )rJ   r   r   r\   r   r"   r   r   r   r\     r1   zTensorReferenceAnalysis.addc                 C   r   r   )rJ   r   r   ra   r   r"   r   r   r   ra     r1   zTensorReferenceAnalysis.mulc                 C   r   r   )rJ   r   r   rb   r   r"   r   r   r   rb     r1   zTensorReferenceAnalysis.subc                 C   r   r   )rJ   r   r   rc   r   r9   r   r   r   rc     r]   zTensorReferenceAnalysis.expc                 C   r   r   )rJ   r   r   rd   r   r9   r   r   r   rd     r]   zTensorReferenceAnalysis.logc                 C   r   r   )rJ   r   r   re   r   r9   r   r   r   re     r]   zTensorReferenceAnalysis.log2c                 C   r   r   )rJ   r   r   rf   r   r9   r   r   r   rf     r]   zTensorReferenceAnalysis.sqrtc                 C   r   r   )rJ   r   r   sinr   r9   r   r   r   r     r]   zTensorReferenceAnalysis.sinc                 C   r   r   )rJ   r   r   cosr   r9   r   r   r   r   
  r]   zTensorReferenceAnalysis.cosc                 C   r   r   )rJ   r   r   tanhr   r9   r   r   r   r     r]   zTensorReferenceAnalysis.tanhc                 C   r   r   )rJ   r   r   sinhr   r9   r   r   r   r     r]   zTensorReferenceAnalysis.sinhc                 C   r   r   )rJ   r   r   coshr   r9   r   r   r   r     r]   zTensorReferenceAnalysis.coshc                 C   r   r   )rJ   r   r   tanr   r9   r   r   r   r     r]   zTensorReferenceAnalysis.tanc                 C   r   r   )rJ   r   r   acosr   r9   r   r   r   r     r]   zTensorReferenceAnalysis.acosc                 C   r   r   )rJ   r   r   atanr   r9   r   r   r   r   "  r]   zTensorReferenceAnalysis.atanc                 C   r   r   )rJ   r   r   asinr   r9   r   r   r   r   &  r]   zTensorReferenceAnalysis.asinc                 C   r   r   rJ   r   r   rg   ZTensor_Tensorr"   r   r   r   rg   *  r1   zTensorReferenceAnalysis.powc                 C   r   r   r   r"   r   r   r   rh   .  s   z&TensorReferenceAnalysis.pow_by_naturalc                 C   r   r   )rJ   r   r   ri   r   r"   r   r   r   ri   3  r1   zTensorReferenceAnalysis.minimumc                 C   r   r   )rJ   r   r   rj   r   r"   r   r   r   rj   7  r1   zTensorReferenceAnalysis.maximumc                 C   r   r   )rJ   r   r   r   r   rk   r   r   r   rl   ;  r]   z$TensorReferenceAnalysis.round_to_intc                 C   rW   )Nz8round decimal doesn't support Tensor second argument atmrX   r"   r   r   r   rm   ?  r   z%TensorReferenceAnalysis.round_decimalN)6rp   rq   rr   rs   r   r%   r(   rn   ro   r,   rt   r0   r2   r3   r4   r5   r-   r;   r>   rA   rC   rE   rD   rH   rM   rQ   rR   rS   rT   rU   rV   rY   r\   ra   rb   rc   rd   re   rf   r   r   r   r   r   r   r   r   r   rg   rh   ri   rj   rl   rm   r   r   r   r   r   c  s    















































r   )!r   r[   typingr   r   rJ   Ztorch.utils._sympy.functionsr   r   r   r   r   r   r	   r
   r   r   r   r   r   r   r   r   r   r   r   r   ru   r   r   r   r   r   r   r   r   r   <module>   s   T /w