a
    hs                     @   s&  d dl Z d dlZd dlZd dlZd dlZd dlmZmZ d dlm	Z	m
Z
mZmZmZ d dlmZ d dlmZ d dlZd dlZd dlmZmZ d dlmZ dd	lmZmZ d
dlmZ d
dlmZ d
dl m!Z!m"Z"m#Z#m$Z$m%Z%m&Z& d
dl'm(Z(m)Z) edZ*e+e,Z-e.dj/Z0G dd de j1Z2ej3ddG dd de2Z4ej3ddG dd de2Z5ej3ddG dd de2Z6ej3ddG dd dZ7ej3G dd dZ8G dd  d e)j9Z:G d!d" d"e)j;Z<e=e>e&e
ej?gej@f f d#d$d%ZAeej? e=e>eBeBej@  e&f d&d'd(ZCd)d*eej? e=e>eBeBej?  e&f d&d+d,ZDd-d)d.d/e
d0e	f eej? eEe=eeBej?  e8d1d2d3ZFdCe	eBeBej?  e&eEe:d4d5d6ZGd7e>eeBej?  eeBej?  f d8d9d:ZHe=d;d<d=ZIG d>d? d?eZJdDe
d0e	f eej? eeej?  eEeej@ d@dAdBZKdS )E    N)IterableSequence)AnyCallableOptionalTypeVarUnion)Self)patch)free_symbolsfree_unbacked_symbols)
OrderedSet   )make_symbolSymT   )index_prevent_reordering)DefaultHandler)get_dtype_sizereduction_num_outputssympy_index_symbol	sympy_str
sympy_subs	VarRanges)ReductionTypeVTzindirect|tmpc                   @   s   e Zd ZU eed< ejed< eje	eef e
dddZejejdddZejedd	d
ZejedddZejedddZdee
dddZdS )Depnameindexrenamesreturnc                 C   s   d S N selfr!   r$   r$   J/var/www/auris/lib/python3.9/site-packages/torch/_inductor/dependencies.pyrename)   s    z
Dep.renamer"   c                 C   s   d S r#   r$   r&   r$   r$   r'   	get_numel-   s    zDep.get_numelc                 C   s   d S r#   r$   r*   r$   r$   r'   numbytes_hint1   s    zDep.numbytes_hintc                 C   s   d S r#   r$   r*   r$   r$   r'   has_unbacked_symbols5   s    zDep.has_unbacked_symbolsc                 C   s   d S r#   r$   r*   r$   r$   r'   is_contiguous9   s    zDep.is_contiguoustprefixr"   c                 C   s   | S r#   r$   )r&   r1   r$   r$   r'   normalize_with_stride_order=   s    zDep.normalize_with_stride_orderN)r/   )__name__
__module____qualname__str__annotations__sympyExprabcabstractmethoddictr	   r(   r+   intr,   boolr-   r.   r2   r$   r$   r$   r'   r   %   s   

r   T)frozenc                   @   sh  e Zd ZU eed< ejed< eejdf ed< eejdf ed< dZ	e
e ed< edd	d
ZeedddZd e
ee  dddZejdddZd dddZd/ed dddZeeejejf dddZd dddZejdddZeeef d ddd Zedd!d"Zedd#d$Zedd%d&Zd0eed(d)d*Zedd+d,Zedd-d.ZdS )1	MemoryDepr   r   .	var_namessizeNmoder)   c                 C   s<   d}| j d urd| j  }d| jd| j d| j | dS )N , z
MemoryDep())rC   r   r   ranges)r&   Z
maybe_moder$   r$   r'   __repr__I   s    
zMemoryDep.__repr__c                 C   s
   t | jS r#   )lenrA   r*   r$   r$   r'   num_varsO   s    zMemoryDep.num_varsotherr"   c                    s  | j |j ksJ | j t| jjkr&dS |j t|jjkr<dS tdd t| j|jD r^dS tj	j
