a
    hPE                     @   s  U d Z ddlZddlZddlmZ ddlmZ ddlmZm	Z	m
Z
mZ ddlZe
e	egef  ed< ej r|ddlmZ ndZi Zeeef ed< i Zeeef ed	< G d
d dZG dd dZG dd deZe
e	egef  ed< ej rddlmZ ndZG dd deZeG dd dZG dd deZ G dd deZ!i Z"eee#e f ed< da$eeej%f e#e dddZ&eeej%f e#e dd d!Z'ee(ee#e f  d"d#d$Z)d%d& Z*dS )'a  
Device abstraction layer for TorchDynamo and Inductor backends.

This module provides a unified interface for different hardware backends (CUDA, XPU,
CPU, MPS) through a common device interface. Key components include:

- DeviceInterface: Base class defining the common API for all device types
- Device-specific implementations: CudaInterface, XpuInterface, CpuInterface, MpsInterface
- Device registration system for managing available backends
- Worker APIs for multi-processing scenarios
- Stream and event management across different devices
- Device property caching for worker processes

The abstraction layer enables device-agnostic code in TorchDynamo while allowing
specialized implementations for each hardware backend's unique features.
    N)Iterable)	dataclass)AnyCallableOptionalUnionget_cuda_stream)_cuda_getCurrentRawStream caching_worker_device_propertiescaching_worker_current_devicesc                   @   s  e Zd ZdZG dd dZG dd dZG dd dZG dd	 d	Zed
d Z	ee
jjdddZeeedddZeeedddZedd ZeedddZee
jdddZedd Zee
jdddZeeeed d!d"Zeeed#d$d%Zed:e
jjdd'd(Zed;e
jjdd)d*Zed<e
jjdd+d,Zed=ed.d/d0Zed>e
jeed1d2d3Zed?e
jjedd4d5Z ed@e
jjedd6d7Z!edAe
jjd&dd8d9Z"d&S )BDeviceInterfacez
    This is a simple device runtime interface for Inductor. It enables custom
    backends to be integrated with Inductor in a device-agnostic semantic.
    c                   @   s   e Zd ZejjdddZdS )zDeviceInterface.devicedevicec                 C   s   t d S NNotImplementedErrorclsr    r   L/var/www/auris/lib/python3.9/site-packages/torch/_dynamo/device_interface.py__new__/   s    zDeviceInterface.device.__new__N)__name__
__module____qualname__torchtypesDevicer   r   r   r   r   r   .   s   r   c                   @   s   e Zd Zdd ZdS )zDeviceInterface.Eventc                 O   s   t dd S )NzYEvent should be inherited from torch.Event, otherwise, it couldn't be captured by dynamo.r   r   argskwargsr   r   r   r   3   s    zDeviceInterface.Event.__new__Nr   r   r   r   r   r   r   r   Event2   s   r!   c                   @   s   e Zd Zdd ZdS )zDeviceInterface.Streamc                 O   s   t dd S )Nz[Stream should be inherited from torch.Stream, otherwise, it couldn't be captured by dynamo.r   r   r   r   r   r   9   s    zDeviceInterface.Stream.__new__Nr    r   r   r   r   Stream8   s   r"   c                   @   sL   e Zd ZdZeedddZeedddZedej	j
dd	d
ZdS )zDeviceInterface.Workera  
        Worker API to query device properties that will work in multi processing
        workers that cannot use the GPU APIs (due to processing fork() and
        initialization time issues). Properties are recorded in the main process
        before we fork the workers.
        r   c                 C   s   t d S r   r   r   r   r   r   
