a
    ¢º”hÎ·  ã                   @   s  U d Z ddlZddlZddlZddlZddlZddlZddlmZ ddl	m
Z
 ddl	mZ ddlmZ ddlmZ dd	lmZmZ dd
lmZ ddlZddlZG dd„ deƒZG dd„ dƒZeeef ZdZi Zeed< dddœZg d¢ZG dd„ de ƒZ!dd„ Z"G dd„ dƒZ#G dd„ dƒZ$eedœdd„Z%eee&dœd d!„Z'd"d#„ Z(dneeeee&e&ee d&œd'd(„Z)eeeeeee*f e&e&e$e&dd)œ
d*d+„Z+d,d-„ Z,d.d/„ Z-e .d0¡Z/d1d2„ Z0d3d4„ Z1d5d6„ Z2d7d8„ Z3e .d9¡Z4d:d;„ Z5e .d<¡Z6d=d>„ Z7e .d?¡Z8d@dA„ Z9dodBdC„Z:dDdE„ Z;dFdG„ Z<dHdI„ Z=dJdK„ Z>dLdM„ Z?G dNdO„ dOƒZ@G dPdQ„ dQƒZAeAƒ ZBi ZCeAƒ ZDi ZEeeeFf edR< i ZGe
D ]ªZHeIeHeƒsŽJ ‚eH J¡ D ]Š\ZKZLeLd ZMeLdd… ZNejOeNvrôeD PeK¡ ejQeNv rìeE ReKdS¡rìeMeGeK< neMeEeK< ejSeNvr–ejQeNvr–eB PeK¡ eMeCeK< q–qze .eB T¡ ¡ZUe .dTeD T¡ › dU¡ZVe .dV¡ZWe .dW¡ZXe .dX¡ZYe .dY¡ZZeeeeeee*f e&e&e$e&ed)œ
dZd[„Z[dpd\d]„Z\d^d_„ Z]d`da„ Z^e .db¡Z_dcdd„ Z`dedf„ Zadqee&eeeeeee&ee&e&e&e&ee$ edkœdldm„ZbdS )ra   The Python Hipify script.
##
# Copyright (c) 2015-2016 Advanced Micro Devices, Inc. All rights reserved.
#               2017-2018 Advanced Micro Devices, Inc. and
#                         Facebook Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
é    Né   )Ú	constants)ÚCUDA_TO_HIP_MAPPINGS)ÚMATH_TRANSPILATIONS)ÚOptional)ÚIterator)ÚMappingÚIterable)ÚEnumc                   @   s   e Zd ZdZdZdS )ÚCurrentStater   é   N)Ú__name__Ú
__module__Ú__qualname__ÚINITIALIZEDÚDONE© r   r   úN/var/www/auris/lib/python3.9/site-packages/torch/utils/hipify/hipify_python.pyr   -   s   r   c                   @   s   e Zd Zdd„ Zdd„ ZdS )ÚHipifyResultc                 C   s   || _ || _d| _d S )NÚ ©Úcurrent_stateÚhipified_pathÚstatus)Úselfr   r   r   r   r   Ú__init__2   s    zHipifyResult.__init__c                 C   s   d| j › d| j› d| j› S )NzHipifyResult:: current_state: z, hipified_path : z
, status: r   ©r   r   r   r   Ú__str__7   s    zHipifyResult.__str__N)r   r   r   r   r   r   r   r   r   r   1   s   r   z;// !!! This is a file automatically generated by hipify!!!
ÚHIPIFY_FINAL_RESULTZscalar_t)ZDtypeÚT)!Ú
InputErrorÚopenfÚbcolorsÚGeneratedFileCleanerÚmatch_extensionsÚmatched_files_iterÚpreprocess_file_and_save_resultÚcompute_statsÚadd_dim3ÚprocessKernelLaunchesÚfind_closure_groupÚfind_bracket_groupÚfind_parentheses_groupÚreplace_math_functionsÚhip_header_magicÚreplace_extern_sharedÚget_hip_file_pathÚis_out_of_placeÚis_pytorch_fileÚis_cusparse_fileÚis_special_fileÚis_caffe2_gpu_filer5   ÚTrieÚpreprocessorÚfile_specific_replacementÚfile_add_headerÚfix_static_global_kernelsÚextract_argumentsÚstr2boolr   r   Úhipifyc                       s$   e Zd Z‡ fdd„Zdd„ Z‡  ZS )r    c                    s   t ƒ  |¡ || _d S ©N)Úsuperr   Úmessage)r   r@   ©Ú	__class__r   r   r   N   s    zInputError.__init__c                 C   s   d| j › S )NzInput error: )r@   r   r   r   r   r   R   s    zInputError.__str__)r   r   r   r   r   Ú__classcell__r   r   rA   r   r    K   s   r    c                 C   s   t | |ddS )NÚignore)Úerrors)Úopen)ÚfilenameÚmoder   r   r   r!   V   s    r!   c                   @   s,   e Zd ZdZdZdZdZdZdZdZ	dZ
d	S )
r"   z[95mz[94mz[92mz[93mz[91mz[0mz[1mz[4mN)r   r   r   ÚHEADERZOKBLUEÚOKGREENÚWARNINGÚFAILÚENDCZBOLDZ	UNDERLINEr   r   r   r   r"   [   s   r"   c                   @   s<   e Zd ZdZddd„Zdd„ Zdd„ Zdd	d
„Zdd„ ZdS )r#   z+Context Manager to clean up generated filesFc                 C   s   || _ tƒ | _g | _d S r>   )Úkeep_intermediatesÚsetÚfiles_to_cleanÚdirs_to_clean)r   rN   r   r   r   r   p   s    zGeneratedFileCleaner.__init__c                 C   s   | S r>   r   r   r   r   r   Ú	__enter__u   s    zGeneratedFileCleaner.__enter__c                 O   s6   t j |¡s | j t j |¡¡ t|g|¢R i |¤ŽS r>   )ÚosÚpathÚexistsrP   ÚaddÚabspathrF   )r   ÚfnÚargsÚkwargsr   r   r   rF   x   s    zGeneratedFileCleaner.openc                 C   sx   t j |¡\}}|s$t j |¡\}}|rF|rFt j |¡sF| j|dd t j |¡rV|stt  |¡ | j t j 	|¡¡ d S )NT)Úexist_ok)
rS   rT   ÚsplitrU   ÚmakedirsÚisdirÚmkdirrQ   ÚappendrW   )r   Údnr[   ÚparentÚnr   r   r   r]   }   s    
zGeneratedFileCleaner.makedirsc                 C   s@   | j s<| jD ]}t |¡ q| jd d d… D ]}t |¡ q,d S )Néÿÿÿÿ)rN   rP   rS   ÚunlinkrQ   Úrmdir)r   ÚtypeÚvalueÚ	tracebackÚfÚdr   r   r   Ú__exit__‡   s
    