| j| j}tj	j
|j|j}tt|t|kstt|t|krtd| ||| dS t|t|krdS dd t|D   fdd|D }t|ttd	| j ksJ |S )
zD
        Can return None if not able to decide loop orders.
        Nc                 s   s   | ]}|d kp|dkV  qdS )r   r   Nr$   .0sr$   r$   r'   	<genexpr>h       z7MemoryDep.decide_loop_order_to_match.<locals>.<genexpr>zaunable to decide loop order. self_dep=%s v.s. other_dep=%s, self_strides=%s v.s. other_strides=%sc                 S   s   i | ]\}}||qS r$   r$   )rN   irO   r$   r$   r'   
<dictcomp>   rQ   z8MemoryDep.decide_loop_order_to_match.<locals>.<dictcomp>c                    s   g | ]} | qS r$   r$   rM   Zstride_to_indexr$   r'   
<listcomp>   rQ   z8MemoryDep.decide_loop_order_to_match.<locals>.<listcomp>r   )rJ   rI   r   r   any	itertoolschainrB   r   graphsizevarsstride_hintsrA   r   logdebug	enumeraterange)r&   rL   Zself_stridesZother_stridesorderr$   rT   r'   decide_loop_order_to_matchS   s8    
z$MemoryDep.decide_loop_order_to_matchc                 C   s   t | jt| jdS )zF
        Return the offset by setting every variable to be 0.
        r   )r   r   r<   fromkeysrA   r*   r$   r$   r'   
get_offset   s    zMemoryDep.get_offsetc                 C   s$   t | jgt| j| j| jR  S )z
        Normalize by merging loops. The different to normalize_with_stride_order is,
        this method does not reorder loops while normalize_with_stride_order reorder
        loops based on stride order.
        )r@   r   _RecordLoadStoreInner
_normalizer   rG   rC   r*   r$   r$   r'   	normalize   s    zMemoryDep.normalizer/   r0   c                    s   ddl m} tjj| j| j}tt	t
||jdd}||}| j}| j}||}||}	tjj|	|t| jg|	|\}
}}t|\} tt|	| fdd|
D }tt| j|}t| j|t| t| }|S )a'  
        Used to decide if two MemoryDep does not equal due to different loop orders.
        More specifically, when dep1 and dep2 are not equal, we can normalize
        both and check if they are equal after that. If yes, then the mismatch is
        caused by different loop orders.
        r   )irT)keyreversec                    s   g | ]} |qS r$   r$   rN   xadd_varr$   r'   rU      rQ   z9MemoryDep.normalize_with_stride_order.<locals>.<listcomp>)Ztorch._inductorrg   r   rY   rZ   r[   r   rA   sortedr_   rI   __getitem__Zsame_reorderrB   _simplify_loopsr   var_builderr<   zipr   r8   expandr@   r   tuplekeysvalues)r&   r1   rg   stridesr`   Zstride_reordersizesrA   Znew_reordered_sizesZnew_reordered_var_namesZnew_simplified_sizesreindex_prune
var_rangesreplacementZ	new_indexoutr$   rl   r'   r2      s6    


	z%MemoryDep.normalize_with_stride_orderc                 C   s   t t| j| jS )z{c0: 128, c1: 512, ...})r<   rr   rA   rB   r*   r$   r$   r'   rG      s    zMemoryDep.rangesc                 C   s*   t | jtjj| j| j| j| j	| j
dS )N)r   r   rA   rB   rC   )r@   r   r   rY   rZ   simplify_with_rangesr   rG   rA   rB   rC   r*   r$   r$   r'   r~      s    zMemoryDep.simplify_with_rangesc                 C   sX   |   rtj| j}n<t| jj}tj	j
}t| j| jD ]\}}||v r:|| }q:|S r#   )is_indirectr   rY   r+   r   r   r   r   r8   SOnerr   rA   rB   )r&   ZnumelvarsvarrB   r$   r$   r'   r+      s    
zMemoryDep.get_numelr    c                 C   s.   | j |v r*t|| j  | j| j| j| jdS | S )N)rA   rB   rC   )r   r@   r   rA   rB   rC   r%   r$   r$   r'   r(      s    
zMemoryDep.renamec                 C   s@   z&t jj|  tt j| j W S  ty:   Y dS 0 d S Nr   	r   rY   rZ   Z	size_hintr+   r   Z	get_dtyper   NotImplementedErrorr*   r$   r$   r'   r,      s    zMemoryDep.numbytes_hintc                 C   s   t t|  dkS r   rI   r   r+   r*   r$   r$   r'   r-      s    zMemoryDep.has_unbacked_symbolsc                 C   s,   t | jtjrdS t | jtjo*| j| jv S )NT)
isinstancer   r8   IntegerSymbolrA   r*   r$   r$   r'   r.      s    zMemoryDep.is_contiguousT)result_for_complex_expressionr"   c                 C   s   t | jdkrdS t| jtjr(| jjn| jg}| jd }|D ]d}||krP dS t|tjr>t |jdkr>|jd |kr>t|jd ttj	fr>|jd dkr> dS q>|S )zA
        Whether the stride for the last dimension is 1.
        r   Tr   r   F)
