o
    vZh                     @  s  d dl mZ d dlZd dlZd dlmZ d dlmZmZ d dl	Z	d dl
mZ d dlmZmZmZmZmZmZmZmZmZmZmZmZmZmZ d dlmZ d dlmZ d d	lm Z m!Z!m"Z" d d
l#m$Z$m%Z%m&Z&m'Z'm(Z(m)Z)m*Z* d dl+m,Z,m-Z- d dl.m/Z/ ddl0m1Z1 ddl2m3Z3 erd dl4m5Z5m6Z6 g dZ7dd e7D Z8g dZ9edddZ:dddZ;dd"d#Z<dd$d%Z=dd&d'Z>dd(d)Z?dd*d+Z@dd,d-ZAdd.d/ZBdd0d1ZCdd2d3ZDd4d5dd=d>ZEddCdDZFd4d5ddKdLZGddMdNZHddOdPZId4d5ddSdTZJdUdUdVdd[d\ZKdd^d_ZLedd`daZMddcddZNddfdgZOddidjZPedkZQedlZRedmZSednZTd4d5ddodpZUddrdsZVedtZWd4d5ddydzZXdd|d}ZYdd~dZZd4d5dddZ[d4d5dddZ\d4d5dddZ]dS )    )annotationsN)defaultdict)CallableTYPE_CHECKING)cpp)arg_parser_output_exprscpp_dispatch_exprscpp_dispatch_targetdispatch_lambda_argsdispatch_lambda_exprsdispatch_lambda_return_strhas_tensor_optionsPythonSignaturePythonSignatureDeprecatedPythonSignatureGroup!PythonSignatureNativeFunctionPair	signaturesignature_from_schemastructseq_fieldnames)CodeTemplatewith_native_function)
cpp_stringparse_native_yamlparse_tags_yaml)ArgumentBaseOperatorNameFunctionSchemaNativeFunction
SchemaKindTypeVariant)FileManagersplit_name_params)
YamlLoader   )is_tensor_list_type)should_trace)IterableSequence)Salias
contiguousZis_cudaZ	is_sparseZis_sparse_csrsizeZstrideZsym_sizeZ
sym_strideZsym_storage_offsetZ	sym_numelz.*_backwardz#.*_backward_(out|input|weight|bias)z
.*_forwardz.*_forward_outz.*_jvpZ_unsafe_viewZtensorz2_?sparse_(coo|compressed|csr|csc|bsr|bsc)_tensor.*z_range.*Z_sparse_add_outz_sparse_div.*z_sparse_mul.*z_sparse_sub.*Z_sparse_dense_add_outindexZ	index_outZunique_dim_consecutivez	_cumsum.*z
_cumprod.*z_sum.*z_prod.*z_th_.*z_thnn_.*zrange.*z_solve.*z
_inverse.*z_cholesky.*z_triangular_solve.*z_qr.*z_svd.*sliceitemZ_local_scalar_densetoZ_to_copyZ_to_copy_outZ_reshape_copyZ_reshape_copy_outZcopy_sparse_to_sparse_Zcopy_Z_foreach_copyZnumpy_TZmatrix_HZmTZmHznonzero(_(out|numpy))?set_dataz.*_overrideabledataZis_leafZ	output_nr_versionZrequires_grad_Zretains_gradset_Z
_fw_primalZ)fake_quantize_per_tensor_affine_cachemaskZ*fake_quantize_per_channel_affine_cachemaskZ!_new_zeros_with_same_feature_metaZ_has_same_storage_numelZ_reshape_aliasZreplace_copyzfill.Tensorzfill.Scalarzlift.*Znormal_functionalnbytesitemsizeZ_batch_norm_with_updateZ_batch_norm_with_update_outZ_batch_norm_no_updatec                 C  s   g | ]}t d | dqS )^$)recompile).0pattern r>   ^/var/www/auris/lib/python3.10/site-packages/torchgen/packaged/autograd/gen_python_functions.py
<listcomp>   s    r@   )z?add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> TensorzHadd_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)z?sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> TensorzHsub_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)z/mul.Scalar(Tensor self, Scalar other) -> Tensorz8mul_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)z/div.Scalar(Tensor self, Scalar other) -> Tensorz8div_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)fr   returnboolc                 C  sd   d| j v rd| j vrdS t| j}tD ]
}||r dS qt| j}tD ]	}||kr/ dS q&dS )N	generatedZ	view_copyFT)tagsr   namefuncSKIP_PYTHON_BINDINGSmatchstrSKIP_PYTHON_BINDINGS_SIGNATURES)rA   rF   Z
skip_regexr   r=   r>   r>   r?   should_generate_py_binding   s   