zGeneratedFileCleaner.__exit__N)F)F)	r   r   r   Ú__doc__r   rR   rF   r]   rl   r   r   r   r   r#   n   s   


r#   )rT   Úreturnc                 C   s   |   tjd¡S )Nú/)ÚreplacerS   Úsep)rT   r   r   r   Ú_to_unix_path   s    rr   )rG   Ú
extensionsrn   c                    s   t ‡ fdd„|D ƒƒS )z<Helper method to see if filename ends with certain extensionc                 3   s   | ]}ˆ   |¡V  qd S r>   ©Úendswith)Ú.0Úe©rG   r   r   Ú	<genexpr>”   ó    z#match_extensions.<locals>.<genexpr>©Úany)rG   rs   r   rx   r   r$   ’   s    r$   c                    s   t ‡ fdd„|D ƒƒS )Nc                 3   s   | ]}t   ˆ |¡V  qd S r>   )Úfnmatch)rv   Úpattern©Úfilepathr   r   ry   ˜   rz   z_fnmatch.<locals>.<genexpr>r{   )r€   Úpatternsr   r   r   Ú_fnmatch—   s    r‚   r   F)Ú	root_pathÚincludesÚignoresrs   Úout_of_place_onlyÚis_pytorch_extensionrn   c                 c   sú   t |ƒ}tj| ddD ]Þ\}}}	tj || ¡}
|
dkrvd|v rH| d¡ d|v rZ| d¡ d|v rv| d¡ | d¡ |	D ]x}ttj ||¡ƒ}ttj |
|¡ƒ}t	||ƒrzt	||ƒszt
||ƒsÈ||v rz|sìt|ƒsÞt|ƒsÞqz|rìt|ƒsìqz|V  qzqd S )NT)ÚtopdownÚ.z.gitÚbuildZthird_partyzthird_party/nvfuser)rO   rS   ÚwalkrT   ÚrelpathÚremover`   rr   Újoinr‚   r$   r2   r5   r1   )rƒ   r„   r…   rs   r†   r‡   Zexact_matchesZabs_dirpathÚdirsÚ	filenamesZrel_dirpathrG   r€   Úrel_filepathr   r   r   r%   ›   s8    



ÿþýýr%   )
Úoutput_directoryr€   Ú	all_filesÚheader_include_dirsÚstatsÚhip_clang_launchr‡   Ú	clean_ctxÚshow_progressrn   c	              
   C   st   t j t j | |¡¡}	ttj|	d}
|
t|	< t| ||||||||ƒ	}|rhd|j	vrht
|	d|j|j	dd |t|	< d S )N)r   r   Zignoredz->T)Úflush)rS   rT   rW   rŽ   r   r   r   r   r7   r   Úprintr   )r’   r€   r“   r”   r•   r–   r‡   r—   r˜   Úfin_pathÚhipify_resultÚresultr   r   r   r&   Ç   s    
ÿ
þr&   c                 C   sP   dd„ | d D ƒ}t dt|ƒd›ƒ t d |¡ƒ t dt| d ƒd›ƒ d S )	Nc                 S   s   h | ]\}}|’qS r   r   )rv   Z	cuda_callZ	_filepathr   r   r   Ú	<setcomp>á   rz   z compute_stats.<locals>.<setcomp>Úunsupported_callsz1Total number of unsupported CUDA function calls: rk   ú, z+
Total number of replaced kernel launches: Úkernel_launches)rš   ÚlenrŽ   )r•   rŸ   r   r   r   r'   à   s    r'   c                 C   s¦  d}d}|   dd¡  dd¡} dd„ tdƒD ƒ}d|| d< t| ƒD ]Š\}}|d	krV qÎ|d
krh|d	7 }n|dkrx|d	8 }|dks|t| ƒd	 krB|dkrB||dk || d< |d	7 }|dk rB|d	 || d< qB| |d d |d d d	 … }| |d	 d |d	 d … }| |d d |d d …   dd¡ d¡}	| |d	 d |d	 d …   dd¡ d¡}
d|	› d}d|
› d}|  |	|¡}|  |
|¡}|  || || ¡}|S )zBadds dim3() to the second and third arguments in the kernel launchr   ú<<<r   ú>>>c                 S   s   g | ]}i ‘qS r   r   )rv   Ú_r   r   r   Ú
<listcomp>ò   rz   zadd_dim3.<locals>.<listcomp>r   Ústartr   ú(ú)ú,ÚendÚ
ú zdim3()rp   ÚrangeÚ	enumerater¢   Ústrip)Úkernel_stringÚcuda_kernelÚcountÚclosureZarg_locsÚindÚcZfirst_arg_rawZsecond_arg_rawZfirst_arg_cleanZsecond_arg_cleanZfirst_arg_dim3Zsecond_arg_dim3Zfirst_arg_raw_dim3Zsecond_arg_raw_dim3r   r   r   r(   í   s6    
  **r(   z([ ]+)(detail?)::[ ]+\\\n[ ]+c                    sV  t  dd„ ˆ ¡‰ ‡ fdd„}dd„ }dd„ }t||ˆ ƒƒƒ}ˆ }|D ]
}||ƒ}ˆ  d	|d
 ¡}	ˆ |d d |	d … }
ˆ |d |d
 … }|d d
 dkr¢dnd}ˆ |d d || d
 d … }t||
ƒ}ttd|d  dd	¡ dd¡ƒƒ}d|dd…  ddd|  d ¡ dd¡ dd¡ |d	| d ¡ }| |
|¡}|d  |¡ qD|S )zK Replace the CUDA style Kernel launches with the HIP style kernel launches.c                 S   s   |   d¡› |   d¡› dS )Nr   r   z::©Úgroup©Úinpr   r   r   Ú<lambda>  rz   z'processKernelLaunches.<locals>.<lambda>c           
         s„  | d | d dœdddœdddœdœ}ddi}d}d}d	}d
}|}t |d d d ddƒD ]"}ˆ | }	|||fv rà|	dkr¤||kr”|}||d d< |d  d7  < |	dkrà|d  d8  < |d dkrà||krà||d d< |}||krZˆ |  ¡ sˆ | dv rP||kr|}||d d< |dkr~d|d d< |d |d |d g  S qZ||krZ||d d< |d |d |d g  S qZd S )Nr§   r«   ©r§   r«   rd   )Úkernel_launchÚkernel_nameÚtemplatez<>r   r   r   é   r½   ú>r¿   ú<>   ú#ú:r©   r¨   r¥   r¾   )r®   Úisalnum)
Z	in_kernelÚposr³   ÚSTARTZAT_TEMPLATEZAFTER_TEMPLATEZAT_KERNEL_NAMEr   ÚiÚchar©Ústringr   r   Úgrab_method_and_template  sD    ý

z7processKernelLaunches.<locals>.grab_method_and_templatec                 S   sd   d}g }|   d|¡dkr`|   d|¡}|   d|¡d }|dkrDtdƒ‚| ||| ||… dœ¡ q|S )zKFinds the starting and ending points for all kernel launches in the string.r   r£   rd   r¤   rÀ   zno kernel end found)r§   r«   r¸   )Úfindr    r`   )rË   Z
kernel_endZkernel_positionsZkernel_startr   r   r   Úfind_kernel_boundsS  s    
ÿ
z1processKernelLaunches.<locals>.find_kernel_boundsc                 S   sâ   d}d}d}| D ]Ì}|dkrf|dkr2|dkr2d}q¾|dkrH|dkrHd}q¾|dkr¾|dkr¾|dkr¾d}nX|dkr„|d	ks~|d
kr¾d}n:|dkr¢|dkr¾|dkr¾d}n|dkr¾|dkr¾|dkr¾d}|}|dkrÔ||7 }q|d7 }q|S )Nr   ro   z//Ú*z/*ú"ú\ú'úr¬   Úxr   )rË   Z
in_commentZprev_cZ
new_stringr¶   r   r   r   Úmask_commentsk  s2    

z,processKernelLaunches.<locals>.mask_commentsr¨   r«   r   r§   r   rd   r¸   r£   r¤   r©   zhipLaunchKernelGGL(z, 0é   r    r¡   )	ÚRE_KERNEL_LAUNCHÚsubÚlistrÍ   r(   r¢   r;   rp   r`   )rË   r•   rÌ   rÎ   rÕ   Zget_kernel_positionsÚoutput_stringÚkernelÚparamsZparenthesisr²   r±   Zend_param_indexZkernel_name_with_templateZcuda_kernel_dim3Znum_klpZ
hip_kernelr   rÊ   r   r)     s6    ;!
 