set_deviceF   s    z!DeviceInterface.Worker.set_devicereturnc                   C   s   t d S r   r   r   r   r   r   current_deviceJ   s    z%DeviceInterface.Worker.current_deviceNc                 C   s   t d S r   r   r   r   r   r   get_device_propertiesN   s    z,DeviceInterface.Worker.get_device_properties)N)r   r   r   __doc__staticmethodintr#   r&   r   r   r   r'   r   r   r   r   Worker>   s   r+   c                   C   s   t d S r   r   r   r   r   r   r&   R   s    zDeviceInterface.current_devicer   c                 C   s   t d S r   r   r   r   r   r   r#   V   s    zDeviceInterface.set_devicer   r%   c                 C   s   t d S r   r   r   r   r   r   maybe_exchange_deviceZ   s    z%DeviceInterface.maybe_exchange_devicec                 C   s   t d S r   r   r   r   r   r   exchange_device^   s    zDeviceInterface.exchange_devicec                   C   s   t d S r   r   r   r   r   r   device_countb   s    zDeviceInterface.device_countr$   c                   C   s   t d S r   r   r   r   r   r   is_availablef   s    zDeviceInterface.is_availablestreamc                 C   s   t d S r   r   r1   r   r   r   r2   j   s    zDeviceInterface.streamc                   C   s   t d S r   r   r   r   r   r   current_streamn   s    zDeviceInterface.current_streamc                 C   s   t d S r   r   r1   r   r   r   
set_streamr   s    zDeviceInterface.set_streamZ	stream_idZdevice_indexZdevice_typec                 C   s   t d S r   r   r5   r   r   r   _set_stream_by_idv   s    z!DeviceInterface._set_stream_by_id)
device_idxr%   c                 C   s   t d S r   r   r7   r   r   r   get_raw_streamz   s    zDeviceInterface.get_raw_streamNc                 C   s   t d S r   r   r   r   r   r   synchronize~   s    zDeviceInterface.synchronizec                 C   s   | j |S r   )r+   r'   r   r   r   r   r'      s    z%DeviceInterface.get_device_propertiesc                 C   s   t d S r   r   r   r   r   r   get_compute_capability   s    z&DeviceInterface.get_compute_capabilityFincluding_emulationc                 C   s   t d S r   r   r<   r   r   r   is_bf16_supported   s    z!DeviceInterface.is_bf16_supporteddtyper=   r%   c                 C   s   |t jkp| |S r   )r   bfloat16r>   r   r@   r=   r   r   r   is_dtype_supported   s    z"DeviceInterface.is_dtype_supportedc                 C   s   t d S r   r   r   r   r   r   memory_allocated   s    z DeviceInterface.memory_allocatedc                 C   s   dS )z
        Returns True if the device has Triton support, False otherwise, even if
        the appropriate Triton backend is not available.
        Fr   r   r   r   r   is_triton_capable   s    z!DeviceInterface.is_triton_capablec                 C   s   |   stddS )aH  
        Raises a `RuntimeError` with the appropriate human-readable instructions
        to resolve the issue if Triton is not available for the given device, or
        the default device if `device` is `None`.

        The caller should ensure the presence of the 'triton' package before
        calling this method.
        z/This device is not capable of supporting TritonN)rE   RuntimeErrorr   r   r   r   raise_if_triton_unavailable   s    
z+DeviceInterface.raise_if_triton_unavailable)N)N)N)F)F)N)N)N)#r   r   r   r(   r   r!   r"   r+   r)   r&   r   r   r   r#   r*   r-   r.   r/   boolr0   r2   r3   r4   r6   r9   r:   classmethodr'   r;   r>   r@   rC   rD   rE   rG   r   r   r   r   r   (   s^   


 r   c                   @   sD   e Zd ZdZee ee ddddZdd Z	e
e
e
dd	d
ZdS )DeviceGuarda_  
    This class provides a context manager for device switching. This is a stripped
    down version of torch.{device_name}.device.

    The context manager changes the current device to the given device index
    on entering the context and restores the original device on exiting.
    The device is switched using the provided device interface.
    N)device_interfaceindexr%   c                 C   s   || _ || _d| _d S )N)rK   idxprev_idx)selfrK   rL   r   r   r   __init__   s    zDeviceGuard.__init__c                 C   s   | j d ur| j| j | _d S r   )rN   rK   r.   rO   )rP   r   r   r   	__enter__   s    