rL   rF   r   rJ   c                 C  s
   d|  S )NZTHPVariable_r>   rF   r>   r>   r?   get_pycname      
rN   	overloads+Sequence[PythonSignatureNativeFunctionPair]c                 C  s   t | dko| d j dkS )Nr%   r   )lenr   arguments_count)rP   r>   r>   r?   is_noarg   s   rT   c                 C     | j d u o
tj| jv S N)python_moduler!   methodvariantsrA   r>   r>   r?   is_py_variable_method      r[   c                 C  rU   rV   )rW   r!   functionrY   rZ   r>   r>   r?   is_py_torch_function   r\   r^   c                 C  
   | j dkS )NnnrW   rZ   r>   r>   r?   is_py_nn_function   rO   rb   c                 C  r_   )NZfftra   rZ   r>   r>   r?   is_py_fft_function   rO   rc   c                 C  r_   )NZlinalgra   rZ   r>   r>   r?   is_py_linalg_function   rO   rd   c                 C  r_   )Nnestedra   rZ   r>   r>   r?   is_py_nested_function   rO   rf   c                 C  r_   )Nsparsera   rZ   r>   r>   r?   is_py_sparse_function   rO   rh   c                 C  r_   )NZspecialra   rZ   r>   r>   r?   is_py_special_function   rO   ri   Tsymintoutnative_yaml_pathtags_yaml_pathdeprecated_yaml_pathtemplate_pathrk   Nonec             
     s>  t | |dd}t||j}ttt|}t||dd}t||td dd|d t||dd}	t	||	t
dddd	|d
 t||	tddd|d t||	tddd|d t||	tddd|d t||	tdddd t||	tddd|d t||	tddd|d t||	dd d t||	dd d t| d! fdd}
|d |
 d S )"NF)Zinstall_dirZtemplate_dirdry_runTrX   zpython_variable_methods.cpprX   rk   torchzpython_torch_functions.cpp   )rX   
num_shardsrk   torch.nnzpython_nn_functions.cpp	torch.fftzpython_fft_functions.cpptorch.linalgzpython_linalg_functions.cpptorch.nestedzpython_nested_functions.cpptorch.sparsezpython_sparse_functions.cpptorch.specialzpython_special_functions.cppc                 S     dS NTr>   fnr>   r>   r?   <lambda>q      zgen.<locals>.<lambda>zpython_return_types.cppc                 S  r~   r   r>   r   r>   r>   r?   r   t  r   zpython_return_types.hrB   dict[str, str]c                     s   dd dd t D iS )NZenum_of_valid_tags c                 S  s   g | ]}d | d| dqS )z	
.value("z", at::Tag::)r>   )r<   tagr>   r>   r?   r@   }  s    z.gen.<locals>.gen_tags_enum.<locals>.<listcomp>)joinsortedr>   
valid_tagsr>   r?   gen_tags_enumy  s
   zgen.<locals>.gen_tags_enumzpython_enum_tag.cpp)rB   r   )r"   r   native_functionslistfilterrL   load_signaturescreate_python_bindingsr[   create_python_bindings_shardedr^   rb   rc   rd   rf   rh   ri   "create_python_return_type_bindings)create_python_return_type_bindings_headerr   write)rl   rm   rn   ro   rp   rk   fmr   methodsZ	functionsr   r>   r   r?   gen	  s   	


	
	r   pairspred Callable[[NativeFunction], bool]?dict[BaseOperatorName, list[PythonSignatureNativeFunctionPair]]c                 C  s6   t t}| D ]}||jr||jjjj | q|S rV   )r   r   r]   rG   rF   append)r   r   groupedpairr>   r>   r?   group_filter_overloads  s   
r   r   r"   module
str | NonefilenamerX   c          
   
     s   g g g g t ||}t| tdD ]1}|| }	t|||	||d t|||	|d t||	|d d|j	 d q
   fdd dS )	+Generates Python bindings to ATen functionskeyrt   rs   #include <ATen/ops/.h>c                     s$   dd   d   dS )N@generated from /)generated_commentops_headerspy_forwards
