a
    jh5c                     @   s  d dl Z d dlZd dlZd dlZd dlZd dlZd dlmZ d dlmZ d dl	m
Z
 d dlmZ d dlmZ d dlmZ ejejeZejedgZejed	Zd
gZe  dd Ze  dd ZG dd deZdd ZddddddZddddddZ dZ!dd Z"G dd dZ#e$d d! e%d"D Z&d#e&d$< d$e&d%< d%e&d#< d&d' Z'd(d) Z(G d*d+ d+eZ)G d,d- d-eZ*dS ).    N)Path)knobs)compile_module_from_src)_allocation)	GPUTarget)	GPUDriverincludelibcudac                  C   s   t jj } r| gS tddg }dd | D }dd |D }td}|rj|sjdd |	dD }d	}|r|d
t
| 7 }|d7 }n|d7 }|d7 }tdd |D sJ ||S )Nz/sbin/ldconfigz-pc                 S   s    g | ]}d |v r|  d qS )libcuda.so.1)split).0line r   K/var/www/auris/lib/python3.9/site-packages/triton/backends/nvidia/driver.py
<listcomp>       z libcuda_dirs.<locals>.<listcomp>c                 S   s   g | ]}t j|qS r   )ospathdirname)r   locr   r   r   r      r   ZLD_LIBRARY_PATHc                 S   s&   g | ]}t jt j|d r|qS )r   r   r   existsjoin)r   dirr   r   r   r       r   :zlibcuda.so cannot found!
z!Possible files are located at %s.z:Please create a symlink of libcuda.so to any of the files.z<Please make sure GPU is set up and then run "/sbin/ldconfig"z- (requires sudo) to refresh the linker cache.c                 s   s$   | ]}t jt j|d V  qdS )r   Nr   )r   r   r   r   r   	<genexpr>(   r   zlibcuda_dirs.<locals>.<genexpr>)r   ZnvidiaZlibcuda_path
subprocesscheck_outputdecode
splitlinesr   getenvr   strany)Zenv_libcuda_pathZlibsZlocsdirsZenv_ld_library_pathmsgr   r   r   libcuda_dirs   s     

r'   c                   C   s   t gt S N)libdevice_dirr'   r   r   r   r   library_dirs,   s    r*   c                       s$   e Zd Z fddZdd Z  ZS )	CudaUtilsc                    s"   t | dstt| | | _| jS )Ninstance)hasattrsuperr+   __new__r,   )cls	__class__r   r   r/   8   s    
zCudaUtils.__new__c                 C   sR   t ttjtd dt tt	d}|j
| _
|j| _|j| _|j| _|j| _d S )Nzdriver.cZ
cuda_utilssrcnamer*   include_dirs	libraries)r   r   r   r   r   r   	read_textr*   r6   r7   Zload_binaryZget_device_propertiesZcuOccupancyMaxActiveClustersZset_printf_fifo_sizefill_tma_descriptor)selfmodr   r   r   __init__=   s    zCudaUtils.__init__)__name__
__module____qualname__r/   r<   __classcell__r   r   r1   r   r+   6   s   r+   c                 C   sH   | d dkrdS |  drdS ddddd	d
ddd
dddddddd|  S )Nr   *ZCUdeviceptr
tensordescZCUtensorMapint32_tint8_tint16_tint64_tuint32_tuint8_tuint16_tuint64_tdouble)i1i8Zi16i32i64u1u8u16u32Zu64fp16bf16fp32f32fp64	nvTmaDesc)
startswith)tyr   r   r   	ty_to_cppQ   s.    
r\   rI   rG   rJ   )rT   rU   rV   rW   rX   Z	pack_fp16Z	pack_bf16Z	pack_fp32Z	pack_fp64ZiiiKKppOOOOOc                    s  fdd}fdd fdd fdd||  }d	d
 t|D }dfdd|  D }t| }g }|  D ]}|| q~dd
 t|D }t|dkrdddd | D  nd}	g }
| D ]N\}}|dkrq|tv r|
t|  d|  q|
t| d|  qd|
}g }| D ]\}}|d dkrh|d| d nT|tv r|d| d n6|dkr|d|  n|dkr>|d|  q>t	t|}d}dd | D }dd | D }dd | D }d d | D }|d! d"t|dkr:d| nd d#d| d$| fd%d| D  d&| d'|	 d(|| d|| d|| d)t|dkrdd| nd d*}|S )+Nc           
         s   g }d}| D ]}t |tr|drڈ r0 | nd }|d7 }td|}|d}|d}|dd }|d u r|d|  td| D ]}	|d qn