zDeviceGuard.__enter__)typevalue	tracebackc                 C   s   | j d ur| j| j| _ dS NF)rN   rK   r-   rO   )rP   rS   rT   rU   r   r   r   __exit__   s    
zDeviceGuard.__exit__)r   r   r   r(   rS   r   r   r*   rQ   rR   r   rW   r   r   r   r   rJ      s   
rJ   c                   @   s4  e Zd ZejjZejjZejjZG dd dZe	ejj
Z
e	ejjZe	ejjZe	ejjZe	ejjZe	ejjZe	ejjZe	ejjZe	ejjZe	eZe	ejjZe	ejjZe	ejjZe	ejjZe	edddZe	dejjddd	Ze	dejjed
ddZ e	dejjdd
ddZ!dS )CudaInterfacec                   @   sH   e Zd ZeedddZeedddZed
ejj	ddd	Z
dS )zCudaInterface.Workerr   c                 C   s   | t d< d S Ncudar   r   r   r   r   r#      s    zCudaInterface.Worker.set_devicer$   c                   C   s   dt v rt d S tj S rY   )r   r   rZ   r&   r   r   r   r   r&      s    z#CudaInterface.Worker.current_deviceNc                 C   s   | d ur<t | tr*t| } | jdks*J t | tjr<| j} | d u rNtj } dt	vrvdd t
tj D }|t	d< t	d |  S )NrZ   c                 S   s   g | ]}t j|qS r   )r   rZ   r'   .0ir   r   r   
<listcomp>   s   z>CudaInterface.Worker.get_device_properties.<locals>.<listcomp>)
isinstancestrr   r   rS   rL   rX   r+   r&   r
   rangerZ   r/   r   Zdevice_propr   r   r   r'      s    


z*CudaInterface.Worker.get_device_properties)Nr   r   r   r)   r*   r#   r&   r   r   r   r'   r   r   r   r   r+      s   r+   r$   c                   C   s
   t j S r   )r   rZ   r0   r   r   r   r   r0      s    zCudaInterface.is_availableNr   c                 C   sF   t jjd u r(t j| \}}|d | S t j| jddd S d S )N
   :   r   )r   versionhiprZ   get_device_capabilityr'   ZgcnArchNamesplit)r   majorminr   r   r   r;     s    z$CudaInterface.get_compute_capabilityr,   c                 C   s   t jjd upt j| jdkS )N   )r   rh   ri   rZ   r'   rl   r   r   r   r   rE     s    zCudaInterface.is_triton_capablec                 C   sr   ddl m} t| s0tj| }||t dd l	}tj
jd urZd|jjvrntdnd|jjvrntdd S )Nr   )GPUTooOldForTritonZamdz'triton not built with the 'amd' backendZnvidiaz*triton not built with the 'nvidia' backend)Ztorch._inductor.excro   rX   rE   r   rZ   r'   inspectcurrentframetriton.backendsrh   ri   backendsrF   )r   ro   Zdevice_propstritonr   r   r   rG     s    

