a
    kh                     @  s<  d dl mZmZ d dlZd dlZd dlZd dlZd dlZd dlZd dl	Z	d dl
mZ d dlmZ d dlmZ d dlmZmZmZmZmZmZmZmZmZmZ d dlmZ d dlmZ d	d
lm Z  d	dl!m"Z" d	dl#m$Z$m%Z%m&Z&m'Z' e(de)d  Z*edZ+G dd dej,Z-ddddZ.G dd dZ/i Z0g Z1dd Z2d?ddZ3G dd dee+ Z4dd Z5d d! Z6d"d# Z7eG d$d% d%Z8G d&d' d'e4e+ Z9edd(d)d*d+Z:edddddddd,d-d-d.d.d/d/d0d1d2d+Z:d@dddddddd,d3d-d-d.d.d/d/d4d5d6d+Z:G d7d8 d8Z;G d9d: d:Z<d;d< Z=d=d> Z>dS )A    )annotationsdivisionN)defaultdict)	dataclass)cached_property)
CallableGenericIterableOptionalTypeVarUnionoverloadDictAnyTupleTensorDescriptor)
ModuleType   )knobs)driver)find_paths_ifget_iterable_pathtype_canonicalisation_dictcanonicalize_dtypez.runtime.jitTc                      s   e Zd ZdZdd fddZedd Zdd	 Zd
d Zdd Z	dd Z
dd Zdd Zdd Zdd Zdd Zdd Zdd Z  ZS )DependenciesFindera  
    This AST visitor is used to find dependencies of a JITFunction. This can
    be used to invalidate a JITFunction's hash when its source code -- or
    that of its dependencies -- changes.

    This visitor also keeps track of the global variables touched by the
    JITFunction.  When we launch the kernel, we check that these have the same
    values as they did when we ran this visitor.  If not, we raise an error (or
    otherwise we could recompile).
    Nonereturnc                   sH   t    || _t|d| _|| _|| _h d| _	i | _
d| _d S )Nutf-8>
   minfloatprintgetattrlistmaxrangelenint
isinstanceF)super__init__namehashlibsha256encodehasherglobals	nonlocalssupported_python_builtinsused_global_valsvisiting_arg_default_value)selfr-   r2   r3   src	__class__ @/var/www/auris/lib/python3.9/site-packages/triton/runtime/jit.pyr,   )   s    

zDependenciesFinder.__init__c                 C  s
   | j  S N)r1   	hexdigestr7   r;   r;   r<   retN   s    zDependenciesFinder.retc                 C  s&   t |jrdS t|dd}|tS )NT
__module__ )inspect	isbuiltinfuncr$   
startswithTRITON_MODULE)r7   noderE   moduler;   r;   r<   _is_triton_builtinR   s    z%DependenciesFinder._is_triton_builtinc                 C  s   t |tr| j |j @ D ]\}|\}}| j| \}}|j| \}}||krtd| d| d| j d|j d| dq| j|j |j}|t	t
|dd7 }| j|d	 d S )
NGlobal variable z has value z when compiling z, but inner kernel z has conflicting value z7 from when it was first compiled.  This is not allowed.noinlineFr    )r*   JITFunctionr5   keysRuntimeErrorr-   __name__update	cache_keystrr$   r1   r0   )r7   rE   kvar_name_Zv1Zv2Zfunc_keyr;   r;   r<   _update_hashX   s    
&zDependenciesFinder._update_hashc                   s   t |jtju r|jS |j jv r&d S  fdd}||j\}}|d ur jst |turt|t	st
|dds|j jvrt||f j|jt|f<  | |S )Nc                   sD    j | d }|d ur | j fS  j| d }|d ur@| jfS dS )N)NN)r2   getr3   )r-   valr?   r;   r<   name_lookupr   s    