py_methodspy_method_defstemplate_dir_for_commentsr>   r   r   r   r   r   r   r>   r?   r     s   z(create_python_bindings.<locals>.<lambda>N)r   r   keysrJ   r   method_impl
method_defextendforward_declsbasewrite_with_template)
r   r   r   r   r   rX   rk   r   rF   rP   r>   r   r?   r     s$   
r   c           	        s   g g t ||}t| tdD ]$}|| }t|\}}|s#dnd| |s/dnd| q   fdd dS )z
    Generate function to initialize and return named tuple for native functions
    which returns named tuple and registration invocations in `python_return_types.cpp`.
    r   r   
c                     s    dd   d   dS )Nr   r   r   )r   Zpy_return_typespy_return_types_registrationsr   r>   r   r   Zpy_return_types_definitionr   r>   r?   r     s   z4create_python_return_type_bindings.<locals>.<lambda>N)r   r   r   rJ   1generate_return_type_definition_and_registrationsr   r   r   )	r   r   r   r   r   rF   rP   definitionsregistrationsr>   r   r?   r     s&   

r   c                   sj   g t ||}t| tdD ]}|| }t|}|sdnd| q   fdd dS )z
    Generate function to initialize and return named tuple for native functions
    which returns named tuple and relevant entry for the map in `python_return_types.cpp`.
    r   r   r   c                     s   dd   d   dS )Nr   r   r   )r   py_return_types_declarationsr   r>   r   r   r   r>   r?   r     s
   z;create_python_return_type_bindings_header.<locals>.<lambda>N)r   r   r   rJ   !generate_return_type_declarationsr   r   r   )r   r   r   r   r   rF   rP   declarationsr>   r   r?   r     s   

r   rw   intc             	     sb   t ||}ddd}	d fdd	}
| j|| d
dd|   d|  i|	|
|h dd dS )r   kv@tuple[BaseOperatorName, list[PythonSignatureNativeFunctionPair]]rB   rJ   c                 S  s
   | d j S )Nr   )r   )r   r>   r>   r?   key_func  s   
z0create_python_bindings_sharded.<locals>.key_funcdict[str, list[str]]c              	     sN   | \}}d|j  dgtt|| dt|| dgt|| dgdS )Nr   r   rs   rt   )r   r   r   r   )r   r   r   r   r   )r   rF   Zfn_pairsrX   r   rk   r>   r?   env_func  s   z0create_python_bindings_sharded.<locals>.env_funcr   r   r   r   >   r   r   r   r   )Zbase_envZkey_fnZenv_callablerw   Zsharded_keysN)r   r   rB   rJ   )r   r   rB   r   )r   Zwrite_shardeditemsr   )r   r   r   r   r   rX   rw   rk   r   r   r   r>   r   r?   r     s   


r   F)skip_deprecatedpyir   list[NativeFunction]r   r   c                  sB   t d fdd}tt|| }t|| d}|r|S || S )	NrA   r   rB   r   c                   s   t t|  d| dS )NrX   r   r   r]   )r   r   rZ   r   r>   r?   gen_signature_pairs8  s   z,load_signatures.<locals>.gen_signature_pairsr   )rA   r   rB   r   )r   r   mapload_deprecated_signatures)r   ro   rX   r   r   r   r   
deprecatedr>   r   r?   r   0  s   r   'list[PythonSignatureNativeFunctionPair]c                  s  t t}| D ]}||jj | qg }t|}tj|td}W d    n1 s*w   Y  |D ]}	t	
|	d t|	d \}
 |
drO|
dd}
dt
didd	 jjD  D ]}|v sr|v srJ d
| qad fdd}d}||
 D ]3}||jjsqd}t|jj||d}|tt|j|j|j|j|j|jt |jd	|jd q|sJ d|
 dt q1|S )N)LoaderrF   ZatenZ_outr   1Scalarc                 S  s   i | ]}|j |qS r>   rM   r<   ar>   r>   r?   