z)CudaInterface.raise_if_triton_unavailable)N)N)N)"r   r   r   r   rZ   r   r!   r"   r+   r)   r&   r#   r/   r2   r3   r4   r6   r:   r'   r   r9   _exchange_devicer.   _maybe_exchange_devicer-   rD   r>   rH   r0   r   r   r;   rE   rG   r   r   r   r   rX      s4   rX   get_xpu_stream)_xpu_getCurrentRawStreamc                   @   s>  e Zd ZejjZejjZejjZG dd dZe	ejj
Z
e	ejjZe	ejjZe	ejjZe	ejjZe	ejjZe	ejjZe	ejjZe	ejjZe	eZe	ejjZe	ejjZe	ejjZe	edddZe	dejjddd	Ze	deedddZe	dejjedddZ e	dejjddddZ!dS )XpuInterfacec                   @   sH   e Zd ZeedddZeedddZed
ejj	ddd	Z
dS )zXpuInterface.Workerr   c                 C   s   | t d< d S Nxpur[   r   r   r   r   r#   1  s    zXpuInterface.Worker.set_devicer$   c                   C   s   dt v rt d S tj S rz   )r   r   r{   r&   r   r   r   r   r&   5  s    z"XpuInterface.Worker.current_deviceNc                 C   s   | d ur<t | tr*t| } | jdks*J t | tjr<| j} | d u rNtj } dt	vrvdd t
tj D }|t	d< t	d |  S )Nr{   c                 S   s   g | ]}t j|qS r   )r   r{   r'   r\   r   r   r   r_   G  s   z=XpuInterface.Worker.get_device_properties.<locals>.<listcomp>)r`   ra   r   r   rS   rL   ry   r+   r&   r
   rb   r{   r/   rc   r   r   r   r'   ;  s    


z)XpuInterface.Worker.get_device_properties)Nrd   r   r   r   r   r+   0  s   r+   r$   c                   C   s
   t j S r   )r   r{   r0   r   r   r   r   r0   ^  s    zXpuInterface.is_availableNr   c                 C   s   t j| }|S r   )r   r{   rj   )r   ccr   r   r   r;   b  s    z#XpuInterface.get_compute_capabilityFr=   r%   c                 C   s
   t j S r   )r   r{   r>   r<   r   r   r   r>   g  s    zXpuInterface.is_bf16_supportedr,   c                 C   s   dS NTr   r   r   r   r   rE   k  s    zXpuInterface.is_triton_capable)evicer%   c                 C   s    dd l }d|jjvrtdd S )Nr   intelz)triton not built with the 'intel' backendrr   rs   rF   )r   rt   r   r   r   rG   o  s    z(XpuInterface.raise_if_triton_unavailable)N)F)N)N)"r   r   r   r   r{   r   r!   r"   r+   r)   r&   r#   r/   r2   r3   r4   r6   r:   r'   rw   r9   ru   r.   rv   r-   rD   rH   r0   r   r   r;   r>   rE   rG   r   r   r   r   ry   +  s6   ry   c                   @   s   e Zd ZU eed< dS )CpuDevicePropertiesZmulti_processor_countN)r   r   r   r*   __annotations__r   r   r   r   r   w  s   
r   c                   @   s   e Zd ZG dd dejZG dd dZeedddZeded	d
dZ	edej
jedddZeedddZedd Zedej
jdddZedej
jedddZedej
jddddZdS ) CpuInterfacec                   @   s.   e Zd Zd
ddZedddZddd	ZdS )zCpuInterface.EventTc                 C   s
   d| _ d S )Ng        time)rP   Zenable_timingr   r   r   rQ   ~  s    zCpuInterface.Event.__init__r$   c                 C   s   |j | j  d S )Ni  r   )rP   Z	end_eventr   r   r   elapsed_time  s    zCpuInterface.Event.elapsed_timeNc                 C   s   t  | _ d S r   )r   perf_counter)rP   r2   r   r   r   record  s    zCpuInterface.Event.record)T)N)r   r   r   rQ   floatr   r   r   r   r   r   r!   }  s   