|d	 t|D ]}	|d
 qt|D ]}	|d qq|| q r|t	 ksJ |S )Nr   rB      ztensordesc<([^[>]*)\[([^]]*)\]   ,rA   rO   rY   rN   )

isinstancer#   rZ   rematchgroupcountappendrangelen)
	signatureoutputtensordesc_idxsigmetarb   dtypeshapendim_)tensordesc_metar   r   _expand_signature~   s,    


z(make_launcher.<locals>._expand_signaturec                    s.   t | tr | D ]} || qn
||  d S r(   )r`   tuplere   )rk   ri   x)_flatten_signaturer   r   ru      s    
z)make_launcher.<locals>._flatten_signaturec                    sJ   t | tr&dt | }d| dS | d dkr6dS | dv rBdS t| S )Nr_   []r   rA   z	PyObject*	constexprrY   )r`   rs   r   mapr\   r[   val_extracted_typer   r   r~      s    
z&make_launcher.<locals>._extracted_typec                    sr   t | tr&dt | }d| dS | d dkr6dS | dv rBdS | drPdS d	d
ddddddddd
t|  S )N ()r   rA   Orx   rB   dlbhiLBHIK)
rK   longrD   rE   rC   rF   rH   rI   rG   rJ   )r`   rs   r   rz   rZ   r\   r{   	format_ofr   r   r      s,    

z make_launcher.<locals>.format_ofc                 S   s   i | ]\}}||qS r   r   r   r   sr   r   r   
<dictcomp>   r   z!make_launcher.<locals>.<dictcomp>r   c                    s   g | ]} |qS r   r   )r   r[   r   r   r   r      r   z!make_launcher.<locals>.<listcomp>c                 S   s   i | ]\}}||qS r   r   r   r   r   r   r      r   r   , c                 s   s   | ]\}}d | V  qdS )z&_argNr   r   r   r[   r   r   r   r      r   z make_launcher.<locals>.<genexpr>ry   z argrA   Zptr_infoz.dev_ptrZ_argZ_storagerY   z*tma_ptrz
  c                 S   s:   g | ]2\}}|d  dkrd| d| d| d| d	qS )r   rA   zDevicePtrInfo ptr_infoz = getPointer(_argr   z); if (!ptr_infoz.valid) return NULL;r   r   r   r   r   r      s   c              	   S   s0   g | ](\}}|d krd| d| d| dqS )rY   zCUtensorMap* tma_ptrz = getTmaDesc(_argz); if (!tma_ptrz) return NULL;r   r   r   r   r   r      s   c              
   S   s<   g | ]4\}}|t v rt |  d | dt|  d| dqS ) _argz_storage = z(_argz);)FLOAT_STORAGE_TYPEFLOAT_PACK_FUNCTIONr   r   r   r   r      s   c                 S   s"   g | ]\}}|d krd| qS )ry   z&argr   r   r   r   r   r      r   z&global_scratchaR  
#include "cuda.h"
#include <stdbool.h>
#include <Python.h>
#include <dlfcn.h>

static inline void gpuAssert(CUresult code, const char *file, int line)
{
   if (code != CUDA_SUCCESS)
   {
      const char* prefix = "Triton Error [CUDA]: ";
      const char* str;
      cuGetErrorString(code, &str);
      char err[1024] = {0};
      strcat(err, prefix);
      strcat(err, str);
      PyGILState_STATE gil_state;
      gil_state = PyGILState_Ensure();
      PyErr_SetString(PyExc_RuntimeError, err);
      PyGILState_Release(gil_state);
   }
}

#define CUDA_CHECK(ans) { gpuAssert((ans), __FILE__, __LINE__); }

typedef CUresult (*cuLaunchKernelEx_t)(const CUlaunchConfig* config, CUfunction f, void** kernelParams, void** extra);

static cuLaunchKernelEx_t getLaunchKernelExHandle() {
  // Open the shared library
  void* handle = dlopen("libcuda.so.1", RTLD_LAZY);
  if (!handle) {
    PyErr_SetString(PyExc_RuntimeError, "Failed to open libcuda.so.1");
    return NULL;
  }
  // Clear any existing error
  dlerror();
  cuLaunchKernelEx_t cuLaunchKernelExHandle = (cuLaunchKernelEx_t)dlsym(handle, "cuLaunchKernelEx");
  // Check for errors
  const char *dlsym_error = dlerror();
  if (dlsym_error) {
    PyErr_SetString(PyExc_RuntimeError, "Failed to retrieve cuLaunchKernelEx from libcuda.so.1");
    return NULL;
  }
  return cuLaunchKernelExHandle;
}

static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int launch_pdl, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function, CUdeviceptr global_scratchz) {
  void *params[] = { a,   };
  if (gridX*gridY*gridZ > 0) {
    // 4 attributes that we can currently pass maxmimum
    CUlaunchAttribute launchAttr[4];
    static cuLaunchKernelEx_t cuLaunchKernelExHandle = NULL;
    if (cuLaunchKernelExHandle == NULL) {
      cuLaunchKernelExHandle = getLaunchKernelExHandle();
    }
    CUlaunchConfig config;
    config.gridDimX = gridX;
    config.gridDimY = gridY;
    config.gridDimZ = gridZ;

    if (num_ctas != 1) {
      config.gridDimX *= clusterDimX;
      config.gridDimY *= clusterDimY;
      config.gridDimZ *= clusterDimZ;
    }

    config.blockDimX = 32 * num_warps;
    config.blockDimY = 1;
    config.blockDimZ = 1;
    config.sharedMemBytes = shared_memory;
    config.hStream = stream;
    config.attrs = launchAttr;
    int num_attrs = 0;

    if (launch_pdl != 0) {
      CUlaunchAttribute pdlAttr = { .id = CU_LAUNCH_ATTRIBUTE_PROGRAMMATIC_STREAM_SERIALIZATION, .value = 1};
      launchAttr[num_attrs] = pdlAttr;
      ++num_attrs;
    }

    if (launch_cooperative_grid != 0) {
      CUlaunchAttribute coopAttr = { .id = CU_LAUNCH_ATTRIBUTE_COOPERATIVE, .value = 1};
      launchAttr[num_attrs] = coopAttr;
      ++num_attrs;
    }

    if (num_ctas != 1) {
      CUlaunchAttribute clusterAttr = {};
      clusterAttr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
      clusterAttr.value.clusterDim.x = clusterDimX;
      clusterAttr.value.clusterDim.y = clusterDimY;
      clusterAttr.value.clusterDim.z = clusterDimZ;
      launchAttr[num_attrs] = clusterAttr;
      ++num_attrs;

      CUlaunchAttribute clusterSchedulingAttr = {};
      clusterSchedulingAttr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE;
      clusterSchedulingAttr.value.clusterSchedulingPolicyPreference = CU_CLUSTER_SCHEDULING_POLICY_SPREAD;
      launchAttr[num_attrs] = clusterSchedulingAttr;
      ++num_attrs;
    }

    config.numAttrs = num_attrs;

    CUDA_CHECK(cuLaunchKernelExHandle(&config, function, params, 0));
  }
}

typedef struct _DevicePtrInfo {
    CUdeviceptr dev_ptr;
    bool valid;
} DevicePtrInfo;

static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {
  DevicePtrInfo ptr_info;
  ptr_info.dev_ptr = 0;
  ptr_info.valid = true;
  if (PyLong_Check(obj)) {
    ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(obj);
    return ptr_info;
  }
  if (obj == Py_None) {
    // valid nullptr
    return ptr_info;
  }
  PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr");
  if(ptr){
    PyObject *empty_tuple = PyTuple_New(0);
    PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL);
    Py_DECREF(empty_tuple);
    Py_DECREF(ptr);
    if (!PyLong_Check(ret)) {
      PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
      ptr_info.valid = false;
      return ptr_info;
    }
    ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(ret);
    if(!ptr_info.dev_ptr)
      return ptr_info;
    uint64_t dev_ptr;
    int status = cuPointerGetAttribute(&dev_ptr, CU_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr);
    if (status == CUDA_ERROR_INVALID_VALUE) {
        PyErr_Format(PyExc_ValueError,
                     "Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx);
        ptr_info.valid = false;
    } else if (status != CUDA_SUCCESS) {
        CUDA_CHECK(status);  // Catch any other cuda API errors
        ptr_info.valid = false;
    }
    ptr_info.dev_ptr = dev_ptr;
    Py_DECREF(ret);  // Thanks ChatGPT!
    return ptr_info;
  }
  PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
  ptr_info.valid = false;
  return ptr_info;
}

static inline CUtensorMap* getTmaDesc(PyObject *obj) {
  if (sizeof(CUtensorMap*) != 8) {
    PyErr_SetString(PyExc_SystemError, "getTmaDesc() requires 64-bit compilation");
    return NULL;
  }

  PyObject *method_handle = PyObject_GetAttrString(obj, "tma_desc_cpu_ptr");
  if (!method_handle) {
    PyErr_SetString(PyExc_TypeError, "tma_desc_cpu_ptr() method does not exist");
    return NULL;
  }

  PyObject *empty_tuple = PyTuple_New(0);
  if (!empty_tuple) {
    Py_DECREF(method_handle);
    PyErr_SetString(PyExc_SystemError, "Internal Python error!");
    return NULL;
  }
  PyObject *method_ret = PyObject_Call(method_handle, empty_tuple, NULL);
  Py_DECREF(empty_tuple);
  Py_DECREF(method_handle);
  if (!method_ret) {
    PyErr_SetString(PyExc_SystemError, "Internal Python error!");
    return NULL;
  }

  if (!PyLong_Check(method_ret)) {
    PyErr_SetString(PyExc_TypeError, "tma_desc_cpu_ptr() must return 64-bit int");
    Py_DECREF(method_ret);
    return NULL;
  }

  uint64_t ptr_as_uint = PyLong_AsUnsignedLongLong(method_ret);
  Py_DECREF(method_ret);
  if (!ptr_as_uint) {
    PyErr_SetString(PyExc_ValueError, "received NULL ptr from tma_desc_cpu_ptr()");
    return NULL;
  }
  if (ptr_as_uint % 64 != 0) {
    PyErr_SetString(PyExc_ValueError, "tma_desc_cpu_ptr() must be 64-byte aligned");
    return NULL;
  }

  return (CUtensorMap*)(ptr_as_uint);
}

static void ensureCudaContext() {
  CUcontext pctx;
  CUDA_CHECK(cuCtxGetCurrent(&pctx));
  if (!pctx) {
    // Ensure device context.
    CUdevice device;
    CUDA_CHECK(cuDeviceGet(&device, 0));
    CUDA_CHECK(cuDevicePrimaryCtxRetain(&pctx, device));
    CUDA_CHECK(cuCtxSetCurrent(pctx));
  }
}

static uint16_t pack_fp16(double f) {
    uint16_t result;
    // from https://github.com/python/pythoncapi-compat
#if 0x030600B1 <= PY_VERSION_HEX && PY_VERSION_HEX <= 0x030B00A1 && !defined(PYPY_VERSION)
    _PyFloat_Pack2(f, (unsigned char*)&result, 1);
#else
    PyFloat_Pack2(f, (unsigned char*)&result, 1);
#endif
    return result;
}

static uint16_t pack_bf16(double f) {
    float f32 = (float)f;
    uint32_t u32 = *(uint32_t*)&f32;
    return (uint16_t)(u32 >> 16);
}

static uint32_t pack_fp32(double f) {
    float f32 = (float)f;
    return *(uint32_t*)&f32;
}

static uint64_t pack_fp64(double f) {
    return *(uint64_t*)&f;
}

static PyObject* launch(PyObject* self, PyObject* args) {
  // ensure cuda context is valid before calling any CUDA APIs, e.g. before getPointer calls cuPointerGetAttributes
  ensureCudaContext();

  int gridX, gridY, gridZ;
  uint64_t _stream;
  uint64_t _function;
  int launch_cooperative_grid;
  int launch_pdl;
  PyObject *launch_enter_hook = NULL;
  PyObject *launch_exit_hook = NULL;
  PyObject *kernel_metadata = NULL;
  PyObject *launch_metadata = NULL;
  PyObject *global_scratch_obj = NULL;
  c                    s$   g | ]\}} | d | dqS )r   ;r   r   r}   r   r   r     r   z
  if(!PyArg_ParseTuple(args, "a7  ", &gridX, &gridY, &gridZ,
                                           &_stream, &_function, &launch_cooperative_grid, &launch_pdl, &global_scratch_obj,
                                           &kernel_metadata, &launch_metadata,
                                           &launch_enter_hook, &launch_exit_hooka  )) {
    return NULL;
  }

  int num_warps, num_ctas, shared_memory, clusterDimX, clusterDimY, clusterDimZ;
  if (!PyArg_ParseTuple(kernel_metadata, "iiiiii", &num_warps, &num_ctas, &shared_memory, &clusterDimX, &clusterDimY, &clusterDimZ)) {
    PyErr_SetString(PyExc_TypeError, "kernel_metadata must be a tuple");
    return NULL;
  }

  // extract launch metadata
  if (launch_enter_hook != Py_None){
    PyObject* args = Py_BuildValue("(O)", launch_metadata);
    PyObject* ret = PyObject_CallObject(launch_enter_hook, args);
    Py_DECREF(args);
    if (!ret)
      return NULL;
    Py_DECREF(ret);
  }

  CUdeviceptr global_scratch = 0;
  if (global_scratch_obj != Py_None) {
    DevicePtrInfo global_scratch_info = getPointer(global_scratch_obj, -1);
    if (!global_scratch_info.valid) {
      return NULL;
    }
    global_scratch = global_scratch_info.dev_ptr;
  }

  // raise exception asap
  z
  Py_BEGIN_ALLOW_THREADS;
  _launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, launch_pdl, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (CUstream)_stream, (CUfunction)_function, global_scratchaC  );
  Py_END_ALLOW_THREADS;
  if (PyErr_Occurred()) {
    return NULL;
  }

  if(launch_exit_hook != Py_None){
    PyObject* args = Py_BuildValue("(O)", launch_metadata);
    PyObject* ret = PyObject_CallObject(launch_exit_hook, args);
    Py_DECREF(args);
    if (!ret)
      return NULL;
    Py_DECREF(ret);
  }

  Py_RETURN_NONE;
}

static PyMethodDef ModuleMethods[] = {
  {"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"},
  {NULL, NULL, 0, NULL} // sentinel
};

static struct PyModuleDef ModuleDef = {
  PyModuleDef_HEAD_INIT,
  "__triton_launcher",
  NULL, //documentation
  -1, //size
  ModuleMethods
};

PyMODINIT_FUNC PyInit___triton_launcher(void) {
  PyObject *m = PyModule_Create(&ModuleDef);
  if(m == NULL) {
    return NULL;
  }
  PyModule_AddFunctions(m, ModuleMethods);
  return m;
}
)
values	enumerater   _BASE_ARGS_FORMATrg   itemsr   re   r\   rf   )	constantsrh   rq   rr   Zexpand_signatureZargs_formatformatZflat_signaturerk   Z	args_listZarg_decl_listr   r[   Z	arg_declsZinternal_args_listparamsnewlineZ	ptr_declsZ	tma_declsZfloat_storage_declsr4   r   )r~   ru   r   rq   r   make_launcher|   s    &
,





./               $    %    &    (    Pr   c                   @   s    e Zd ZdZdd Zdd ZdS )TmaDescKernelParam   c                 C   s"   dd l }|j| j|jdd| _d S )Nr   cpurm   device)torchemptyTMA_DESC_SIZEZuint8descr:   r   r   r   r   r<   V  s    zTmaDescKernelParam.__init__c                 C   s
   | j  S r(   )r   data_ptrr:   r   r   r   tma_desc_cpu_ptr[  s    z#TmaDescKernelParam.tma_desc_cpu_ptrN)r=   r>   r?   r   r<   r   r   r   r   r   r   S  s   r   c                 c   s   | ]}||fV  qd S r(   r   )r   r   r   r   r   r   `  r   r      
      	   c              
   C   s   |d u r(| j g| j| j| j| jS |d }|d }|d }|d }|d }| j  }| j}| j}	|	d dksvJ t }