z2DependenciesFinder.visit_Name.<locals>.name_lookupZ__triton_builtin__F)typectxastStoreidlocal_namesr6   r   r*   rM   r$   r4   copyr5   rW   )r7   rH   rZ   rY   Zvar_dictr;   r?   r<   
visit_Namej   s(    	





zDependenciesFinder.visit_Namec                   s    fdd|j D S )Nc                   s   g | ]}  |qS r;   )visit).0eltr?   r;   r<   
<listcomp>       z2DependenciesFinder.visit_Tuple.<locals>.<listcomp>)eltsr7   rH   r;   r?   r<   visit_Tuple   s    zDependenciesFinder.visit_Tuplec                 C  s\   |  |j}t|tjr&|  |j}q|d u s>t|ddtkrBd S t||j}| | |S )NrP   rB   )	rc   valuer*   r]   	Attributer$   rG   attrrW   )r7   rH   lhsr@   r;   r;   r<   visit_Attribute   s    
z"DependenciesFinder.visit_Attributec                 C  s"   dd |j j D | _| | d S )Nc                 S  s   h | ]
}|j qS r;   arg)rd   rq   r;   r;   r<   	<setcomp>   rg   z7DependenciesFinder.visit_FunctionDef.<locals>.<setcomp>)argsr`   generic_visitri   r;   r;   r<   visit_FunctionDef   s    z$DependenciesFinder.visit_FunctionDefc                   sn    fdd}t |j|j|jr&|jgng |jD ]} | q0||j |jd ur` |j ||j	 d S )Nc                   sD   z6 j rJ d _ | D ]}|d ur | qW d _ nd _ 0 d S )NTF)r6   rc   )defaultsexprr?   r;   r<   visit_defaults   s    
z:DependenciesFinder.visit_arguments.<locals>.visit_defaults)
	itertoolschainposonlyargsrs   vararg
kwonlyargsrc   kw_defaultskwargrv   )r7   rH   rx   rq   r;   r?   r<   visit_arguments   s    
(

z"DependenciesFinder.visit_argumentsc                 C  s8   |  |}t|tr(|  jt|O  _n| j| d S r=   )rc   r*   r%   r`   setadd)r7   rH   targetr;   r;   r<   visitAssnTarget   s    

z"DependenciesFinder.visitAssnTargetc                 C  s4   t |jdkrtd| |jd  | | d S )N   z2Simultaneous multiple assignment is not supported.r   )r(   targets	TypeErrorr   rt   ri   r;   r;   r<   visit_Assign   s    zDependenciesFinder.visit_Assignc                 C  s   |  |j | | d S r=   r   r   rt   ri   r;   r;   r<   visit_AnnAssign   s    z"DependenciesFinder.visit_AnnAssignc                 C  s   |  |j | | d S r=   r   ri   r;   r;   r<   	visit_For   s    zDependenciesFinder.visit_For)rP   rA   __qualname____doc__r,   propertyr@   rJ   rW   rb   rj   ro   ru   r   r   r   r   r   __classcell__r;   r;   r9   r<   r      s   %
'
 	r   rS   r   c                 C  s  dd l m  m} t| tr|  } | dr^| d} t| } | dsNJ d| dd   S | 	dr|dt| d d  S | drdt| dd   S | drt| dS nJt| |j
rdt| j S t| |jr| j} nt| tr| j} nt| } t| d	d
| S )Nr   zconst const**kr   ztl.Z_trB   )triton.language.corelanguagecorer*   rS   striprF   removeprefix_normalize_tyendswithZpointer_typeZ
element_tydtyper-   r[   rP   r   rX   replace)tyr   r;   r;   r<   r      s,    