"ÿÿþþr)   c                 C   sŽ   d}d}|}d\}}|t | ƒk rŠ| | |d krP|du rFd}d}|}q€|d7 }n0| | |d kr€|r€|d8 }|dkr€|}||fS |d7 }qdS )aÊ  Generalization for finding a balancing closure group

         if group = ["(", ")"], then finds the first balanced parentheses.
         if group = ["{", "}"], then finds the first balanced bracket.

    Given an input string, a starting position in the input string, and the group type,
    find_closure_group returns the positions of group[0] and group[1] as a tuple.

    Example:
        >>> find_closure_group("(hi)", 0, ["(", ")"])
        (0, 3)
    Fr   )rd   rd   Tr   )NN)r¢   )Úinput_stringr§   r¸   Zinside_parenthesisÚparensrÆ   Zp_startZp_endr   r   r   r*   ­  s$    

r*   c                 C   s   t | |ddgdS )z%Finds the first balanced parantheses.Ú{Ú}r·   ©r*   ©rÝ   r§   r   r   r   r+   Ó  s    r+   c                 C   s   t | |ddgdS )z!Finds the first balanced bracket.r¨   r©   r·   rá   râ   r   r   r   r,   Ø  s    r,   z\bassert[ ]*\(c                 C   s.   | }t D ] }| |› dt | › d¡}q|S )a‹  FIXME: Temporarily replace std:: invocations of math functions
        with non-std:: versions to prevent linker errors NOTE: This
        can lead to correctness issues when running tests, since the
        correct version of the math function (exp/expf) might not get
        called.  Plan is to remove this function once HIP supports
        std:: math function calls inside device code

    r¨   )r   rp   )rÝ   rÚ   Úfuncr   r   r   r-   à  s    	r-   z:?:?\b(__syncthreads)\b(\w*\()c                    sh   | ‰ ddg}t ‡ fdd„|D ƒƒr&ˆ S dˆ v }|dˆ v 7 }|dˆ v 7 }|t ˆ ¡du7 }|rdd	|  ‰ ˆ S )