rI   rA   r   r   r8   AddargsZMulr=   r   )r&   r   ZtermsZlast_symZtermr$   r$   r'   stride1_for_last_dim   s&    

zMemoryDep.stride1_for_last_dimc                 C   s6   t | jtjr$| j| jvo"|   S t | jttjfS r#   )r   r   r8   r   rA   r   r=   r   r*   r$   r$   r'   	is_scalar  s    zMemoryDep.is_scalarc                 C   s   t dd | jjD S )Nc                 s   s   | ]}t |jV  qd S r#   )r   r   rN   vr$   r$   r'   rP   !  rQ   z(MemoryDep.is_indirect.<locals>.<genexpr>)rV   r   r   r*   r$   r$   r'   r      s    zMemoryDep.is_indirect)r/   )T) r3   r4   r5   r6   r7   r8   r9   rt   r   rC   r   rH   propertyr=   rJ   listra   rc   rf   r2   r<   rG   r~   r+   r(   r,   r>   r-   r.   r   r   r   r$   r$   r$   r'   r@   A   s.   

9,	r@   c                   @   s   e Zd ZU eed< dZee ed< eej	dddZ
ej	dddZeeef d d	d
dZedddZedddZedddZedddZedddZdS )StarDepr   NrC   r)   c                 C   s   t dd S )NzStarDep does not have an indexr   r*   r$   r$   r'   r   *  s    zStarDep.indexc                 C   s   t j| jS r#   )r   rY   r+   r   r*   r$   r$   r'   r+   .  s    zStarDep.get_numelr    c                 C   s    | j |v rt|| j  | jS | S r#   )r   r   rC   r%   r$   r$   r'   r(   1  s    
zStarDep.renamec                 C   s@   z&t jj|  tt j| j W S  ty:   Y dS 0 d S r   r   r*   r$   r$   r'   r,   6  s    zStarDep.numbytes_hintc                 C   s   t t|  dkS r   r   r*   r$   r$   r'   r-   >  s    zStarDep.has_unbacked_symbolsc                 C   s   dS NFr$   r*   r$   r$   r'   r.   A  s    zStarDep.is_contiguousc                 C   s   dS r   r$   r*   r$   r$   r'   r   D  s    zStarDep.is_scalarc                 C   s   dS r   r$   r*   r$   r$   r'   r   G  s    zStarDep.is_indirect)r3   r4   r5   r6   r7   rC   r   r   r8   r9   r   r+   r<   r(   r=   r,   r>   r-   r.   r   r   r$   r$   r$   r'   r   $  s   
r   c                   @   s   e Zd ZU eed< eed< eejdddZejdddZ	e
eef d dd	d
ZedddZedddZedddZdS )WeakDepr   mutating_bufr)   c                 C   s   t dd S )NzWeakDep does not have an indexr   r*   r$   r$   r'   r   Z  s    zWeakDep.indexc                 C   s   t jjS r#   )r8   r   r   r*   r$   r$   r'   r+   ^  s    zWeakDep.get_numelr    c                 C   s    | j |v rt|| j  | jS | S r#   )r   r   r   r%   r$   r$   r'   r(   a  s    
zWeakDep.renamec                 C   s   dS )Nr   r$   r*   r$   r$   r'   r,   f  s    zWeakDep.numbytes_hintc                 C   s   dS r   r$   r*   r$   r$   r'   r-   i  s    zWeakDep.has_unbacked_symbolsc                 C   s   dS r   r$   r*   r$   r$   r'   r.   l  s    zWeakDep.is_contiguousN)r3   r4   r5   r6   r7   r   r8   r9   r   r+   r<   r(   r=   r,   r>   r-   r.   r$   r$   r$   r'   r   S  s   
r   c                   @   s<   e Zd ZU ejed< eejdf ed< eejdf ed< dS )IndexExprDepr   .rA   rB   N)r3   r4   r5   r8   r9   r7   rt   r   r$   r$   r$   r'   r   p  s   