<dictcomp>i      z.load_deprecated_signatures.<locals>.<dictcomp>z*deprecation definiton: Unrecognized value aten_schemar   rB   rC   c                   s   rt | jj| jj}n| jj}t|D ]9\}}|t k rF | }|v r.| }d }n
| }|j}|j	}||jksB||j	krE dS q|j
d u rN dS qtjt| jkoftdd tj| jD S )NFc                 s  s    | ]	\}}||kV  qd S rV   r>   )r<   r   br>   r>   r?   	<genexpr>  s    
zKload_deprecated_signatures.<locals>.is_schema_compatible.<locals>.<genexpr>)	itertoolschain	argumentsrl   Zflat_non_outflat_all	enumeraterR   type
annotationdefaultreturnsallzip)r   r   iargZarg_nameZschema_typeZschema_annotationZ
schema_argZ	call_argsZis_outZknown_constantsZschemaZschema_args_by_namer>   r?   is_schema_compatibleq  s.   
z8load_deprecated_signatures.<locals>.is_schema_compatibleFT)category_overriderX   r   )	rF   
input_argsinput_kwargsoutput_argstensor_options_argsrX   deprecated_schemaZdeprecated_args_exprsr   r   zNo native function with name z matched signature:
  )r   r   rB   rC   )r   r   r   rF   r   openyamlloadr$   r   parser#   endswithreplacer    r   r   r]   rG   r   r   r   r   r   r   r   r   rX   tupler   rJ   )r   ro   rX   r   r   r   resultsrA   Zdeprecated_defsr   Z	aten_namerF   r   Zany_schema_foundZ
python_sigr>   r   r?   r   F  sj   


 r   c                 C  s(   t | j}t| jj}d|g| S )N_)r   rF   rG   r   r   r   )rA   rF   
fieldnamesr>   r>   r?   gen_structseq_typename_key  s   r    tuple[list[str], dict[str, str]]c                 C  s   i }g }| D ]<}t |jjj}|sqt|jj}t|j}||}|du rBd|s,dnt| }|||< |	d| d| d q||fS )zr
    Generate block of named tuple type def inits, and add typeref snippets
    to declarations that use them
    N
NamedTupler   zstatic PyTypeObject* z = generated::get__structseq();
r   r]   rG   r   r   rF   r   getrR   r   )rP   	typenamesZtypedefsoverloadr   rF   tn_keytypenamer>   r>   r?   emit_structseq_call  s,   	

r
  tuple[list[str], list[str]]c           
      C  s   i }g }g }| D ]h}t |jjj}|sqddd |D }t|jj}t|j}||}	|	du rp| d|s:dnt	| }	|	||< |
d| d| d	|	 d
| dt	| d|	 d|	 d|	 d |
d| d| d q||fS )z
    Generate block of function in `python_return_types.cpp` to initialize
    and return named tuple for a native function which returns named tuple
    and registration invocations in same file.
    , c                 s  s    | ]	}d | dV  qdS ){"z", ""}Nr>   )r<   r   r>   r>   r?   r     s    zDgenerate_return_type_definition_and_registrations.<locals>.<genexpr>Nr  r   PyTypeObject* get_zI_structseq() {
    static PyStructSequence_Field NamedTuple_fields[] = { z(,  {nullptr} };
    static PyTypeObject zh;
    static bool is_initialized = false;
    static PyStructSequence_Desc desc = { "torch.return_types.z", nullptr, NamedTuple_fields, zB };
    if (!is_initialized) {
        PyStructSequence_InitType(&z, &desc);
        zm.tp_repr = (reprfunc)torch::utils::returned_structseq_repr;
        is_initialized = true;
    }
    return &z;
}
z$addReturnType(return_types_module, "z", generated::get_z_structseq());)r   r]   rG   r   r   r   rF   r   r  rR   r   )
rP   r  r   r   r  r   fieldsrF   r  r	  r>   r>   r?   r     sN   


r   	list[str]c                 C  s   i }g }| D ];}t |jjj}|sqt|jj}t|j}||}|du rA| d|s.dnt| }|||< |	d| d q|S )z
    Generate block of function declarations in `python_return_types.h` to initialize
    and return named tuple for a native function.
    Nr  r   r  r  r  )rP   r  r   r  r   rF   r  r	  r>   r>   r?   r     s"   	

r   a  \
// ${name}
static PyObject * ${pycname}(PyObject* self_, PyObject* args, PyObject* kwargs)
{
  ${method_header}
  static PythonArgParser parser({
    ${signatures}
  }, /*traceable=*/${traceable});

  ParsedArgs<${max_args}> parsed_args;
  auto _r = parser.parse(${self_}, args, kwargs, parsed_args);
  ${check_has_torch_function}
  switch (_r.idx) {
    ${dispatch}
  }
  ${method_footer}
}