a  If the file makes kernel builtin calls and does not include the cuda_runtime.h header,
    then automatically add an #include to match the "magic" includes provided by NVCC.
    TODO:
        Update logic to ignore cases where the cuda_runtime.h is included by another file.
    zhip/hip_runtime.hzhip/hip_runtime_api.hc                 3   s(   | ] }t  d |› d|› dˆ ¡V  qdS )z#include ("z"|<z>)N)ÚreÚsearch)rv   Úext©rÚ   r   r   ry   ÿ  rz   z#hip_header_magic.<locals>.<genexpr>ZhipLaunchKernelGGLÚ
__global__Z
__shared__Nz#include "hip/hip_runtime.h"
)r|   ÚRE_SYNCTHREADSrå   )rÝ   ÚheadersZhasDeviceLogicr   rç   r   r.   ó  s    r.   zGextern\s+([\w\(\)]+)?\s*__shared__\s+([\w:<>\s]+)\s+(\w+)\s*\[\s*\]\s*;c                 C   s   | }t  dd„ |¡}|S )a€  Match extern __shared__ type foo[]; syntax and use HIP_DYNAMIC_SHARED() MACRO instead.
       https://github.com/ROCm/hip/blob/master/docs/markdown/hip_kernel_language.md#__shared__
    Example:
        "extern __shared__ char smemChar[];" => "HIP_DYNAMIC_SHARED( char, smemChar)"
        "extern __shared__ unsigned char smem[];" => "HIP_DYNAMIC_SHARED( unsigned char, my_smem)"
    c                 S   s.   d|   d¡pd› d|   d¡› d|   d¡› dS )	NzHIP_DYNAMIC_SHARED(r   r   r­   r   r    rÀ   r©   r·   r¹   r   r   r   r»     rz   z'replace_extern_shared.<locals>.<lambda>)ÚRE_EXTERN_SHAREDrØ   )rÝ   rÚ   r   r   r   r/     s
    ÿr/   c                 C   sð   t j | ¡rJ ‚|s t| ƒs | S t j | ¡\}}t j |¡\}}|dkrLd}|}|}| dd¡}| dd¡}| dd¡}| dd¡}| dd¡}|d	kr¤| dd¡}|s¾||kr¾t j |d¡}|rÞ||krÞ|| |krÞ|d
 }t j ||| ¡S )z3
    Returns the new name of the hipified file
    ú.cuú.hipZcudaZhipÚCUDAÚHIPÚTHCÚTHHzcaffe2/coreZ_hip)rS   rT   Úisabsr1   r\   Úsplitextrp   rŽ   )r‘   r‡   ÚdirpathrG   Úrootræ   Úorig_filenameZorig_dirpathr   r   r   r0   !  s*    $r0   c                 C   s>   t j | ¡rJ ‚|  d¡rdS |  d¡r,dS |  d¡r:dS dS )Nútorch/Fúthird_party/nvfuser/útools/autograd/templates/T©rS   rT   rò   Ú