r   c                   @   s   e Zd ZU ee ed< ee ed< ee ed< dZee	e
j  ed< dZee ed< eeef d ddd	Zeeee f d d
ddZd d dddZee	d  d dddZee d dddZee dddZdeee dddZdS )
ReadWritesreadswritesindex_exprsN
range_varsr{   r    c                    s>   t t fdd| jD t fdd| jD | j| j| jS )Nc                 3   s   | ]}|  V  qd S r#   r(   rN   depr!   r$   r'   rP     rQ   z$ReadWrites.rename.<locals>.<genexpr>c                 3   s   | ]}|  V  qd S r#   r   r   r   r$   r'   rP     rQ   )r   r   r   r   r   r   r{   r%   r$   r   r'   r(     s    zReadWrites.rename)r   r"   c                 C   sJ   t |tttfsJ t |ts(t|g}tt| j|| j| j| j	| j
S r#   )r   r   r   r   r   unionr   r   r   r   r{   )r&   r   r$   r$   r'   	with_read  s    

zReadWrites.with_readrK   c                 C   s@   t | j|j}t | j|j}t | j|j}t|| ||S r#   )r   r   r   r   r   r   )r&   rL   r   r   r   r$   r$   r'   merge  s    zReadWrites.merge)read_writesr"   c                 C   sL   t jdd | D  }t jdd | D  | }t jdd | D  }t|||S )Nc                 S   s   g | ]
}|j qS r$   )r   rN   rwr$   r$   r'   rU     rQ   z)ReadWrites.merge_list.<locals>.<listcomp>c                 S   s   g | ]
}|j qS r$   )r   r   r$   r$   r'   rU     rQ   c                 S   s   g | ]
}|j qS r$   )r   r   r$   r$   r'   rU     rQ   )r   r   r   )r   Z
all_writesZ	all_readsZall_index_exprsr$   r$   r'   
merge_list  s    zReadWrites.merge_list)	rem_readsr"   c                 C   s   t | j| | j| j| j| jS r#   )r   r   r   r   r   r{   )r&   r   r$   r$   r'   remove_reads  s    zReadWrites.remove_readsr)   c                 C   s   t | j| jS r#   )rW   rX   r   r   r*   r$   r$   r'   reads_and_writes  s    zReadWrites.reads_and_writesT)ignore_integer_indexr"   c                 C   sF   t  }|  D ]2}t|tsq|r4t|jttjfs||j	 q|S )z6
        Integer index is used for load_seed.
        )
r   r   r   r@   r   r=   r8   r   addr   )r&   r   namesr   r$   r$   r'   buffer_names  s    
zReadWrites.buffer_names)T)r3   r4   r5   r   r   r7   r   r   r   r   r8   r9   r{   r   r<   r6   r(   r   r   r   staticmethodr   r   r   r   r>   r   r$   r$   r$   r'   r   w  s   
		r   c                
       sv  e Zd Zeedd fddZeeee	j
f ee	j
 ee	j
 ddddZee	j
eee	j
ee	jdf ee	j
df f d	d
dZe	j
ee	j
ee	jdf ee	j
df f dddZee	j
edddZeeedddZd ee	j
eee edddZee	j
eedddZe	j
eej edddZd!eeee	j
e	j
e	j
f eejeeeee	j
f  ee ddddZ  ZS )"rd   Nr{   rf   r"   c                    s2   t    t | _t | _t | _|| _|| _d S r#   )super__init__r   _reads_writes_index_exprs_var_ranges_should_normalize)r&   r{   rf   	__class__r$   r'   r     s    
z_RecordLoadStoreInner.__init__)r   rA   rx   r"   c                 C   s<   t | tjsdS | j}|r8|d |vr8|  |  qdS )zz
        Reduction has last (reduced) dim in its sizes, but
        downstream users won't.  Normalize this away.
        Nr   )r   r8   r9   r   pop)r   rA   rx   r   r$   r$   r'   drop_unused_symbols  s    