|
g||	}|rt|}|d  d9  < tjjj	j
|
 |||t| |||	 |S )	Nswizzle	elem_size	elem_type
block_size
fp4_paddedr   r]   r^   )basern   stridesr   r   listtritonZruntimeZdriveractiveutilsr9   r   TMA_DTYPE_DEVICE_TO_HOST)argmetadatar   r   r   r   r   r   rn   r   r   resultr   r   r   make_tensordesc_argf  s6     

r   c                    s.   ddl m ddlm   fdd}|S )Nr   )TensorDescriptorc                     s   | d t t }| t td  }d}g }t|D ]J\}}t| frprR| nd }|d7 }|t|| q0|| q0r|t ksJ g ||R  S )Nr   r]   )rg   r   r   r`   extendr   re   )argsZ	meta_argsZraw_kernel_argsrj   Z
final_argsr   r   rl   ZGluonTensorDescriptorr   launcherrq   r   r   inner  s    z%wrap_handle_tensordesc.<locals>.inner)Ztriton.tools.tensor_descriptorr   Z'triton.experimental.gluon.nvidia.hopper)r   rq   r   r   r   r   wrap_handle_tensordesc  s    r   c                   @   s   e Zd Zdd Zdd ZdS )CudaLauncherc                    s   t drjnt }fdd  fdd| D }dd j D }t|dd }t|||tdt t	t
d	}td
d | D }ttj|jd| _|rt|j|n|j| _|j| _|j| _|j| _|j| _d S )Nr   c                    s   t | tr jj| fS | S r(   )r`   r#   fn	arg_namesindex)rt   )r4   r   r   <lambda>  r   z'CudaLauncher.__init__.<locals>.<lambda>c                    s   i | ]\}} ||qS r   r   r   idxvalue)arg_idxr   r   r     r   z)CudaLauncher.__init__.<locals>.<dictcomp>c                 S   s   i | ]\}}||qS r   r   r   r   r   r   r     r   rq   Z__triton_launcherr3   c                 s   s"   | ]}t |to|d V  qdS )rB   N)r`   r#   rZ   )r   rk   r   r   r   r     r   z(CudaLauncher.__init__.<locals>.<genexpr>r]   )r-   r   dictr   rh   getattrr   r   r*   r6   r7   r$   r   	functoolsreduceoperatormulZcluster_dimsnum_ctasr   launchglobal_scratch_sizeglobal_scratch_alignlaunch_cooperative_grid
launch_pdl)r:   r4   r   r   rh   rq   r;   Zhas_tensor_desc_argr   )r   r4   r   r<     s(    zCudaLauncher.__init__c           
   	   G   sd   | j dkr8|| | }|| j | j  }t|| j|}	nd }	| j|||||| j| j|	g|R   d S Nr   )r   r   r   Z