startswith©r‘   r   r   r   r1   i  s    


r1   c                 C   sZ   t j | ¡rJ ‚|  d¡r,|  d¡r(dS dS |  d¡r:dS |  d¡rHdS |  d¡rVdS dS )Nzaten/zaten/src/ATen/core/FTr÷   rø   rù   rú   rü   r   r   r   r2   u  s    




r2   c                 C   s   t | ƒrd|  ¡ v S dS )NÚsparseF©r2   Úlowerrü   r   r   r   r3   „  s    r3   c                 C   s<   t | ƒr8d|  ¡ v rdS d|  ¡ v r8d|  ¡ v r4dS dS dS )Nrý   TZlinalgZbatchlinearalgebralibblasFrþ   rü   r   r   r   r4   Š  s    r4   c                 C   sR   t j | ¡rJ ‚|  d¡rdS t j | ¡}t j |¡\}}d|v sJ|dv oPd|vS )Nzc10/cudaTZgpu©rì   ú.cuhZcudnn)rS   rT   rò   rû   Úbasenameró   )r‘   rG   r¥   ræ   r   r   r   r5   ”  s    
r5   c                   @   s   e Zd ZdZdd„ ZdS )ÚTrieNodezA Trie node whose children are represented as a directory of char: TrieNode.
       A special char '' represents end of word
    c                 C   s
   i | _ d S r>   )Úchildrenr   r   r   r   r   ¡  s    zTrieNode.__init__N)r   r   r   rm   r   r   r   r   r   r  œ  s   r  c                   @   sV   e Zd ZdZdd„ Zdd„ Zdd„ Zdd	„ Zd
d„ Ze	j
dd„ ƒZdd„ Zdd„ ZdS )r6   z£Creates a Trie out of a list of words. The trie can be exported to a Regex pattern.
    The corresponding Regex should match much faster than a simple Regex union.c                 C   s&   t ƒ | _tjdd| _| j ¡ | _dS )z,Initialize the trie with an empty root node.F)ÚusedforsecurityN)r  rõ   ÚhashlibÚmd5Ú_hashÚdigestÚ_digestr   r   r   r   r   ¨  s    zTrie.__init__c                 C   sT   | j  | ¡ ¡ | j  ¡ | _| j}|D ]}|j |tƒ ¡ |j| }q&d|jd< dS )zAdd a word to the Trie. Tr   N)	r  ÚupdateÚencoder	  r
  rõ   r  Ú
setdefaultr  ©r   ÚwordÚnoderÉ   r   r   r   rV   ®  s    zTrie.addc                 C   s   | j S )zReturn the root node of Trie. )rõ   r   r   r   r   Údump¹  s    z	Trie.dumpc                 C   s
   t  |¡S )z Escape a char for regex. )rä   Úescape)r   rÉ   r   r   r   Úquote½  s    z
Trie.quotec                 C   s6   | j }|D ] }||jv r$|j| }q
 dS q
d|jv S )zZSearch whether word is present in the Trie.
        Returns True if yes, else return FalseFr   )rõ   r  r  r   r   r   rå   Á  s    