z)_RecordLoadStoreInner.drop_unused_symbols.)r   r{   r"   c           
         s   g |  }t| }tjj||t|g||\}}}tt	 \} t
t|| fdd|D }	tt||	}g |  }g |}| ||| |t|t|fS )Nc                    s   g | ]} |qS r$   r$   rj   rl   r$   r'   rU     rQ   z4_RecordLoadStoreInner._normalize.<locals>.<listcomp>)ru   rt   rv   r   rY   rZ   rp   r   rq   canonicalization_prefixr<   rr   r   r8   rs   r   )
clsr   r{   Z
index_varsrx   	new_sizesry   rz   Znew_varsr|   r$   rl   r'   re     s    
 z _RecordLoadStoreInner._normalize)r   r"   c                 C   s   | j sbdd | j D }dd t| j |D }dd |D }| ||| |t|t|fS dd | j D }| ||S )Nc                 S   s   g | ]}t jj|qS r$   r   rY   rZ   Zsimplifyrj   r$   r$   r'   rU     rQ   z6_RecordLoadStoreInner.canonicalize.<locals>.<listcomp>c                 S   s   g | ]\}}|d kr|qS r   r$   rN   kr   r$   r$   r'   rU     rQ   c                 S   s   g | ]}|d kr|qS r   r$   r   r$   r$   r'   rU     rQ   c                 S   s    i | ]\}}|t jj|qS r$   r   r   r$   r$   r'   rS     s   z6_RecordLoadStoreInner.canonicalize.<locals>.<dictcomp>)	r   r   rv   rr   ru   r   rt   itemsre   )r&   r   rx   rA   r{   r$   r$   r'   canonicalize  s    z"_RecordLoadStoreInner.canonicalize)r   r   r"   c                 C   s4   | j t|g| |R   d| dt| dS )Nzload(rE   rF   )r   r   r@   r   r   r&   r   r   r$   r$   r'   load  s    z_RecordLoadStoreInner.loadc                 C   s    t |tsJ | |t|S r#   )r   r=   r   r8   r   r   r$   r$   r'   	load_seed  s    z_RecordLoadStoreInner.load_seed)r   r   valuerC   r"   c              	   C   sF   | j t|g| |R d|i d| dt| d| d| d	S )NrC   zstore(rE   rF   )r   r   r@   r   r   )r&   r   r   r   rC   r$   r$   r'   store  s    $z_RecordLoadStoreInner.store)r   r   r   r"   c                 C   s   |  ||d| dS )Nzstore_reduction(rF   )r   )r&   r   r   r   r$   r$   r'   store_reduction  s    z%_RecordLoadStoreInner.store_reduction)r   dtyper"   c                 C   s,   | j t| |  dt| d| dS )Nzindex_expr(rE   rF   )r   r   r   r   r   )r&   r   r   r$   r$   r'   