r   c                   @  s   e Zd ZdZdddddddZedd	 Zed
dddZed
dddZedd Z	edd Z
edd Zedd ZdS )KernelParamzBRepresents a parameter (name plus metadata) to a @jit'ed function.r)   zinspect.Parameterbool)numparamdo_not_specializedo_not_specialize_on_alignmentc                 C  s   || _ || _|| _|| _d S r=   )r   _paramr   r   )r7   r   r   r   r   r;   r;   r<   r,   
  s    zKernelParam.__init__c                 C  s   | j jS r=   )r   r-   r?   r;   r;   r<   r-     s    zKernelParam.namerS   r   c                 C  s(   | j jr| j jtjjkrdS t| j jS )NrB   )r   
annotationrC   	Parameteremptyr   r?   r;   r;   r<   r     s    zKernelParam.annotationc                 C  sN   | j }|dr|dd  }n|dr4|dd  }|tt v rJ| j S dS )Nr   r   r   r   rB   )r   rF   r   r   values)r7   ar;   r;   r<   annotation_type  s    

zKernelParam.annotation_typec                 C  s
   d| j v S N	constexpr)r   r?   r;   r;   r<   is_constexpr&  s    zKernelParam.is_constexprc                 C  s    | j r
dS d| jv p| jdS )NFr   r   )r   r   rF   r?   r;   r;   r<   is_const*  s    zKernelParam.is_constc                 C  s   | j jS r=   )r   defaultr?   r;   r;   r<   r   0  s    zKernelParam.defaultc                 C  s   | j jtjjkS r=   )r   r   rC   r   r   r?   r;   r;   r<   has_default4  s    zKernelParam.has_defaultN)rP   rA   r   r   r,   r   r-   r   r   r   r   r   r   r   r;   r;   r;   r<   r     s    




r   c                   s0   ddl m ddlm  d	 fdd	S )
Nr   r   r   r   FTc                   s6   d u rdS t  trdS t  tr|r6 d|dnd } dkrJ|rJdS d krb dkrbd	|fS d
 krz dkrzd|fS d|fS nt  trdS t dr j|f}t|d }|d u r|d rdndt|d  }|t|< | r d|dnd }||fS t  t	rd j
fS t  r,d fS t dr<dS t  trfdd D } fdd}|dd |D }	|dd |D }
|	|
fS t  trt jdsJ t jj}d| t j dd fS t  r"t jdsJ t jj}d| t j d jdd fS td t  d S )!N)r   N)u1Nr)   )alignr   )r   r   i   iZi32l            l    Zu64Zi64)Zfp32Ndata_ptrr   r   r   tensorr   Ztma_desc_cpu_ptr)Z	nvTmaDescNc                   s   g | ]} |qS r;   r;   rd   x)specialize_implr;   r<   rf   c  rg   zCcreate_specialize_impl.<locals>.specialize_impl.<locals>.<listcomp>c                   s   t  drt |  S t| S )N_fields)hasattrr[   tuple)valsrp   r;   r<   <lambda>d  rg   zAcreate_specialize_impl.<locals>.specialize_impl.<locals>.<lambda>c                 S  s   g | ]}|d  qS r   r;   r   r;   r;   r<   rf   e  rg   c                 S  s   g | ]}|d  qS r   r;   r   r;   r;   r<   rf   f  rg   ztensordesc<>,zUnsupported type: %s)r*   r   r)   r"   r   r   	dtype2strrX   r   rM   rR   r   r   baser%   Zblock_shapeZlayoutr   r[   )rq   r   specialize_valuer   keyZdskresspecZ
make_tupleZtysrN   innerZGluonTensorDescriptorr   specialize_extrar   rp   r<   r   B  sX    