zTrie.searchc              	   C   sH  |}d|j v r$t|j  ¡ ƒdkr$dS g }g }d}t|j  ¡ ƒD ]j}t|j | tƒr¤z,|  |j | | j¡}| |  	|¡| ¡ W q¨ t
y    | |  	|¡¡ Y q¨0 q>d}q>t|ƒdk }	t|ƒdkrøt|ƒdkrà| |d ¡ n| dd |¡ d ¡ t|ƒdkr|d }
ndd |¡ d	 }
|rD|	r8|
d
7 }
nd|
› d}
|
S )zžConvert a Trie into a regular expression pattern

        Memoized on the hash digest of the trie, which is built incrementally
        during add().
        r   r   Nr   ú[ú]z(?:ú|r©   ú?z)?)r  r¢   ÚkeysÚsortedÚ
isinstancer  Ú_patternr
  r`   r  Ú	ExceptionrŽ   )r   rõ   r	  r  ZaltÚccÚqrÉ   ZrecurseZcconlyr   r   r   r   r  Î  s6    

zTrie._patternc                 C   s   |   | j| j¡S ©z#Export the Trie to a regex pattern.©r  rõ   r
  r   r   r   r   r~   ú  s    zTrie.patternc                 C   s   |   | j| j¡S r  r   r   r   r   r   Úexport_to_regexþ  s    zTrie.export_to_regexN)r   r   r   rm   r   rV   r  r  rå   Ú	functoolsÚ	lru_cacher  r~   r!  r   r   r   r   r6   ¤  s   
+r6   ÚPYTORCH_MAPr   z(?<=\W)(z)(?=\W)z#include "([^"]+)"z#include <([^>]+)>z"#define THC_GENERIC_FILE "([^"]+)"z\.cu\bc	                    sÊ  t j t j ˆ|¡¡‰t|ƒ}tˆ }	|ˆ vrFd|	_d|	_tj	|	_
|	S tt j |ˆ¡ƒ}
tˆddT}| ¡ tkr˜d|	_d|	_tj	|	_
|	W  d  ƒ S | d¡ | ¡ }W d  ƒ n1 s¾0    Y  |}t j t j ˆt|
ˆƒ¡¡}t j t j |¡¡sˆ t j |¡¡ dd„ ‰‡fd	d
„}ˆr8t ˆ|¡}nDt|
ƒrPt ||¡}n,t|
ƒrht ˆ|¡}ndd„ }t ||¡}d'‡ ‡‡‡‡‡‡‡‡	f	dd„	}t |ddƒ|¡}t |ddƒ|¡}t |dƒ|¡}| d¡rþ| dd¡}| dd¡}t d|¡}ˆst |ˆ	ƒ}| d¡r,d|vr,t!|ƒ}t"|ƒ}ˆrv||krvt j ˆ¡t j |¡krvˆ|	_d|	_tj	|	_
|	S ˆ|kr”t#ˆdƒr”t| }d}t j |¡ràt|dd}| ¡ |k}W d  ƒ n1 sÖ0    Y  |r®zVˆj|ddd}| $|¡ W d  ƒ n1 s0    Y  ||	_d|	_tj	|	_
|	W S  t%yª } zTt&t'j(› d |› d!|j)› d"ˆ› d#t'j*› 	t+j,d$ ˆ|	_d%|	_tj	|	_
|	W  Y d}~S d}~0 0 n||	_d&|	_tj	|	_
|	S dS )(z< Executes the CUDA -> HIP conversion on the specified file. Nz[ignored, not to be hipified]zutf-8)Úencodingz#[ignored, input is hipified output]r   c                 S   s   t |  d¡ S ©Nr   )r$  r¸   ©Úmr   r   r   Úpt_repl[  s    zpreprocessor.<locals>.pt_replc                    s   t  |  d¡ˆ | ƒ¡S r&  )ÚPYTORCH_SPECIAL_MAPÚgetr¸   r'  )r)  r   r   Úpt_special_repl^  s    z%preprocessor.<locals>.pt_special_replc                 S   s   t |  d¡ S r&  )Ú
CAFFE2_MAPr¸   r'  r   r   r   Úc2_replk  s    zpreprocessor.<locals>.c2_replTc                    s$   ‡‡‡‡‡‡ ‡‡‡	‡
‡fdd„}|S )Nc              
      sà  |   d¡}tj |¡‰ | d¡s4| d¡rJ| d¡sJˆ t|   d¡ˆƒ¡S ˆrÖt‡ fdd„ˆD ƒƒrÖd }d }ˆrªtj ˆ¡}tj 	tj 
||¡¡}tj |¡rª|}|}|d u rôˆD ]<}tj 
ˆ|¡}tj 	tj 
||¡¡}tj |¡r¶|}|}q¶|d u r|   d¡S |tvr,tˆ|ˆˆˆ
ˆˆˆˆ	ƒ	 nz|tv r¦t| }|jtjkr¦tj |ˆ¡}tj 	tj 
ˆt|ˆƒ¡¡}	|	|_|t|< ˆ tj |	d urœ|	n||¡¡S t| j}
ˆ ttj |
d urÊ|
n||¡ƒ¡S |   d¡S )Nr   )z	ATen/cudazATen/native/cudazATen/native/nested/cudazATen/native/quantized/cudazATen/native/sparse/cudazATen/native/transformers/cudazTHC/rð   ZTHCPc                 3   s   | ]}|  ˆ ¡V  qd S r>   rt   )rv   Úsrx   r   r   ry   €  rz   z>preprocessor.<locals>.mk_repl.<locals>.repl.<locals>.<genexpr>r   )r¸   rS   rT   r  rû   Úformatr0   r|   ÚdirnamerW   rŽ   rU   r   r&   r   r   r   rŒ   r   rr   )r(  rj   Ú
header_dirZheader_filepathZheader_dir_to_checkZheader_path_to_checkÚheader_include_dirZheader_resultZheader_rel_pathZheader_fout_pathZhipified_header_filepath)r“   r—   r›   r”   r–   Úinclude_current_dirr‡   r’   r˜   r•   Útemplrx   r   Úreplq  sd    
ÿøø



ý
ÿÿ
ÿz+preprocessor.<locals>.mk_repl.<locals>.replr   )r5  r4  r6  )	r“   r—   r›   r”   r–   r‡   r’   r˜   r•   )r4  r5  r   Úmk_replp  s     :zpreprocessor.<locals>.mk_replz#include "{0}"z#include <{0}>Fz#define THC_GENERIC_FILE "{0}"zCMakeLists.txtrî   rï   rð   rñ   rí   r   Z	PowKernelz[skipped, no changes])rì   r  ú.cú.ccú.cppú.hú.hppÚwz[ok]zFailed to save z with "z", leaving z unchanged.©Úfilez[skipped, no permissions]z[skipped, already hipified])T)-rS   rT   rW   rŽ   rr   r   r   r   r   r   r   rŒ   rF   ÚreadlineÚHIPIFY_C_BREADCRUMBÚseekÚreadr0   rU   r1  r]   ÚRE_PYTORCH_PREPROCESSORrØ   r4   r2   ÚRE_CAFFE2_PREPROCESSORÚRE_QUOTE_HEADERÚRE_ANGLE_HEADERÚRE_THC_GENERIC_FILEru   rp   ÚRE_CU_SUFFIXr)   r-   r.   r$   ÚwriteÚOSErrorrš   r"   rK   ÚstrerrorrM   ÚsysÚstderr)r’   r€   r“   r”   r•   r–   r‡   r—   r˜   rœ   r‘   ZfinZoutput_sourceZorig_output_sourceZ	fout_pathr,  r.  r7  Zdo_writeZfout_oldZfoutrw   r   )
r“   r—   r›   r”   r–   r‡   r’   r)  r˜   r•   r   r7   4  s     
&

<
ÿþý,*&ÿr7   c                    sˆ   t | dƒj}| ¡ }|r>t dt |¡› d‡ fdd„|¡}n| |ˆ ¡}| d¡ | |¡ | ¡  W d   ƒ n1 sz0    Y  d S )Núr+z\b(z)\bc                    s   ˆ S r>   r   )rÔ   ©Úreplace_stringr   r   r»   ñ  rz   z+file_specific_replacement.<locals>.<lambda>r   )	r!   rC  rä   rØ   r  rp   rB  rJ  Útruncate)r€   Zsearch_stringrQ  Ústrictrj   Úcontentsr   rP  r   r8   í  s    &

r8   c                 C   s†   t | dƒh}| ¡ }|d dkr8|d dkr8d|› d}d|› d| }| d¡ | |¡ | ¡  W d   ƒ n1 sx0    Y  d S )	NrO  r   rÂ   rd   rÁ   rÐ   z	#include z 
)r!   rC  rB  rJ  rR  )r€   Úheaderrj   rT  r   r   r   r9   ù  s    

r9   c                 C   s   |   dd¡} | S )z<Static global kernels in HIP results in a compilation error.z __global__ staticrè   ©rp   )Zin_txtr   r   r   r:     s    r:   z#include .*\nc                 C   s6  g }dddœ}| }|d }|t |ƒk r2|| dkrF|d  d7  < nt|| dkrd|d  d8  < nV|| dkr‚|d  d7  < n8|| dkrº||d  dkrº|d dkrº|d  d8  < |d dkræ|d dkræ| ||d	œ¡ q2|d dkr(|d dkr(|| d
kr(| ||d	œ¡ |d }|d7 }q|S )ad   Return the list of arguments in the upcoming function parameter closure.
        Example:
        string (input): '(blocks, threads, 0, THCState_getCurrentStream(state))'
        arguments (output):
            '[{'start': 1, 'end': 7},
            {'start': 8, 'end': 16},
            {'start': 17, 'end': 19},
            {'start': 20, 'end': 53}]'
    r   )rÂ   r¨   r   r¨   r©   rÂ   rÁ   ú-r¼   rª   )r¢   r`   )r§   rË   Ú	argumentsZclosuresÚcurrent_positionZargument_start_posr   r   r   r;     s.    þ(*
r;   c                 C   s.   |   ¡ dv rdS |   ¡ dv r dS t d¡‚dS )zArgumentParser doesn't support type=bool. Thus, this helper method will convert
    from possible string types to True / False.)ÚyesÚtrueÚtÚyÚ1T)ÚnoÚfalserj   rc   Ú0FzBoolean value expected.N)rÿ   ÚargparseÚArgumentTypeError)Úvr   r   r   r<   ;  s
    r<   ©rì   r  r8  r9  r:  r;  z.inr<  ©r  r;  r<  ©rÏ   T)Úproject_directoryÚshow_detailedrs   Úheader_extensionsr’   r”   r„   Úextra_filesr†   r…   r˜   r–   r‡   Úhipify_extra_files_onlyr—   rn   c                    sÚ  ˆdkrt  ¡ ‰t j ˆ¡s.tdƒ t d¡ ˆsDˆ d¡ ˆd ‰ˆˆkrt‡‡fdd„ˆD ƒ‰‡‡fdd„ˆD ƒ‰t j ˆ¡sŒt 	ˆˆ¡ t