index_expr  s    z _RecordLoadStoreInner.index_expr)rv   
boundariesboundary_indicesindexing_dtyperightsortersorter_indicesr"   c                 C   s4   | j t|d  |dur0| j t|d  dS )z?Records the names of the buffers that bucketize will read from.r   N)r   r   r   )r&   rv   r   r   r   r   r   r   r$   r$   r'   	bucketize  s    z_RecordLoadStoreInner.bucketize)N)NN)r3   r4   r5   r   r>   r   r   r   r=   r8   r9   r   r   classmethodrt   r   re   r   r6   r   r   r   r   r   torchr   r   r   r   __classcell__r$   r$   r   r'   rd     sF   ""   rd   c                       s&   e Zd Zeedd fddZ  ZS )RecordLoadStoreNr   c                    s   t ||d}t j|d d S )Nr{   rf   )parent_handler)rd   r   r   )r&   r{   rf   r   r   r$   r'   r   *  s    zRecordLoadStore.__init__)r3   r4   r5   r   r>   r   r   r$   r$   r   r'   r   )  s   r   r0   c                    s0   t   i tjtjd fdd}|fS )N)lengthr"   c                    s    t  t  }| |< |S r#   )r   next)r   r   Zcntr1   r{   r$   r'   rm   6  s    zvar_builder.<locals>.add_var)rW   countr8   r9   r   )r1   rm   r$   r   r'   rq   2  s    rq   )argsizesr1   r"   c                    s&   t | \}  fdd|D }||fS )Nc                    s   g | ]}t t |qS r$   )r   map)rN   rB   rl   r$   r'   rU   B  rQ   z)index_vars_no_squeeze.<locals>.<listcomp>)rq   )r1   r   r{   r   r$   rl   r'   index_vars_no_squeeze>  s    r   d)r1   c           
      G   sb   ddl m} t| \}}g }g }|D ]4}||\}}	|| ||	tt|| q$||fS )Nr   )SqueezeView)rg   r   rq   Zsqueezerappendr   r   )
r1   r   r   r{   rm   r   r   rB   Znew_sizery   r$   r$   r'   index_vars_squeezeF  s    
r   Fr$   )rf   r1   hidden_args.)fnr   rf   r1   r   r"   c                G   s   t |d|i\}}ddlm} t| |r@t| g ||||}nNt||d}	t|	" | g ||R   W d    n1 s~0    Y  |	j}|rg }
ng t	j
|}
tt|jt|j|j|
|S )Nr1   r   )LoopBody)rf   )r   	loop_bodyr   r   extract_loop_body_with_argsr   r   set_ops_handlerr   rW   rX   from_iterabler   r   r   r   r   )r   rf   r1   r   r   r   r{   r   innerr   r   r$   r$   r'   extract_read_writesU  s(    
0r   )r   r   r{   rf   r"   c                    sP  ddl m} t||d}| |}| jrRdd t| jD   fdd| D }| j|j D ]}|	|j
||j  q^| j|j D ]}||j
t||j  q| j|j D ]}||j
||j d |j q| j|j D ]}||j
||j d  q| j|j D ]}|||j d  q| j|j D ]"}|d |j
d d d fd d d  q(|S )Nr   )MemoryUsageTyper   c                 S   s   i | ]\}}|t tj|qS r$   )r   r   TMP)rN   rR   r   r$   r$   r'   rS     rQ   z/extract_loop_body_with_args.<locals>.<dictcomp>c                    s   i | ]\}}|t | qS r$   )r   r   replr$   r'   rS     rQ   )r   r   rd   Zindexing_from_argsZindirect_varsr^   r   Zmemory_usageZLOADr   Zbuffer_nameZ
index_nameZ	LOAD_SEEDr   r=   ZSTOREr   rC   ZSTORE_REDUCTIONr   Z
INDEX_EXPRr   Z	BUCKETIZEr   )r   r   r{   rf   r   r   Zname_to_indexentryr$   r   r'   r   y  sD    
r   ztorch._inductor.ir.IRNode)
input_noder"   c                 C   s  ddl m}m}m} t|  |rRg |  }g |  }t|dkrN||fS dS t| j	j	|sddS | 
 }d}d}|du rt|dkrt }g }|D ]}	t|	tsq|	j|v rq||	j tj|	j}
|
du rq|
 }|du st||rqt||rdt| dkrd|du r8g | }g | }n*|g | ks\|g | krr dS q||
  q||kr||fS t|}qt||fS )aX  
    Returns the size and reduction size of all inputs, if the sizes and reduction_sizes (if exist) are all the same.
    It's possible that a node has multiple inputs, some are Reduction nodes and others are Pointwise nodes.
    In this case, reduction_sizes of the Reduction nodes need to be the same.
    Otherwise returns (None, None).
    r   )ComputedBufferExternKernelLoopsr   NNN)rg   r   r   r   r   Zget_defining_opget_sizeZget_reduction_sizerI   dataZ	get_readsr   r@   r   r   r   rY   Ztry_get_bufferextend)r   r   r   r   rB   Zreduction_sizer   seenZ	new_readsreadbufferopr$   r$   r'   #extract_input_node_reduction_ranges  sP    