"z/create_specialize_impl.<locals>.specialize_impl)FTT)r   r   Z'triton.experimental.gluon.nvidia.hopperr   )r   r;   r   r<   create_specialize_impl=  s    1r   Fc                 C  s6   t tdkrttdd  td }|| |dd S )Nr   c                 [  s   d S r=   r;   )rV   kwargsr;   r;   r<   r   x  rg   zmangle_type.<locals>.<lambda>)r   )r(   specialize_impl_cacheappendr   )rq   
specializer   r;   r;   r<   mangle_typev  s    r   c                   @  s$   e Zd ZU ded< ddddZdS )KernelInterfacer   runr   c                   s    fddS )z
        A JIT function is launched with: fn[grid](*args, **kwargs).
        Hence JITFunction.__getitem__ returns a callable proxy that
        memorizes the grid.
        c                    s   j |  dd|S )NFgridwarmup)r   )rs   r   r   r7   r;   r<   r     rg   z-KernelInterface.__getitem__.<locals>.<lambda>r;   )r7   r   r;   r   r<   __getitem__  s    zKernelInterface.__getitem__N)rP   rA   r   __annotations__r   r;   r;   r;   r<   r   }  s   
r   c           	   	   C  sl   dd |  D }dd l}| |dd | D t| dd | D t| |j|d}||}|S )Nc                 S  s*   i | ]"\}}||j jd kr"t|n|qS r   )r:   rP   rS   rd   r   rk   r;   r;   r<   
<dictcomp>  rg   z1serialize_specialization_data.<locals>.<dictcomp>r   c                 S  s   g | ]}t |qS r;   r%   r   r;   r;   r<   rf     rg   z1serialize_specialization_data.<locals>.<listcomp>c                 S  s   g | ]}t |qS r;   r   r   r;   r;   r<   rf     rg   )r-   	signatureconstant_keysconstant_vals
attrs_keys
attrs_valsoptionsr   )itemsjsonrN   r%   r   __dict__dumps)	r-   r   	constantsattrsr   r   r   objZserialized_objr;   r;   r<   serialize_specialization_data  s    $
r   c              
   C  s  t | jt |ksJ g }t| j |D ]\}}|jrL|d| d q*|jrVdnd}|jrddnd}|jrrdnd}d| d| d| d| d	}	|j	rt
|j	tr|j	dks|j	dd	 d
v rd}|r|d|j	 d|	 d n|d|j	 d q*||	  q*dd }
ddtt|
| j dg  dddd | j D  dd| d}dd | j D }t|d< t|j|d< t|| |d S )a2  
    Equivalent to sig.bind followed by apply_defaults. This generates a
    native Python function (using exec) which can be memoized on a per-kernel
    basis to avoid having to run these expensive functions -- which constitute
    much of the kernel launch overhead -- every time we run the kernel.
    z("constexpr", )TrueFalsezspecialize_impl(, r   Nr   )fpZbfFz("z",) + z[1:]z", None)c                 S  s0   | d j tjju r| d S | d  d| d  S )Nr   r   z	=default_r   rC   r   r   )r   r;   r;   r<   r     rg   z0create_function_from_signature.<locals>.<lambda>z
def dynamic_func(z	**optionsz):
    params = {c                 S  s   g | ]}d | d| qS )'z': r;   )rd   r-   r;   r;   r<   rf     rg   z2create_function_from_signature.<locals>.<listcomp>z}
    specialization = [r   z-]
    return params, specialization, options
c                 S  s,   i | ]$\}}|j tjjurd | |j qS )Zdefault_r   )rd   r-   r   r;   r;   r<   r     s   z2create_function_from_signature.<locals>.<dictcomp>rM   r   Zdynamic_func)r(   
parametersziprN   r   r   r   r   r   r   r*   rS   joinr%   mapr   rM   r   Zget_arg_specializationexec)sigZkparamsbackendspecializationr-   Zkpr   r   r   r@   rq   Z	func_bodyZfunc_namespacer;   r;   r<   create_function_from_signature  s@    
r   c                 C  s   | j  d| j S )N.)rA   r   fnr;   r;   r<   get_full_name  s    r   c                   @  s&   e Zd ZU ded< ded< ded< dS )JitFunctionInfor   rI   rS   r-   rM   Zjit_functionN)rP   rA   r   r   r;   r;   r;   r<   r     s   
r   c                      s   e Zd Zdd ZddddZdd Zd	d
 Zdd Zdd Zd&ddZ	dd Z