ttˆƒƒ‰t
ttˆƒƒ‰t
tˆˆˆ|||d	ƒ}t|ƒ}|D ]0}t j |¡sêt j ˆ|¡}||vrÌ| |¡ qÌd
dlm} |D ]R}t j |¡r*||ƒ}n|t j ˆ|¡ƒ}| ‡ ‡‡fdd„| d¡D ƒ¡ q|d u rvtdd}g g dœ}|sŠ|n|D ]}tˆ||||||||
ƒ	 qŽttjd tj tjd |rÖt|ƒ tS )Nr   z,The project folder specified does not exist.r   ro   Z_amdc                    s   g | ]}|  ˆˆ ¡‘qS r   rV  )rv   Úinclude©r’   rh  r   r   r¦   e  rz   zhipify.<locals>.<listcomp>c                    s   g | ]}|  ˆˆ ¡‘qS r   rV  )rv   rD   rn  r   r   r¦   f  rz   )r„   r…   rs   r†   r‡   r   )ÚPathc                 3   sF   | ]>}|  ¡ rtt|ƒˆƒrtt|ƒˆƒst|jˆ ƒrt|ƒV  qd S r>   )Úis_filer‚   Ústrr$   Úname)rv   rT   )rj  r…   r„   r   r   ry     s
   