z&case ${overload_index}: {
  ${body}
}
ao  // ${name}
static PyObject * ${pycname}(PyObject* self_, PyObject* args, PyObject* kwargs)
{
  ${method_header}
  static PythonArgParser parser({
    ${signatures}
  }, /*traceable=*/${traceable});

  ParsedArgs<${max_args}> parsed_args;
  auto _r = parser.parse(${self_}, args, kwargs, parsed_args);
  ${check_has_torch_function}
  ${dispatch}
  ${method_footer}
}

z// ${name}
static PyObject * ${pycname}(PyObject* self_, PyObject* args)
{
  ${method_header}
  ${check_has_torch_function}
  ${dispatch}
  ${method_footer}
}

c                C  sH  t | }t|}t|\}}dg}	|	|7 }	|	|rdgng 7 }	|r"g ndgdg }
tdd |D r3dnd}t||d	}t|d
k}g }g }t|D ],\}}|jj|d	}|	t
t| d t|||d	}|	|srtj||dn| qI|r{t}n|rt}nt}|j| ||	tdd |D ||t| |||d||
|rdd
S dd
S )z?
    Generate a python binding for all overloads of an op.
    ZHANDLE_TH_ERRORSz/const Tensor& self = THPVariable_Unpack(self_);Py_RETURN_NONE;ZEND_HANDLE_TH_ERRORSc                 s  s    | ]}t |jV  qd S rV   )r'   r]   r<   or>   r>   r?   r         zmethod_impl.<locals>.<genexpr>truefalserj   r%   ,)overload_indexbodyc                 s  s    | ]}|j  V  qd S rV   )r   rS   r  r>   r>   r?   r     r  )rF   r   noargrX   self_nullptr)
rF   pycnamemethod_headerZmax_args
signatures	traceableZcheck_has_torch_functiondispatchmethod_footerr  )rN   rT   r
  r   group_overloadsrR   r   r   signature_strr   r   rJ   emit_dispatch_casePY_VARIABLE_CASE
substitutePY_VARIABLE_METHOD_NOARGS$PY_VARIABLE_METHOD_VARARGS_SINGLETONPY_VARIABLE_METHOD_VARARGSmaxgen_has_torch_function_check)rF   r   rP   rX   rk   r  r  Zstructseq_initsstructseq_typenamesr  r"  r   grouped_overloadsZis_singletonr  r!  r  r  r   Zdispatch_bodytemplater>   r>   r?   r     sf   r   r  c                C  s`   |r|r
d|  dS dS |rdnd}|r dddd	d
ddd| nd}d| d| d|p,d dS )NzMif(check_has_torch_function(self_)) {
  return handle_torch_function(self_, "z");
}
r   r  r  ZTHPVariableFunctionsModuleZTHPNNVariableFunctionsModuleZTHPFFTVariableFunctionsModuleZ THPLinalgVariableFunctionsModuleZ THPNestedVariableFunctionsModuleZ THPSparseVariableFunctionsModuleZ!THPSpecialVariableFunctionsModule)ru   rx   ry   rz   r{   r|   r}   ZTHPVariableClasszAif(_r.has_torch_function()) {
  return handle_torch_function(_r, z, args, kwargs, z, "ztorch.Tensorr>   )rF   r   r  rX   r  	namespacer>   r>   r?   r,    s8   
r,  zRif (_r.isNone(${out_idx})) {
  ${call_dispatch}
} else {
  ${call_dispatch_out}
}
r  r   r-  r   c             	   C  sT   | j dur tj| j t| j| j||dt| j| j ||ddS t| j| j||dS )a0  
    Emit dispatch code for a single parsed signature. This corresponds to either
    a single native function, or a pair that differ only in output params. In the
    latter case, a single python signature is used for both and dispatching
    switches on the presence/absence of passed output args.
    Nrj   )Zout_idxZcall_dispatchZcall_dispatch_out)ZoutplacePY_VARIABLE_OUTr'  r   Z