_allocatorr   r   r   r   )
r:   ZgridXZgridYZgridZstreamfunctionr   Z	grid_sizeZ
alloc_sizeZglobal_scratchr   r   r   __call__  s    
zCudaLauncher.__call__N)r=   r>   r?   r<   r   r   r   r   r   r     s   r   c                       sX   e Zd Z fddZdd Zdd Zdd Zed	d
 Zdd Z	dd Z
dd Z  ZS )
CudaDriverc                    s   t  | _t| _t   d S r(   )r+   r   r   Zlauncher_clsr.   r<   r   r1   r   r   r<     s    zCudaDriver.__init__c                 C   s6   |   }| |}|d d |d  }d}td||S )Nr   r   r]       r
   )get_current_deviceZget_device_capabilityr   )r:   r   Z
capabilityZ	warp_sizer   r   r   get_current_target  s
    
zCudaDriver.get_current_targetc                 C   s   dd l }|d|  S )Nr   r
   )r   r   r   r   r   r   r   get_active_torch_device  s    z"CudaDriver.get_active_torch_devicec                 C   s   dd l }|jS r   )r   r
   r   r   r   r   get_device_interface  s    zCudaDriver.get_device_interfacec                  C   s:   z dd l } | j o| jjd u W S  ty4   Y dS 0 d S )Nr   F)r   r
   Zis_availableversionZhipImportError)r   r   r   r   	is_active  s
    zCudaDriver.is_activec                 C   s   ddl m} |S )Nr   )do_bench)Ztriton.testingr   )r:   r   r   r   r   get_benchmarker  s    zCudaDriver.get_benchmarkerc                 C   s&   dd l }d}|jt|d |jddS )Nr   i      r
   r   )r   r   int)r:   r   Z
cache_sizer   r   r   get_empty_cache_for_benchmark  s    z(CudaDriver.get_empty_cache_for_benchmarkc                 C   s   |   d S r(   )Zzero_)r:   cacher   r   r   clear_cache  s    zCudaDriver.clear_cache)r=   r>   r?   r<   r   r   r   staticmethodr   r   r   r   r@   r   r   r1   r   r     s   
	r   )+r   r   r   r   r   ra   pathlibr   r   Ztriton.runtime.buildr   Ztriton.runtimer   Ztriton.backends.compilerr   Ztriton.backends.driverr   r   r   realpath__file__r   r6   r)   r7   	lru_cacher'   r*   objectr+   r\   r   r   r   r   r   r   rf   r   r   r   r   r   r   r   r   r   <module>   s^   

	   Z($