r  r)   c                   C   s   dS )Ncr$   r$   r$   r$   r'   r     s    r   c                   @   s  e Zd ZU eej ed< deddddZe	e
edf ee	ef edd	d
Zdeeeejf eeejdddZee
d dddZeeee e
d dddZeee eee
d dddZejejeede
d f ede
d f dddZeedef eddddZdS ) FreeSymbolsOpsHandlersymbolsTN)unbacked_onlyr"   c                 C   s   t  | _|rtnt| _d S r#   )r   r  r   r   get_symbols)r&   r  r$   r$   r'   r     s    zFreeSymbolsOpsHandler.__init__.)r   r   kwargsr"   c                 C   sD   t || D ].}t|tjtjjjfr|  j	| 
|O  _	qd S r#   )rW   rX   rv   r   r8   r9   logicboolalgBooleanr  r  )r&   r   r   r  ar$   r$   r'   _default  s    zFreeSymbolsOpsHandler._default)	index_varrB   checkwrap_negr"   c                 C   sB   t |tjtjjjfrJ |  j| |O  _tdt	| dS )N(rF   )
r   r8   r9   r	  r
  r  r  r  r   r6   )r&   r  rB   r  r  r$   r$   r'   indirect_indexing  s    z'FreeSymbolsOpsHandler.indirect_indexing)N.)rk   r"   c                 C   s   dS )Nr   r$   )r&   rk   r$   r$   r'   frexp  s    zFreeSymbolsOpsHandler.frexp)dtypes
combine_fnrv   r"   c                 C   s   dt | S Nr#   rI   )r&   r  r  rv   r$   r$   r'   scan  s    zFreeSymbolsOpsHandler.scan)r  rv   stable
descendingr"   c                 C   s   dt | S r  r  )r&   r  rv   r  r  r$   r$   r'   sort  s    zFreeSymbolsOpsHandler.sort)r   	src_dtypereduction_typer   r"   c                 C   s   t |}|dkrd| S d S )Nr   r#   )r   )r&   r   r  r  r   Z
num_valuesr$   r$   r'   	reduction  s    zFreeSymbolsOpsHandler.reduction)maskbodyrL   r"   c                 C   s   t |sJ d|  d S )Nz$masked body must always be callable.)callable)r&   r  r   rL   r$   r$   r'   masked  s    zFreeSymbolsOpsHandler.masked)T)TT)r3   r4   r5   r   r8   r   r7   r>   r   r6   rt   r   r<   r  r   r=   r9   r  r  r   r  r  r   r   r   r  r   r"  r$   r$   r$   r'   r    s4   
$	  

r  )r   r   rindexr  r"   c              	   C   s   ddl m} |d ur||gn|g}t|}t|F t|dd | |  W d    n1 sb0    Y  W d    n1 s0    Y  |jS )Nr   )FlexibleLayoutZallow_indexingT)rg   r$  r  r   r   r
   objectr  )r   r   r#  r  r$  r   handlerr$   r$   r'   extract_free_symbols%  s    Dr'  )F)NT)Lr:   ZdataclassesrW   loggingrecollections.abcr   r   typingr   r   r   r   r   Ztyping_extensionsr	   Zunittest.mockr
   r8   r   Z%torch.fx.experimental.symbolic_shapesr   r   Ztorch.utils._ordered_setr   Zutils._sympy.symbolr   r   Zcodegen.commonr   Zops_handlerr   utilsr   r   r   r   r   r   Zvirtualizedr   r   r   	getLoggerr3   r\   compilesearchr   ABCr   Z	dataclassr@   r   r   r   r   ZMockHandlerrd   ZKernelFormatterHandlerr   r6   rt   r9   r   rq   r   r   r   r>   r   r   r  r   r  r'  r$   r$   r$   r'   <module>   s    

 c
.

Dm	&
	

( 1C7  