output_idxemit_single_dispatchr   )r  r-  rk   r>   r>   r?   r%    s    
r%  tuple[str, ...]c                C  s4   |rdS t | }t|rd| dfS d| dfS )Nr>   zstatic PyObject * z#(PyObject* self_, PyObject* args);
z5(PyObject* self_, PyObject* args, PyObject* kwargs);
)rN   rT   )rF   rP   rX   r  r>   r>   r?   r     s   r   c                C  sh   t | }| jrd| d}t|r|rdnd}nd| d}d}|dkr(|d7 }d	|  d
| d| dS )z$
    Generate method def entry.
    zTypeError_to_NotImplemented_<>ZMETH_NOARGSzMETH_VARARGS | METH_KEYWORDSzcastPyCFunctionWithKeywords(r   ru   z | METH_STATICr  z", r  z, nullptr},)rN   Zdunder_methodrT   )rF   r   rP   rX   r  flagsr>   r>   r?   r   <  s   
r   Sequence[PythonSignatureGroup]c          	        s`  i }i  | D ]C}|j jd|d}|jj r0| v r+td|jj d | jj d| |< q||v rEtd|jj d|| jj d|||< q  D ]P\}}||vrg }| D ]'}t|jjjjt|jjjjkr|jj s|j j	s|
|j jd|d qZ|j j|d}td| d| d	d
dd |D  qN fdd| D }t||dS )NT)skip_outputsrk   z(Found duplicated function definition:
- z.
Existing definition:
- .rj   z4While identifying overloads, we found an out schema z] without a corresponding non-out variant. We expected the non-out variant to have schema: 
- zy
Please check that you spelled the schema correctly in native_functions.yaml. We discovered the following candidate(s): 
r   c                 s  s    | ]}d | V  qdS )z- Nr>   )r<   	candidater>   r>   r?   r     r  z"group_overloads.<locals>.<genexpr>c                   s$   g | ]\}}t j| |d qS ))Z
functionalrl   )r   Z
from_pairsr  )r<   sigr   Z	outplacesr>   r?   r@     s    z#group_overloads.<locals>.<listcomp>)r   r$  r]   rG   Z	is_out_fnRuntimeErrorr   rJ   rF   r   r   r   sort_overloads)	rP   rk   basesr  r:  rl   
candidatesZout_sigr   r>   r;  r?   r#  _  sb   





r#  r.  c                  s  ddddfdd}t  fddd ttt D ]\}}t D ]\}}||j|jr8| | q&qs@t S t }ttfddt	|}t	|D ]"}	||	 }
t 
 D ]}| }||
 |sv|= || qaqU fdd|D S )Nt1r    t2rB   rC   c                 S  s   t | dkrt |dkpbt | dkot |dkpbdt | v o#dt |vpbt | dko5t |dkp5t |dkpbt | d	koDt |d
dkpbt | dkoPt |dkpbt | dks]t | dkobt |dkS )Nr   ZTensorzScalar?zTensor?ZDimnamezint[]r   zint?zTensor[]z[]zSymInt[]ZSymInt)rJ   find)r@  rA  r>   r>   r?   is_arg_smaller  s*   


	

z&sort_overloads.<locals>.is_arg_smallers1r   s2c                   sl   | j dd|j dd}}t|t|krdS tdd t||D }t fddt||D }|o5| S )z-Returns True if s1 < s2 in the partial order.T)r7  Fc                 s  s     | ]\}}|j |j kV  qd S rV   )r   r<   Zarg1Zarg2r>   r>   r?   r     s    z5sort_overloads.<locals>.is_smaller.<locals>.<genexpr>c                 3  s6    | ]\}}t |jt |jkp |j|jV  qd S rV   )rJ   r   rG  rD  r>   r?   r     s
     
)r   rR   r   r   )rE  rF  Zargs1Zargs2equalZsmaller_or_equalrH  r>   r?   
is_smaller  s   
z"sort_overloads.<locals>.is_smallerc                   s   | j j dS )Nrj   )r   r$  xrj   r>   r?   r     s    z sort_overloads.<locals>.<lambda>r   c                   s   |  vS rV   r>   rK  )larger_thanr>   r?   r     s    c                   s   g | ]} | qS r>   r>   )r<   rL  )r.  r>   r?   r@     r   z"sort_overloads.<locals>.<listcomp>)r@  r    rA  r    rB   rC   )rE  r   rF  r   rB   rC   )r   r   setr   r   addr   rR   r   ranger   discardr   )r.  rk   rJ  i1Z	overload1i2Z	overload2NZ
sorted_idsidxr   jZlargerr>   )r.  rD  rM  rk   r?   r=    s6   


r=  psr   c                  s   t d fdd}||S )	z:
    Emit dispatch code for a single native function.
    rA   r   rB   rJ   c                   s  t  trd j }nd| j }t| j}ddd t | dD }t| }t	| }dt
|  d}t | d}t | d}d|j}	d|j}
 jo`t|  p` jo`d	|v }|rld