ýzhipify.<locals>.<genexpr>rÏ   T)rN   )rŸ   r¡   z-Successfully preprocessed all matching files.r>  )rS   ÚgetcwdrT   rU   rš   rM  ÚexitÚrstripÚshutilÚcopytreerÙ   Úmaprr   r%   rO   rò   rŽ   r`   Úpathlibro  ÚextendÚrglobr#   r&   r"   rJ   rM   rN  r'   r   )rh  ri  rs   rj  r’   r”   r„   rk  r†   r…   r˜   r–   r‡   rl  r—   r“   Zall_files_setrj   ro  r3  Zheader_include_dir_pathr•   r€   r   )rj  r…   r„   r’   rh  r   r=   F  sZ    

ý
ÿ




ÿr=   )r   r   r   FF)F)F)Fre  rf  r   r   rg  r   Fr   TFFFN)crm   rb  r}   rä   rv  rM  rS   r   r   Zcuda_to_hip_mappingsr   r   Útypingr   Úcollections.abcr   r   r	   Úenumr
   r"  r  r   r   Údictrq  ZHipifyFinalResultrA  r   Ú__annotations__ZPYTORCH_TEMPLATE_MAPÚ__all__r  r    r!   r"   r#   rr   Úboolr$   r‚   r%   rÙ   r&   r'   r(   Úcompiler×   r)   r*   r+   r,   Z	RE_ASSERTr-   ré   r.   rë   r/   r0   r1   r2   r3   r4   r5   r  r6   ZCAFFE2_TRIEr-  ZPYTORCH_TRIEr$  Úobjectr*  Úmappingr  ÚitemsÚsrcrh   ÚdstZ	meta_dataZ
API_CAFFE2rV   ZAPI_SPECIALr+  ZAPI_PYTORCHr!  rE  rD  rF  rG  rH  rI  r7   r8   r9   r:   Z
RE_INCLUDEr;   r<   r=   r   r   r   r   Ú<module>   s&  	
!     úú-
÷#
 &



H
^








÷ :

.              ñð