edd Zedd Zdd Zdd Zdd Zdd Z fd d!Z fd"d#Zd$d% Z  ZS )'rM   c                 C  s   dS )NFr;   r?   r;   r;   r<   is_gluon  s    zJITFunction.is_gluonzbool | Noner   c	                 C  s   |sd S | j j}	| j j}
ddd t| j|d D }|	 d|j d|j d|j d|j	 d	|j
 d
| d}t| j }t||||d ||}||||j|j|j|j	|j
|j|||d}|||t|
|	| d|i||ddS )Nr   c                 S  s    g | ]\}}|j  d | qS )z: r-   )rd   r   r   r;   r;   r<   rf     rg   z*JITFunction._call_hook.<locals>.<listcomp>r   z[num_warps=z, num_ctas=z, num_stages=z, enable_fp_fusion=z, launch_cooperative_grid=](r   r   )r   devicer   	num_warpsnum_ctas
num_stagesenable_fp_fusionlaunch_cooperative_gridextern_libsconfigsspecialization_data	is_warmupr   F)r   reprr   compileZis_manual_warmupZalready_compiled)r   r   rA   r   r   paramsr  r  r  r  r  r   r   r	  r   )r7   hookr   r   r  r   r   r
  r  r-   rI   Z	arg_reprsr  Z	full_namer  r   r;   r;   r<   
_call_hook  s:     8


zJITFunction._call_hookc                 C  s   t |sJ | j| dS )z
        Add a hook that will be executed prior to the execution of run
        function with args and kwargs passed into the kernel
        N)callablepre_run_hooksr   )r7   r  r;   r;   r<   add_pre_run_hook  s    zJITFunction.add_pre_run_hookc                 C  sX   ddl m}m}m}m} tj }||}|| _|| _|| _t| j	| j
|}i |||fS )z1
        Precompute as much as possible.
        r   )CompiledKernelr  	ASTSourcemake_backend)compilerr  r  r  r  r   activeZget_current_targetr   r   r  )r7   r  r  r  r  r   r   binderr;   r;   r<   create_binder  s    
zJITFunction.create_binderc          !   
     s  | d| jptjj|d< tj }tj|}| jD ]}||i | q6| j	| \}}	}
|
|i |\}}t
|t
| }| |d }|d u r|}dd | jD }dd |D }dd t||D }d|vsJ dd	|vsJ d
d|vsJ d|D ](}||jvr||vrtd| qt|dd }fdd|D }dd |D  t dd } fdd|D }| tjj||||||g|rd S | | |||}| j||	|jd}|||< | tjj||||||g| t }| j D ]B\\}}\}}| || }|krtd| d| d| q|s|d usVJ t|rh|}t|}|d }|dkr|d nd}|dkr|d nd}|j||g R  } |j|||||j|j | tjj!tjj"g	 R   |S )Ndebugc                 S  s   g | ]
}|j qS r;   r  r   r;   r;   r<   rf   <  rg   z#JITFunction.run.<locals>.<listcomp>c                 S  s   g | ]}|d  qS r   r;   r   r;   r;   r<   rf   =  rg   c                 S  s   i | ]\}}||qS r;   r;   )rd   rT   vr;   r;   r<   r   >  rg   z#JITFunction.run.<locals>.<dictcomp>Zdevice_typez=device_type option is deprecated; current target will be usedr  z8device option is deprecated; current device will be usedstreamz8stream option is deprecated; current stream will be usedz2Keyword argument %s was specified but unrecognisedc                 S  s   |dkS r   r;   )rV   rY   r;   r;   r<   r   G  rg   z!JITFunction.run.<locals>.<lambda>c                   s    i | ]}|t t  |qS r;   )r   r%   r   )rd   path)
bound_argsr;   r<   r   H  rg   c                 S  s   g | ]}|d  qS r   r;   r   r;   r;   r<   rf   J  rg   c                 S  s
   t |tS r=   )r*   rS   )rV   r   r;   r;   r<   r   K  rg   c                   s   i | ]}| t |qS r;   )Z