|d	 j dnd}|dkr| jjj}t| jjdr| j tjkr|d urt|jjsJ d}nd}| d|	 d| d| d| d| d| d| d|
 d| d| dS t| }|d ur| dnd}| d|	 d| d| d| d| d| d| d| d|
 d| dS )Nz// [deprecated] aten::z	// aten::r  c                 s  s"    | ]}|j  d |j V  qdS ) N)Ztype_strrF   r   r>   r>   r?   r   2  s    
z3emit_single_dispatch.<locals>.go.<locals>.<genexpr>rj   )Zpython_signaturer   Zrequires_gradz.set_requires_grad(r   r   voidZ	_foreach_z\PyObject* self_tensorlist = _r.args[0];
Py_INCREF(self_tensorlist);
return self_tensorlist;
r  z
auto dispatch_z = [](z) -> z, {
  pybind11::gil_scoped_release no_gil;
  (z);
};
dispatch_z;
z3 {
  pybind11::gil_scoped_release no_gil;
  return z);
};
return wrap(Z	dispatch_z);
)
isinstancer   r   rG   r   rF   r   r
   r   r	   r   r   r   initsexprsr   r   rX   exprr   self_argrJ   
startswithkindr   Zinplacer&   argumentr   r  r   )rA   Zschema_commentrF   Zlambda_formalsZlambda_returnZdispatch_calleeZdispatch_argsZparser_outputsZlambda_arg_exprsr\  Zlambda_argsZneed_set_requires_gradZset_requires_gradr_  Zreturn_stmtr	  Zstructseq_typerefrW  r-  rk   r>   r?   go(  s   




z emit_single_dispatch.<locals>.goNrA   r   rB   rJ   r   )rW  rA   r-  rk   rd  r>   rc  r?   r2    s   Sr2  )rA   r   rB   rC   )rF   r   rB   rJ   )rP   rQ   rB   rC   )rl   rJ   rm   rJ   rn   rJ   ro   rJ   rp   rJ   rk   rC   rB   rq   )r   rQ   r   r   rB   r   )r   r"   r   rQ   r   r   r   r   r   rJ   rX   rC   rk   rC   rB   rq   )
r   r"   r   rQ   r   r   r   rJ   rB   rq   )r   r"   r   rQ   r   r   r   r   r   rJ   rX   rC   rw   r   rk   rC   rB   rq   )r   r   ro   rJ   rX   rC   r   rC   r   rC   rB   rQ   )
r   rQ   ro   rJ   rX   rC   r   rC   rB   r   re  )rP   rQ   rB   r  )rP   rQ   rB   r  )rP   rQ   rB   r  )rF   r   r   r   rP   rQ   rX   rC   rk   rC   rB   rJ   )
rF   r   r   r   r  rC   rX   rC   rB   rJ   )r  r   r-  r   rk   rC   rB   rJ   )rF   r   rP   rQ   rX   rC   rB   r3  )
rF   r   r   r   rP   rQ   rX   rC   rB   rJ   )rP   rQ   rk   rC   rB   r6  )r.  r6  rk   rC   rB   r6  )
rW  r   rA   r   r-  r   rk   rC   rB   rJ   )^
__future__r   r   r:   collectionsr   typingr   r   r   Ztorchgen.apir   Ztorchgen.api.pythonr   r   r	   r
   r   r   r   r   r   r   r   r   r   r   Ztorchgen.code_templater   Ztorchgen.contextr   Ztorchgen.genr   r   r   Ztorchgen.modelr   r   r   r   r   r    r!   Ztorchgen.utilsr"   r#   Ztorchgen.yaml_utilsr$   Zgen_inplace_or_view_typer&   Zgen_trace_typer'   collections.abcr(   r)   Z_SKIP_PYTHON_BINDINGSrH   rK   rL   rN   rT   r[   r^   rb   rc   rd   rf   rh   ri   r   r   r   r   r   r   r   r   r   r
  r   r   r*  r&  r)  r(  r   r,  r1  r%  r   r   r#  r=  r2  r>   r>   r>   r?   <module>   s   !@$	V










|
)
')3
v


3&	
E$
(
$g]