r!   c                   @   s$   e Zd ZedejjdddZdS )zCpuInterface.WorkerNr   c                 C   s   dd l }| }t|S Nr   )multiprocessing	cpu_countr   )r   r   r   r   r   r   r'     s    z)CpuInterface.Worker.get_device_properties)N)r   r   r   r)   r   r   r   r'   r   r   r   r   r+     s   r+   r$   c                   C   s   dS r~   r   r   r   r   r   r0     s    zCpuInterface.is_availableFr<   c                 C   s   dS r~   r   r<   r   r   r   r>     s    zCpuInterface.is_bf16_supportedNr,   c                 C   s   dS N r   r   r   r   r   r;     s    z#CpuInterface.get_compute_capabilityc                 C   s   dS r   r   r8   r   r   r   r9     s    zCpuInterface.get_raw_streamc                   C   s   dS r   r   r   r   r   r   r&     s    zCpuInterface.current_devicer   c                 C   s   d S r   r   r   r   r   r   r:     s    zCpuInterface.synchronizec                 C   s   dS r~   r   r   r   r   r   rE     s    zCpuInterface.is_triton_capablec                 C   s    dd l }d|jjvrtdd S )Nr   cpuz'triton not built with the 'cpu' backendr   )r   rt   r   r   r   rG     s    z(CpuInterface.raise_if_triton_unavailable)F)N)N)N)N)r   r   r   r   r!   r+   r)   rH   r0   r>   r   r   ra   r;   r*   r9   r&   r:   rE   rG   r   r   r   r   r   |  s$   

r   c                   @   s   e Zd ZedeedddZedejeedddZ	eedd	d
Z
edd ZedejjedddZedejjdddZG dd dZdS )MpsInterfaceFr}   c                 C   s   t jjddS )N   r   )r   rs   mpsZis_macos_or_newerr<   r   r   r   r>     s    zMpsInterface.is_bf16_supportedr?   c                 C   s(   |t jt jfv rdS |t jkp&| |S rV   )r   float64Z
complex128rA   r>   rB   r   r   r   rC     s    zMpsInterface.is_dtype_supportedr$   c                   C   s   t jj S r   )r   rs   r   r0   r   r   r   r   r0     s    zMpsInterface.is_availablec                   C   s   dS r   r   r   r   r   r   r&     s    zMpsInterface.current_deviceNr,   c                 C   s   dS r   r   r   r   r   r   r;     s    z#MpsInterface.get_compute_capabilityr   c                 C   s   t j  d S r   )r   r   r:   r   r   r   r   r:     s    zMpsInterface.synchronizec                   @   s0   e Zd ZedejjdddZedd ZdS )zMpsInterface.WorkerNr   c                 C   s   i S r   r   r   r   r   r   r'     s    z)MpsInterface.Worker.get_device_propertiesc                   C   s   dS r   r   r   r   r   r   r&     s    z"MpsInterface.Worker.current_device)N)	r   r   r   r)   r   r   r   r'   r&   r   r   r   r   r+     s   r+   )F)F)N)N)r   r   r   r)   rH   r>   rI   r   r@   rC   r0   r&   r   r   ra   r;   r:   r+   r   r   r   r   r     s"    
r   device_interfacesFr   rK   c                 C   s   t | tjr| j} |t| < d S r   )r`   r   r   rS   r   r   r   r   r   register_interface_for_device  s    r   r,   c                 C   s>   t | tjr| j} tst  | tv r,t|  S td|  d S )NzNo interface for device )r`   r   r   rS   _device_initializedinit_device_regr   r   r   r   r   r   get_interface_for_device  s    r   r$   c                   C   s   t s
t  t S r   )r   r   r   itemsr   r   r   r    get_registered_device_interfaces  s    r   c                  C   sx   t dt ttj D ]} t d|  t qt dt ttj D ]} t d|  t qFt dt t dt	 da
d S )NrZ   zcuda:r{   zxpu:r   r   T)r   rX   rb   r   rZ   r/   ry   r{   r   r   r   )r^   r   r   r   r     s    



r   )+r(   rp   r   collections.abcr   Zdataclassesr   typingr   r   r   r   r   r*   r   rZ   Z_is_compiledZtorch._Cr	   r   r
   dictra   r   r   rJ   rX   r{   rx   rw   ry   r   r   r   r   rS   r   r   r   r   tupler   r   r   r   r   r   <module>   s@   
 [L7'