parse_attrr   )rd   rT   )attrvalsr   r;   r<   r   L  rg   )r   r   rK   z1 has changed since we compiled this kernel, from z to r   r   r   )#rX   r  r   runtimer   r  get_current_deviceZget_current_streamr  device_cachesrS   Zparse_optionsr  r   r   KeyErrorr   r  Zjit_cache_hookr  r  Zjit_post_compile_hookobjectr5   r   rO   r  r(   launch_metadatar   r   functionZpacked_metadataZlaunch_enter_hookZlaunch_exit_hook)!r7   r   r   rs   r   r  r  r  Zkernel_cacher   r  r   r   r   kernelZsigkeysZsigvalsr   rT   
constexprsr   r8   Znot_presentr-   rV   rY   Zglobals_dictZnewValZ	grid_sizeZgrid_0Zgrid_1Zgrid_2r'  r;   )r!  r   r   r<   r   #  sp    




zJITFunction.runc                 C  s   | j d u r| jS |  |S r=   )_repr_fn_name)r7   rV   r;   r;   r<   r  m  s    zJITFunction.reprNc	                 C  sz  |r|ng }|r|ng }|| _ |j| _|| _t|| _|| _|| _t|d | _	|| _
t|| _|| _g | _t| jj D ]B\}	}
|	|v p|
j|v }|	|v p|
j|v }| jt|	|
|| qtt|}|td|tj d  }| | t| j| _d | _ i | _!d | _"|| _#|| _$dd | jD | _%dd | jD | _&g | _'|j(| _(|j)| _)|j*| _*|j+| _+|j| _d S )Nr   z^def\s+\w+\s*\(c                 S  s   g | ]
}|j qS r;   r  rd   pr;   r;   r<   rf     rg   z(JITFunction.__init__.<locals>.<listcomp>c                 S  s   g | ]}|j r|jqS r;   )r   r   r-  r;   r;   r<   rf     rg   ),r   rA   rI   versionrC   r   r   r   getsourcelinesstarting_line_numberr+  r   r,  r'  r  	enumerater   r   r-   r   r   textwrapdedent	getsourceresearch	MULTILINEstart_unsafe_update_srcr   r  r$  hashr5   r)  r  rL   	arg_namesr*  r  r   rP   r   __globals__)r7   r   r/  r   r   r  rL   r  r'  ir   ZdnsZdns_oar8   r;   r;   r<   r,   p  sD    

zJITFunction.__init__c                 C  s   | j t| jjB S r=   )r=  rC   getclosurevarsr   r3   r?   r;   r;   r<   get_capture_scope  s    zJITFunction.get_capture_scopec                 C  sh   | j d u rbt| jj}t| j| j|| jd}|	| 
  |jt| j | _ tt|j | _| j S )N)r-   r2   r3   r8   )r;  rC   r?  r   r3   r   r,  r=  r8   rc   parser@   rS   r1  dictsortedr5   r   )r7   r3   Zdependencies_finderr;   r;   r<   rR     s    
zJITFunction.cache_keyc                 C  s   ddl m} |S )Nr   r   )r   r   )r7   r   r;   r;   r<   r[     s    zJITFunction.typec                O  s   | j ttj||dd|S )NTr   )r   r   
MockTensor
wrap_dtype)r7   r   rs   r   r;   r;   r<   r     s    zJITFunction.warmupc                   s  ddl m}m} dd l}dd lm  tj }|	|}|d | j
kr`td|d  d| j
 tt|d }|d } fd	d
t||D }	tt|d }
|d }tt|
|}t|d  }|| ||	|}dd
 |d  D }|d }||d |}|| j| d |< |S )Nr   )r  r  r   r-   zSpecialization data is for z but trying to preload for r   r   c                   s,   i | ]$\}}| j |r$  |n|qS r;   )r   Zis_dtyper   tlr;   r<   r     s   z'JITFunction.preload.<locals>.<dictcomp>r   r   r   c                 S  s(   i | ] \}}|t |tr t|n|qS r;   )r*   r%   r   r   r;   r;   r<   r     s   r   r   )r  r  r  r   Ztriton.languager   r   r  r#  loadsr,  rO   r   r   r   rB  r   r$  )r7   r  r  r  r   r  Zdeserialized_objr   r   r   r   r   r   r   r8   r   r   r)  r;   rF  r<   preload  s4    



zJITFunction.preloadc                 C  sH   t | j}t|t jsJ t|jdks.J t|jd t jsDJ |S )Nr   r   )r]   rA  r8   r*   Moduler(   bodyFunctionDef)r7   treer;   r;   r<   rA    s
    zJITFunction.parsec                 O  s   t dd S )Nz:Cannot call @triton.jit'd outside of the scope of a kernel)rO   )r7   rs   r   r;   r;   r<   __call__  s    zJITFunction.__call__c                   s.   |dkrt d| dtt| || d S )Nr8   zCannot set attribute 'zX' directly. Use '_unsafe_update_src()' and manually clear `.hash` of all callersinstead.)AttributeErrorr+   rM   __setattr__)r7   r-   rk   r9   r;   r<   rP    s    zJITFunction.__setattr__c                   s   d| _ t d| dS )z
        The only method allowed to modify src.
        Bypasses the __setattr__ restriction by calling super().__setattr__ directly.
        Nr8   )r;  r+   rP  )r7   Znew_srcr9   r;   r<   r:    s    zJITFunction._unsafe_update_srcc                 C  s   d| j  d| jj dS )NzJITFunction(:r   )rI   r   r   r?   r;   r;   r<   __repr__  s    zJITFunction.__repr__)NNNNNNN)rP   rA   r   r   r  r  r  r   r  r,   r@  r   rR   r[   r   rI  rA  rN  rP  r:  rR  r   r;   r;   r9   r<   rM     s*   .J  
>

 rM   JITFunction[T]r   r   c                 C  s   d S r=   r;   r   r;   r;   r<   jit  s    rU  r/  r  r'  r   r   r  rL   zOptional[Callable]zOptional[Iterable[int | str]]zOptional[bool]zCallable[[T], JITFunction[T]])r  r'  r   r   r  rL   r   c                 C  s   d S r=   r;   rV  r;   r;   r<   rU    s    zOptional[T]z4Union[JITFunction[T], Callable[[T], JITFunction[T]]])r   r  r'  r   r   r  rL   r   c          	        s8   ddd fdd}| dur0|| S |S dS )a<  
    Decorator for JIT-compiling a function using the Triton compiler.

    :note: When a jit'd function is called, arguments are
        implicitly converted to pointers if they have a :code:`.data_ptr()` method
        and a `.dtype` attribute.

    :note: This function will be compiled and run on the GPU. It will only have access to:

           * python primitives,
           * builtins within the triton package,
           * arguments to this function,
           * other jit'd functions

    :param fn: the function to be jit-compiled
    :type fn: Callable
    r   rS  rT  c              
     sT   t | sJ tjjr8ddlm} ||  dS t|  dS d S )Nr   )InterpretedFunction)r/  r   r   r  rL   r  r'  )r  r   r"  Z	interpretinterpreterrW  rM   )r   rW  r  r   r   r'  rL   r  r/  r;   r<   	decorator8  s"    zjit.<locals>.decoratorNr;   )	r   r/  r  r'  r   r   r  rL   rZ  r;   rY  r<   rU    s     c                   @  s<   e Zd ZdZedd Zdd Zedd Zedd	 Zd
S )rD  zr
    Can be used in place of real tensors when calling:
        kernel.warmup(MockTensor(torch.float32), ...)
    c                 C  s"   | j jdkr| jdkrt| S | S )Nr   Ztorch)r:   rP   rA   rD  rp   r;   r;   r<   rE  ]  s    zMockTensor.wrap_dtypec                 C  s
   || _ d S r=   r   )r7   r   r;   r;   r<   r,   c  s    zMockTensor.__init__c                   C  s   dS Nr   r;   r;   r;   r;   r<   r   f  s    zMockTensor.data_ptrc                   C  s   dS r[  r;   r;   r;   r;   r<   	ptr_rangej  s    zMockTensor.ptr_rangeN)	rP   rA   r   r   staticmethodrE  r,   r   r\  r;   r;   r;   r<   rD  W  s   

rD  c                   @  sb   e Zd Zdd Zdd Zdd Zddd	d
Zdd Zdd Zdd Z	dd Z
dd Zdd ZdS )TensorWrapperc                 C  s*   || _ || _|j| _|j| _| jj| _d S r=   )r   r   datar  shape)r7   r   r   r;   r;   r<   r,   q  s
    zTensorWrapper.__init__c                 C  s
   | j  S r=   )r   r   r?   r;   r;   r<   r   x  s    zTensorWrapper.data_ptrc                 G  s   | j j| S r=   )r   stride)r7   rs   r;   r;   r<   ra  {  s    zTensorWrapper.striderS   r   c                 C  s   d| j  d| j dS )NzTensorWrapper[r  r   )r   r   r?   r;   r;   r<   __str__~  s    zTensorWrapper.__str__c                 C  s
   | j  S r=   )r   element_sizer?   r;   r;   r<   rc    s    zTensorWrapper.element_sizec                 C  s   t | j | jS r=   )r^  r   cpur   r?   r;   r;   r<   rd    s    zTensorWrapper.cpuc                 C  s   | j |j  d S r=   )r   copy_)r7   otherr;   r;   r<   re    s    zTensorWrapper.copy_c                 C  s   t | j | jS r=   )r^  r   cloner   r?   r;   r;   r<   rg    s    zTensorWrapper.clonec                 C  s   t | j|| jS r=   )r^  r   tor   )r7   r  r;   r;   r<   rh    s    zTensorWrapper.toc                 C  s   t | j|| jS r=   )r^  r   	new_emptyr   )r7   sizesr;   r;   r<   ri    s    zTensorWrapper.new_emptyN)rP   rA   r   r,   r   ra  rb  rc  rd  re  rg  rh  ri  r;   r;   r;   r<   r^  o  s   r^  c                 C  sV   t | tr*|| jjkr| jS t| j|S n(t| dr>t| |S tdt|  dd S )Nr   zCannot reinterpret a r   )r*   r^  r   r   r   r   r[   )r   r   r;   r;   r<   reinterpret  s    


rk  c                 C  sd   | }t |ts|j}q|jjj}t|j\}}t|D ]"\}}| 	dr8||7 } q\q8||fS )Nzdef )
r*   rM   r   __code__co_filenamerC   r0  r2  r   rF   )r   Zbase_fn	file_namelinesZ
begin_lineidxliner;   r;   r<   get_jit_fn_file_line  s    

rr  )F)N)?
__future__r   r   r]   ra   r.   rC   ry   r6  r3  collectionsr   Zdataclassesr   	functoolsr   typingr   r   r	   r
   r   r   r   r   r   r   Ztriton.tools.tensor_descriptorr   typesr   rB   r   Zruntime.driverr   _utilsr   r   r   r   rP   r(   rG   r   NodeVisitorr   r   r   r   r   r   r   r   r   r   r   r   rM   rU  rD  r^  rk  rr  r;   r;   r;   r<   <module>   sx   0 Q29
:  /   <%