a
    ho                    @  s1  d Z ddlmZ ddlZddlZddlZddlZddlZddlm	Z	m
Z
 ddlmZ ddlZddlm  mZ ddlZddlZddlmZ ddlmZmZmZmZ ddlmZ dd	lmZmZ e
rdd
lmZ ddlm Z  g dZ!ej"ej#ddZ$ddddZ%dd Z&e$dddddZ'e$dddddZ(e$de)ddddd Z*e$d!e)dddd"d#Z+e$d$d?ddd%d&Z,e$d'd@ddd(d)Z-e$d*dAddd+d,Z.e$d-ddd.d/Z/e$d0ddd1d2Z0e$d3e1d4d4d4d5dBddd7d8Z2e1d4d4d9ddd:d;Z3ddd<d=Z4ddd>d?Z5e$d@dddAdBZ6e$dCdddDdEZ7e$dFdddGdHZ8e$dIdddJdKZ9e$dLe1d4dMdddNdOZ:e$dPe1d4dMdddQdRZ;e$dSdddTdUZ<e$dVdddWdXZ=e$dYdddZd[Z>e$d\ddd]d^Z?e$d_e1d4d4d4d`d`dddadbZ@e$dcdddddeZAe$dfdddgdhZBe$didddjdkZCe$dlej)ddmdndodddpdqZDe$drdddsdtZEe$dudddvdwZFe$dxdddydzZGe$d{ddd|d}ZHe$d~ddddZIe$dddddZJe$dddddZKe$dej)ddddoddddZLe$dddddZMe)dddddZNe$deOddgde$deOddgde$dejOddddgddCddddddZPe$de1d4dMdddddZQe$dddddZRe$dddddZSe$ddddd`ZTe$de)dddddZUe$de)dddddZVe$de)dddddZWe$de)ddddddZXe$de)de1d4d4dMdd4ddddZYe$de)de1d4d4d4dMdMdMd4dMdM	ddddZZe$dej)ddddDddddZ[e$dÃe)de1d4dMdMddddńZ\e$dƃe1d4dǡddddɄZ]e$dʃe)ddddd̄Z^e$d̓ddddτZ_e$dЃe1d4dMdMdMdEdddd҄Z`e$dӃe1d4d4dMdMdFddddՄZae$dփdGdddd؄Zbe$dكe1d4ddMdMdHddddۄZce$d܃dIddddބZde$d߃e1d4dMdMdJddddZee$de)de1d4dMd4ddddZfe$dddddZge$ddKddddZhe$dddddZie$dddddZje$dddddZke$de)dddddZle$de)dddddZme$dddddZne$dddddZoe$d ddddZpe$de1d4d`d`ddddZqe$de)de1d4d5ddLddddd	d
dZre$de1d4dMddddZse$de1d4dMddMddddZte$dddddZue$ddd Zve$dejOdejwjxjyjzddde%dgde$dejOdejwjxjyj{ddde%dgde$dejOd ejwjxjyj|d!dde%d gdd"d# Z}e$d$e}d%ejwjxjyjzdddZ~e$d&e}d'ejwjxjyj{dddZe$d(e}d)ejwjxjyj|d!ddZe$d*eOd+ejwjxjyjze%d+gde$d,eOd-ejwjxjyj{e%d-gde$d.eOd/ejwjxjyj|e%d/gdd0d1 Ze$d2eOd3d4ejwjxjyjze%d3gde$d5eOd6d4ejwjxjyj{e%d6gde$d7eOd8d4ejwjxjyj|e%d8gde$d9eOd:d;ejwjxjyjze~e%d:gde$d<eOd=d;ejwjxjyj{ee%d=gde$d>eOd?d;ejwjxjyj|ee%d?gddNd@dAZdBdCdDdEZdFdG Ze$dHdddIdJZddddKdLdMZe$dNe$dOe$dPdddQdRZe$dSe$dTe$dUdddVdWZe$dXddddddYdZd[Ze$d\eOd]d!d^e%d]gde$d_eOd`dad^e%d`gde$dbeOdcddd^e%dcgde$deeOdfd!dge%dfgde$dheOdidadge%digde$djeOdkdddge%dkgdddBddldmdnZe$dodddpdqZe$drdddsdtZe$dudvdw Zdxdy Zdzdzd{d|d}Ze$d~ddddZe$de)ddddddZe$de)ddeddddZe$de)ddddddZddddZe$de)ddddddZddddZe$de)ddeddddZe$de)ddeddddZe$dddddZe$dddddZe$dddddZe$dedddddZe$dedddddZe$dedddddZe$dddddZe$dddddZe$dddddZe$de1d4d4d4dMdOddddZe$de1d4dMddPddddZe$de1d4dMdMddddZe$de1d4d4d4ddddMddMdMdMdMdMdQddddZe$de1d4d4d4dd9ddMddddZe$de1d4d4d4ddddMddM	dddÐdĄZe$dŃe1d4d4d4dd4ddMdddƐdǄZe$dȃe1d4d4d4dd4ddMdddɐdʄZe$d˃e1d4d4d4dd4ddMddd̐d̈́Ze$d΃e1d4d4d4ddddMdǡdddϐdЄZe$dуe1d4d4d4ddddMdǡdddҐdӄZe$dԃe1d4d4d4ddddMdǡdddՐdքZe$d׃e1d4d4d4d4d4dMd5d5dM	dddؐdلZe$dڃe)dddde1d4dd4d4d5dddېddddܐdݜdސd߄Ze$de)dddde1d4dd4d4d5ddddېddddddddZe$de1d4d4d4d4d4dd5d5d	ddddddddZe$de1d4dMdMdMddddZe$de)de1d4d`d`d`ddddZe$de)dddddZe$de1d4dMd4ddddZe$dddddZe$dddddZe$dddddZe$de1d4d4dddRdddd Ze$dddddZe$de1d4d4dMd5ddddZe$ddddd	Ze$d
ddddZe$dddddZe$dddddZe$dddddZe$dddddZe$dddddZe$dddddZe$de1d4d4ddd d!Ze$d"e1d4d4ddd#d$Ze$d%dSddd&d'Ze$d(e)ddddd)d*Ze$d+dTddd,d-Ze$d.e)ddddd/d0Ze$d1e)de1d4ddMddd2d3Ze$d4e)de1d4ddMddd5d6Ze$d7e)de1d4d4dMddd8d9Ze$d:ddd;d<Ze$d=e$d>e1d4d5dMddd?d@Ze$dAeOdAgde$dBeOdBgde$dCeOdCgde$dDeOdDgde$dEeOdEgde$dFeOdFgddddGdHZe$dIe1d4d`ddMd4dUdddJdKZe$dLe1d4d4d4dMdddMdNZe$dOe1d4dMdMdddPdQZe$dRe1d4dMdMdMdddSdTZe$dUedVdddWdXZe$dYedVdddZd[Ze$d\edVddd]d^Ze$d_edVddd`daZe$dbedVdddcddZe$deedVdddfdgZe$dhedVdddidjZe$dkedVdddldmZe$dnedVdddodpZe$dqe1d4dMd4d4d4d4dVdddrdsZe$dte1d4dMd4d4d4d4dWdddudvZe$dwdXdddxdyZe$dzddd{d|Ze$d}dYddd~dZe$ddZddddZe$de1d4dMd4d4d4d[ddddZe$de1d4dMd4d4d4d4d\ddddZe$dd]ddddZe$dddddZe$de1d4dMd4d4d4d^ddddZe$de1d4dMd4d4d4d4d_ddddZe$dd`ddddZe$ddaddddZe$ddbddddZe$ddcddddZe$dddddZe$dddddZe$de)de1d4d5d5dddddddZe$de)de1d4ddddZe$dej)ddddoe1d4ddddZe$de1d4ddddZe$de1d4d5ddddZe$de1d4d5ddddZe$dddddZe$de1d4dMddddZe$de1d4dMdMdddddddZe$dÃdddĐdńZe$dƃe1d4dMdMdMdMddedddǐdȄZe$dɃdddʐd˄Ze$d̃ddd͐d΄Ze$dσdddАdфZe$d҃dfdddӐdԄZe$dՃe1d4dMddd֐dׄZe$d؃e1d4dMdddِdڄZdgdddېd܄Z e1d4d4d4dMdMd5dMdMdM	dddݐdބZe1d4d4d4d4dMdMd5dMdM	dddߐdZe$dddddZe$dddddZe$deOde%dgde$deOde%dgde$deOde%dgdddddZe$de1d4dMddddZe$dddddZe$de1d4dMddddZe$de1d4d4dMddddZ	e$de1d4d4dMd`d4ddd dZ
e$dddddZe$dddddZe$dddd	d
Ze$dddddZe$ddhddddZe$ddiddddZe$de1d4d5d5dMdddddZe$ddjddddZe$de1d4ddddZe$de1d4ddddZe$d e)ddde1d4dMdMddd!d"Ze$d#e1d4ddd$d%Ze$d&dkddd'd(Ze$d)e1d4ddd*d+Ze$d,ddd-d.Ze$d/ddd0d1Ze$d2e1d4dMdMdMddd3d4Ze$d5e1d4d4ddd6d6dd7d8d9Ze$d:e1d4d4ddd6d6dd7d;d<Ze$d=e1d4dMd4d4ddd>d?Ze$d@e1d4dMd4d4dddAdBZe$dCdddDdEZ e$dFdddGdHZ!e$dIdddJdKZ"e$dLedddMdNZ#e$dOdddPdQZ$e$dRe1d4dMd4d4dldddSdTZ%e1d4ddMdMdddUdVZ&e$dWdddXdYZ'e$dZddd[d\Z(e$d]ddd^d_Z)e$d`dddadbZ*e$dce1d4ddMdddddeZ+e$dfdddgdhZ,e$didddjdkZ-e$dldddmdnZ.e$dodddpdqZ/e$drdddsdtZ0e$dudddvdwZ1e$dxe1d4d4ddd4dd6d6dydd6dzd{d|Z2e$d}e1d4d5ddd4dd6ddydd6dzd~dZ3e$de1d4d4ddd4dd6d6ddd6dzddZ4e$de1d4d4dMdmddddZ5e$de1d4dddnddddZ6e$de1d4dMdd4doddddZ7e$dddddZ8e$de1d4d9dpdddddZ9e$dddddZ:e$de1d4d9dqdd6ddddZ;e$de)dddde1d4dMd4d4d5dMddddZ<e$de1d4d4dMddddZ=e$dddddZ>e$dddddZ?e$dddddZ@e$dddddZAe$dddddZBddddZCddddZDe$de1d4d4dMdddddZEe$de1d4d4dMddddZFe$de)de1d4d4ddMdrddddZGe$dddddZHe$ddddÐdĄZIe$dŃdddƐdgZJe$dǃe1d4ddMd4d4d4d4dsddȐdɜdʐd˄ZKe$d̃ddd͐d΄ZLe$dσdddАdфZMe$d҃e1d4d`d`dddӐdԄZNe$dՃe1d4d4ddd֐dׄZOe$d؃dtdddِdڄZPe$dۃe1d4ddǡdddܐd݄ZQe$dރe1d4d4dMdudddߐdZRe$ddvddddZSddddZTe$dddddZUe$dddddZVe$ddwddddZWe$dddddZXe$dddddZYe$dddddZZe$dddddZ[e$ddxddddZ\e$d ddddZ]e$dddddZ^e$dddddZ_e$d	dd
dddZ`e$dddddZae$dddddZbe$dddddZce$dddddZde$dddddZee$dddddd Zfe$d!dd"dd#d$Zge$d%dd"dd&d'Zhe$d(ddd)d*Zie$d+ddd,d-d.Zje$d/ddd0d1Zke$d2e$d3ddd4d5d6Zle$d7e$d8ddd4d9d:Zme$d;dd6d6d<d=d>ZndS (y  zhThis file exports ONNX ops for opset 9.

Opset 9 is supported by ONNX release 1.4.1
release on 01/23/19
    )annotationsN)CallableTYPE_CHECKING)
deprecated)_C)
_constants_type_utilserrorssymbolic_helper)GLOBALS)	jit_utilsregistration)Sequence)Number(  absacosaddaddcmuladdmmaliasamaxaminaminmaxarangeargmaxargmin
as_strided	as_tensorasinatanatan2baddbmm
batch_norm	bernoullibitwise_not
bitwise_orbmmbroadcast_tensorsbroadcast_to	bucketizecatcdistceil	clamp_max	clamp_minclampcloneconstant_pad_nd
contiguousconv_tbcconv_transpose1dconv_transpose2dconv_transpose3dconv1dconv2dconv3dconvert_element_typeconvolutioncoscosine_similaritycrosscumsumdetachdimdivdotdropouteluembedding_bag	embedding
empty_likeemptyeqerfexp	expand_asexpandeyefillflattenfloor_dividefloorfloordivfrobenius_norm	full_likefullgathergegeluget_pool_ceil_paddingglu
group_normgthann_window
hardshrinkhardsigmoid	hardswishhardtanh	index_add
index_copy
index_fill	index_putindex_selectindexinstance_normis_floating_point	is_pinnedisnanitemkl_div
layer_normle
leaky_relulerpliftlinalg_crosslinalg_matrix_normlinalg_normlinalg_vector_normlinearlinspacelog_sigmoidlog_softmaxloglog10log1plog2logical_andlogical_not
logical_orlogical_xorlogit	logsumexp	lstm_celllstmltmasked_fillmasked_fill_matmulmax_pool1d_with_indicesmax_pool2d_with_indicesmax_pool3d_with_indicesmaxmaximummeshgridminminimummishmmmovedimmse_lossmulmultinomialmvnarrownative_layer_normneneg	new_emptynew_fullnew_ones	new_zerosnonzero_numpynonzeronormnumelnumpy_Tone_hot	ones_likeonesonnx_placeholderpadpairwise_distancepermutepixel_shufflepixel_unshufflepowpreluprim_constant_chunkprim_constant_splitprim_constant	prim_dataprim_device
prim_dtypeprim_ifprim_layoutprim_list_constructprim_list_unpack	prim_loopprim_maxprim_min
prim_shapeprim_tolistprim_tuple_construct	prim_typeprim_unchecked_castprim_uninitialized	rand_likerandrandint_likerandint
randn_likerandn
reciprocalreflection_padrelurelu6	remainderrepeat_interleaverepeatreplication_pad
reshape_asreshaperollrrelursqrtrsubscalar_tensorscatter_addscatterselectselusigmoidsignsilusinsizeslicesoftmaxsoftplus
softshrinksortsplit_with_sizessplitsqrtsquaresqueezestackstd_meanstdsubttaketantanh
tanhshrinktensor	thresholdtotopk	transposetrue_dividetype_asunbindunfoldunsafe_chunkunsafe_split_with_sizesunsafe_split	unsqueezeunsupported_complex_operatorsnoop_complex_operatorsunusedvar_meanvarview_asviewwherewrap_logical_op_with_cast_towrap_logical_op_with_negation
zeros_likezeroszero	   )opsetstrnamec                   s    fdd}|S )z5Exports the function in the current global namespace.c                   s   | t   < t  | S N)globals__all__appendfuncr   H/var/www/auris/lib/python3.9/site-packages/torch/onnx/symbolic_opset9.pywrapper4  s    

z_export.<locals>.wrapperr  )r  r  r  r  r  _export1  s    r   c                 C  s   |  d}|tj  |S )z%Represents "missing" optional inputs.prim::Constant)opsetTyper   OptionalTypeZofTensor)gnr  r  r  r  <  s    
r  zaten::_shape_as_tensorzjit_utils.GraphContextr%  c                 C  s   |  d|S NShaper"  r%  inputr  r  r  _shape_as_tensorC  s    r-  zaten::_reshape_from_tensorc                 C  s.   t |tr"| jdg|R ddi}t| ||S )NConcataxis_ir   )
isinstancelistr"  r   )r%  r,  shaper  r  r  _reshape_from_tensorH  s    
r3  zaten::reshapeTc                 C  s   t | ||S r  )r
   _reshape_helperr%  selfr2  r  r  r  r   O  s    r   zaten::reshape_asc                 C  s   |  d|}t| ||S r(  r"  r   r%  r6  otherr2  r  r  r  r   U  s    r   z	aten::addc                 C  sZ   t |r&t |r&t dddd|S |rLt t |dkrL| d||}| d||S )a  
    This function takes the add function and returns the corresponding ONNX operator.

    This function is not meant to be called directly by the user.

    Args:
        g (GraphContext): The graph context.
        self (Tensor): The first operand.
        other (Tensor): The second operand.
        alpha (float, optional): The scaling factor for the second operand. Defaults to None.

    Returns:
        ONNX operator.
    Addr     z)Add between list of tensors not supported   Mul)r
   	_is_value_is_tensor_list _onnx_opset_unsupported_detailed_scalar_maybe_get_scalarr"  r%  r6  r9  alphar  r  r  r   \  s    
r   z	aten::subc                 C  s4   |r&t t |dkr&| d||}| d||S )a  
    Consumes sub function and returns the corresponding ONNX operator.

    This function is not meant to be called directly by the user.

    Args:
        g (GraphContext): The graph context.
        self (Tensor): The first operand.
        other (Tensor): The second operand.
        alpha (Optional[Tensor]): A scaling factor to apply to the second operand.
            If `alpha` is not provided, it defaults to 1.

    Returns:
        ONNX operator
    r<  r=  Sub)r
   rA  rB  r"  rC  r  r  r  r   u  s    r   z
aten::rsubc                 C  s   t | |||dS )N)rD  )r   rC  r  r  r  r     s    r   z	aten::mulc                 C  s4   t |r"t |r"| d||S | d||S d S )NAndr=  )r
   _is_boolr"  r%  r6  r9  r  r  r  r     s    r   z	aten::divc                 G  s0   t |dkrt| ||S t| ||g|R  S d S Nr   )lenr   _div_rounding_mode)r%  r6  r9  argsr  r  r  rB     s    rB   zaten::addcmulvf      ?c              	   C  s2   | j dt|gd}t| |t| t| |||S NConstantZvalue_t)r"  torchr   r   r   )r%  r6  Ztensor1Ztensor2valueZ
value_tensr  r  r  r     s    r   sc                 C  sT   |d u rt | ||S |dkr(t| ||S |dkr<t| ||S td| d|d S )NrS   trunczUnsupported rounding mode: "z$". Expected None, "floor" or "trunc")r   _floor_divide_trunc_divider	   SymbolicValueError)r%  r6  r9  Zrounding_moder  r  r  rK    s    
rK  c                 C  s   |  d||}| j d|tjjd}tj|tjj}|tjjkrt	|sjt	|rj| j d|tjj
d}q| j d|| d}n| j d|tjj
d}|S )NDivCastZto_i)r"  _C_onnxTensorProtoDataTypeINT64r   JitScalarType
from_value	UNDEFINEDr
   _is_fpFLOAT	onnx_type)r%  r6  r9  outscalar_typer  r  r  rX    s    rX  c                 C  s   t |st |r,t| ||}| d|S | d||}| jdtjdtjdd}| dt | ||t | ||}| d|| d	||}| d
|| d| d||}| jdtjdtjdd}	| d	||	}
| d||
S d S )NFloorrZ  rQ  r   dtyperR  XorrE  r=  rF  NotEqualr<  )r
   rc  r   r"  rS  r   int64Z
_lt_helper)r%  r6  r9  rf  rB   r  negativemodZ
fixup_maskonefixupr  r  r  rW    s     rW  zaten::floor_dividec                 C  s   t | ||S r  )rX  rH  r  r  r  rR     s    rR   zaten::floordivc                 C  s   t | ||S r  )rR   rH  r  r  r  rT     s    rT   zaten::true_dividec                 C  s   t |st |r"| d||S t }tjj}|tju sJ|tj	u sJJ t tj	u r`tjj
}| jd||d}| jd||d}| d||S )a  Division where both inputs are cast to floating types

    If both inputs are floating, performs div as usual
    If only one input is a floating type, the other input is cast to its type
    If neither input is a floating type, both inputs are cast to the default scalar type
    rZ  r[  r\  )r
   rc  r"  rS  get_default_dtyper]  r^  rd  floatdoubleDOUBLE)r%  r6  r9  rg  Zonnx_scalar_typer  r  r  r     s    r   zaten::reciprocalc                 C  s*   t |s| jd|tjjd}| d|S )Nr[  r\  
Reciprocal)r
   rc  r"  r]  r^  rd  r%  r6  r  r  r  r     s    
r   z	aten::catic                   s   t |}g  |D ]&}t |r.t |ds.q | qt dksJJ t fdd D sdJ |    D ]}| 	| qtt |}| j
dg|R d|iS )a{  Implement concatenation of pytorch tensors in ONNX along the specified `dim` dimension.

    Parameters:
        g (jit_utils.GraphContext): Graph context.
        tensor_list (List[torch.Tensor]): List of tensors to concatenate.
        dim (int): Dimension along which to concatenate the tensors.

    Returns:
        ONNX graph node representing the concatenated tensor.
    r   c                 3  sF   | ]>}t  d  du p<t |du p<t |t  d  kV  qdS r   N)r
   _get_tensor_rank.0r   Znonempty_tensorsr  r  	<genexpr>6  s   zcat.<locals>.<genexpr>r.  r/  )r
   _unpack_list_is_constant_get_tensor_dim_sizer  rJ  allnodeZremoveAllInputsZaddInputr"  )r%  tensor_listrA   tensorsr   r  r~  r  r*     s"    

r*   zaten::stackc                   s2    fddt |D }jdg|R d iS )Nc                   s   g | ]}t | gqS r  r
   _unsqueeze_helperr|  rA   r%  r  r  
<listcomp>H  s   zstack.<locals>.<listcomp>r.  r/  )r
   r  r"  )r%  r  rA   Z
unsqueezedr  r  r  r   E  s    r   z
aten::listc                 C  s   |S r  r  rx  r  r  r  _listO  s    r  zaten::mmc                 C  s,   | j dtdgd}| j d|||dddS )NrQ  r<  rR  Gemm        rO  Zbeta_falpha_fr"  rS  r   )r%  r6  r9  Cr  r  r  r   T  s    r   z	aten::bmmc                 C  s   |  d||S NMatMulr*  rH  r  r  r  r&   \  s    r&   zaten::matmulc                 C  s   |  d||S r  r*  rH  r  r  r  r   a  s    r   zaten::addmmr   c              	   C  sH  d }t |}t |}t |}	|d ur0|}n|d ur>|}n|	d urJ|	}t |}
t |}dd }|d ur&||
ds||dr&| d||}|}t |}t |}|dkr| jdtj|| dd}| d	||}|dkr| jdtjt || dd}| d	||}| d
||S | jd|||t |t |dS )Nc                 S  s   | d uo| |kS r  r  )rM  ur  r  r  is_not_none_norw  s    zaddmm.<locals>.is_not_none_nor   r  r<  rQ  ri  rR  r=  r:  r  r  )r
   _try_get_scalar_typer{  r"  rA  rS  r   rj  )r%  r6  Zmat1Zmat2betarD  rg  self_scalar_typeZmat1_scalar_typeZmat2_scalar_typeZ	mat1_rankZ	mat2_rankr  Zres1Zres2r  r  r  r   f  sX    








r   z	aten::negc                 C  s   |  d|S )NZNegr*  rx  r  r  r  r     s    r   z
aten::sqrtc                 C  sT   t j|t jjt jjt jjt jjt jjt jjhv rH| j	d|t
jjd}| 	d|S )Nr[  r\  Sqrt)r   r`  ra  rb  UINT8INT8INT16INTr_  r"  r]  r^  rd  rx  r  r  r  r     s    
r   zaten::rsqrtc                 C  s"   |  dttd|t| |S )NrZ  r<  )r"  r
   _if_scalar_type_asrS  r   r   rx  r  r  r  r     s    r   z
aten::tanhg      ?   )scaleZ
zero_pointc                 C  s   |  d|S )NTanhr*  rx  r  r  r  r     s    r   z	aten::sinc                 C  s   |  d|S )NZSinr*  rx  r  r  r  r     s    r   z	aten::cosc                 C  s   |  d|S )NZCosr*  rx  r  r  r  r<     s    r<   z	aten::tanc                 C  s   |  d|S )NZTanr*  rx  r  r  r  r     s    r   z
aten::asinc                 C  s   |  d|S )NZAsinr*  rx  r  r  r  r     s    r   z
aten::acosc                 C  s   |  d|S )NZAcosr*  rx  r  r  r  r     s    r   z
aten::atanc                 C  s   |  d|S )NAtanr*  rx  r  r  r  r     s    r   zaten::atan2c              
   C  s   |  d||}|  d|}| j dtdd}| j dttjd}|  d||}|  d||  d|||  d	||}|  d
||}	|  d|	||}
|
S )NrZ  r  rQ  r   rR  GreaterWherer:  rE  Less)r"  rS  r   mathpi)r%  r6  r9  sloper   Z
const_zeroZconst_piZ"condition_second_or_third_quadrantZsecond_third_quadrantZcondition_14_or_23_quadrantresultr  r  r  r      s    r    zaten::sigmoidg      p?c                 C  s   |  d|S )a  Converts the corresponding PyTorch function into ONNX operators.

    It is not meant to be called directly by a user.

    Args:
        g (jit_utils.GraphContext): Graph context.
        self (Tensor): the input tensor.
    Returns:
        ONNX operator
    Sigmoidr*  rx  r  r  r  r     s    r   z
aten::signc                 C  s   |  d|S )NZSignr*  rx  r  r  r  r     s    r   c                 C  sR   t |t |ksJ t |dkr>|d dkr>|d tjkr>|S | jd||||dS )Nr<  r   Slice)axes_iZstarts_iZends_i)rJ  r   	INT64_MAXr"  )r%  r,  axesstartsendsr  r  r  _slice  s    &r  z	aten::sumZ	ReduceSumsum)Zdecoratez
aten::mean
ReduceMeanmeanz
aten::prodZ
ReduceProdprodF)allow_multi_dim_supportboolZonnx_opr  r  c                 C  s   t | ||S r  )r
   Z_reduce_with_dtype_helperr  r  r  r  _reduce_with_dtype  s    r  zaten::cumsumnonec                 C  s   t ddd| d S )Nr?   r  r;  r
   _onnx_opset_unsupported)r%  r,  rA   rj  r  r  r  r?   (  s    r?   zaten::_sample_dirichletc                 C  s   t d|S )N_sample_dirichletr
   _onnx_unsupportedr%  r6  	generatorr  r  r  r  .  s    r  zaten::_standard_gammac                 C  s   t d|S )N_standard_gammar  r  r  r  r  r  3  s    r  zaten::tc                 C  s6   t |}|d u s|dk r&| d|S | jd|ddS )Nr  Identity	Transpose)r<  r   Zperm_i)r
   r{  r"  )r%  r6  rankr  r  r  r   8  s    
zaten::numpy_Tc                 C  s8   t |}|d usJ tttd|}| jd||dS Nr   r  r  )r
   r{  r1  reversedranger"  )r%  r,  ndimpermr  r  r  r   C  s    
r   zaten::expandc              	   C  s   t |d}t |s,| jdt|d}n2t |r^t | t| |d| jdt	dgd}t
jj}t| ||}t| || jdt	dd}t| | d||||}| d||S )zXImplement the expand function for a pytorch tensor in ONNX according to specified `size`isrQ  rR  r   rm  Expandr
   _maybe_get_constr>  r"  rS  
LongTensor_is_packed_listr4  r   r   r   r`  r_  r   r   r  )r%  r6  r   Zimplicitrj  r   neg_onesr  r  r  rN   L  s    

 rN   zaten::broadcast_toc              	   C  s   t |d}t |s,| jdt|d}n2t |r^t | t| |d| jdt	dgd}t
jj}t| ||}t| || jdt	dd}t| | d||||}| d||S )Nr  rQ  rR  r   r  rm  r  r  )r%  r6  r   rj  r   r  r  r  r  r(   a  s    

 r(   zaten::expand_asc                 C  s   t |d}t|tjr|j}|tj}g }t|	 D ]J}t
|||||r:|| | jd|j|dd|d}q:| d|}| d||S )Nr   rQ  T)keepdimrR  r)  r  )r
   r  r0  rS  Tensorrj  r   ru  r  rA   equalr  r  rM   r  r"  )r%  r6  r9  Zself_t	orig_typedimsdr2  r  r  r  rM   u  s    
rM   zaten::embeddingbc                 C  s<   |rt jrtd||dkr.t jr.td | d||S )NzUnsupported: ONNX export of embedding with scale_grad_by_freq=True for training mode. ONNX does not support scaling the gradients.r   zWarning: ONNX export of embedding with padding_idx >= 0 for training mode. ONNX does not support not updating the embedding vector at padding_idx during training.Gather)r   Zexport_trainingr	   rY  warningswarnr"  )r%  weightindicespadding_idxscale_grad_by_freqsparser  r  r  rG     s    
rG   zaten::embedding_bagc
           
      C  s    t |st dS t d|S )Nz%embedding_bag with per_sample_weightsrF   )r
   _is_noner  )
r%  Zembedding_matrixr  offsetsr  moder  Zper_sample_weightsZinclude_last_offsetr  r  r  r  rF     s
    
rF   z
aten::size)Zquantize_outputc                 C  sh   |d u r|  d|S t|ddk rZt|}|d urZt|d| }| j dt|d}t| ||S )Nr)  ry  r   rQ  rR  )r"  r
   r  r{  rS  r   Z_size_helperr%  r6  rA   r  r  r  r  r     s    
r   zaten::transposec                 C  sd   ||kr|S t |}|d urTtt|}|| ||  ||< ||< | jd||dS td|d S )Nr  r  zAUnsupported: ONNX export of transpose for tensor of unknown rank.)r
   r{  r1  r  r"  r	   rY  )r%  r6  Zdim0Zdim1r  r  r  r  r  r     s    
r   zaten::permuter  c                 C  s*   |t tdt|kr|S | jd||dS r  )r1  r  rJ  r"  )r%  r6  r  r  r  r  r     s    r   z
aten::viewc                 C  s   t | ||S r  )r   )r%  r6  r   r  r  r  r    s    r  zaten::view_asc                 C  s   |  d|}t| ||S r(  r7  r8  r  r  r  r
    s    r
  zaten::unsafe_chunkc           	      C  s   |d u rt dddd|S t ||}|d u r<t dd|S || d | }|g||  }|| }|rp|| | jd||||dS )	Nr  r  r;  'Dynamic number of outputs not supportedunknown dimension sizer<  SplitZsplit_ir/  outputs)r
   r@  r  _unimplementedr  r"  )	r%  r6  chunksrA   _outputsr   
split_sizesplitsleftoverr  r  r  r    s    

r  zaten::splitc           
      C  s   t ||st dddd|S t | d}| dkrJt| ||||S t |dd}t ||}|d u r|d ur~|| }nt dddd	|S |g||  }|| }	|	r|	|	 | j
d
||||dS )Nr   r  r;  r  rT  r   ry  r  z$Unknown dimension size not supportedr  r  )r
   _is_split_staticr@  	_node_getr  rA   r   
_get_constr  r  r"  )
r%  r6  split_size_or_sizesrA   r  Z	split_valr  r   r  r  r  r  r  r     s(    



r   zaten::unsafe_splitc                 C  s   t | ||||S r  )r   )r%  r6  r  rA   r  r  r  r  r    s    r  zaten::split_with_sizesc                 C  s2   t ||st dddd|S | jd||||dS )Nr   r  r;  r  r  r  )r
   r  r@  r"  r%  r6  Zsplit_sizesrA   r  r  r  r  r   %  s
    
r   zaten::unsafe_split_with_sizesc                 C  s   t | ||||S r  )r   r  r  r  r  r  /  s    r  zaten::unbindc                   s^   |d u rt dddd|S jd|dg|  |d}|dkrB|gn|} fdd	|D }|S )
Nr   r  r;  r  r  r<  r  c                   s   g | ]}t | gqS r  )r
   _squeeze_helper)r}  rf  r  r  r  r  @  s   zunbind.<locals>.<listcomp>)r
   r@  r"  )r%  r6  rA   r  r  Zsqueezed_outputsr  r  r  r   6  s    
r   zaten::selectc                 C  st   t |}t |s^|dk r^|dkr,tj}n|d }t j| ||g|g|gd}t | ||gS | jd|||dS dS )zImplement the select functionality for a pytorch tensor in ONNX.

    Selects elements from the input tensor along the specified `dim` dimension based on the `index` tensor.
    r   r  r<  r  r  r  r  r/  N)r
   rB  r>  r   r  _slice_helperr  r"  )r%  r6  rA   ri   Z	end_indexZ
slice_noder  r  r  r   F  s    
r   zaten::squarec                 C  s   |  d||S Nr=  r*  rx  r  r  r  r   ]  s    r   zaten::squeezec                 C  sJ  |d u r|  d|S t|dd}|dk rt|}|d urxtdt| d d d t||  d	 d
  ||7 }ntdd|S t||}|d u rtdt| d d t| d d d d  tj	| ||gdS |dkrtdt| d d t| d d d d  |S tdt| d d  tj	| ||gdS )NZSqueezery  rA   r   z'ONNX export squeeze with negative axis - might cause the onnx model to be incorrect. (Negative axis is not supported in ONNX. Axis is converted to & based on input shape at export time. CPassing an tensor of different rank in execution will be incorrect.r   %negative axis with unknown input rankz5This model contains a squeeze operation on dimension z on an input z7with unknown shape. Note that if the size of dimension z of the input zVis not 1, the ONNX model will return an error. Opset version 11 supports squeezing on zMnon-singleton dimensions, it is recommended to export this model using opset zversion 11 or higher.r  r<  z. The size of z%this dimension in the given input is z. The model will zWbe exported without the squeeze node. If the model is intended to be used with dynamic z-input shapes, please use opset version 11 to zexport the model.z. If the model is z_intended to be used with dynamic input shapes, please use opset version 11 to export the model.)
r"  r
   r  r{  r  r  r  r  r  r  )r%  r6  rA   Zsqueeze_dimr  dim_sizer  r  r  r   b  s    




r   zaten::preluc              	   C  s   t |}t |}t|}|d urp|dkrJt | |ttd|d }n&|dkrp|dgkrpt | |dg}d}|d ur|d ur||ksJ d| d| | d||S )Nr  r<  r   z)rank(x) should be >= rank(slope) but got z < PRelu)	r
   r{  _get_tensor_sizesrJ  r  r1  r  r  r"  )r%  r6  r  	self_rankZweight_sizesZweight_rankr  r  r  r     s     


r   z
aten::siluc                 C  s   |  d||  d|S )Nr=  r  r*  r+  r  r  r  r     s    r   z
aten::mishc                 C  s   |  d||  d|  d|S )Nr=  r  Softplusr*  r+  r  r  r  r     s    r   z
aten::reluc                 C  s   t j| d|ddS )NRelu   opset_beforer
   _op_with_optional_float_castr+  r  r  r  r     s    r   zaten::relu6c                 C  s   t | |ddS )Nr      )r/   r+  r  r  r  r     s    r   z
aten::ceilc                 C  s   |  d|S )NCeilr*  r+  r  r  r  r,     s    r,   zaten::floorc                 C  s   |  d|S )Nrh  r*  r+  r  r  r  rS     s    rS   z	aten::lenc                 C  s.   t | || jdtdgd}t| |dgS NrQ  r   rR  )r   r"  rS  r  r
   r  )r%  r6  Zsz_0r  r  r  _len  s    r  zaten::thresholdc                 C  sD   t |dkrt dd|S t |dkr8t dd|S | d|S )Nr   r   znon-zero thresholdznon-zero valuer	  )r
   rA  r  r"  )r%  r6  r   rT  r  r  r  r     s
    r   zaten::leaky_relu_C.Valuert  r%  r,  Znegative_slopeZinplacec                 C  s   | j d||dS )N	LeakyRelur  r*  r  r  r  r  rr     s    
rr   z	aten::gluc                 C  sP   t ||}|d ur$|d dks$J | jd||dd\}}| d|| d|S )Nr  r   r  )r/  r  r=  r  )r
   r  r"  )r%  r,  rA   r  firstsecondr  r  r  r\     s
    r\   zaten::softmaxc              
   C  sb  t |}|d ur|dk r"|| }||d k}|rptt|}|d ||  ||< |d< | jd||d}|d }| jd||d}|r|  dkrt |d	d
}| jd|t	|
 d}|r| jd||d}|S | d|| jd||gdd}| d|}	t j| |	|gd}
| d|	|
}|r^|  dkr^t |d	d
}| jd|t	|
 d}|S )Nr   r<  r  r  r  ZSoftmaxr  r!  ry  rj  r[  r\  rE  	ReduceMaxr  
keepdims_iExpr  rZ  )r
   r{  r1  r  r"  r  kindr  r   r`  re  _reducesum_helper)r%  r,  rA   rj  	input_dimis_transpose_requiredr  r   parsed_dtyperL   r  r  r  r  r     s>    
r   zaten::softplusc                 C  s@   t |d}|dkr4| d| d| d|||S | d|S )NrN  r<  rZ  r  r=  )r
   r  r"  )r%  r6  r  r   Z
beta_constr  r  r  r   C  s     r   zaten::get_pool_ceil_paddingc                   s   t | }|d ur$|t d  nd d u sBtdd D rPt dd| S fddtdtD   fddtdt D   fd	dtdtD fd
dtdtD S )Nc                 s  s   | ]}|d u V  qd S r  r  r}  ry  r  r  r  r  P      z(get_pool_ceil_padding.<locals>.<genexpr>r[   input size not accessiblec              	     sB   g | ]:}t t | d |   |  t|  d qS r  r<  )intr  r,   rt  r"  )rA   kernel_sizepaddingstrider  r  r  T  s   0z)get_pool_ceil_padding.<locals>.<listcomp>r   c                   sD   g | ]<} | d  |  | |  kr8 | d  n | qS r<  r  r"  )ceiled_output_dimrA   r(  r)  r  r  r  Z  s   "c                   sP   g | ]H}| d krdn2| | d|    | d  |  d    qS r<  r   r  r  r"  )r+  rA   r'  r(  r)  r  r  r  b  s   

c                   sd   g | ]\}| d |    | krT|  | d k rDt | q^t  | d n
t | qS r%  r&  r"  )r'  r(  padding_ceilr  r  r  r  s   
)r
   r  rJ  anyr  r  )r,  r'  r)  r(  sizesr  )r+  rA   r'  r(  r.  r)  r  r[   K  s&    

r[   zaten::max_pool1dZ
max_pool1dr<  )return_indiceszaten::max_pool2dZ
max_pool2dr  zaten::max_pool3dZ
max_pool3d   c              	     s>   t ddddddt dddddd fdd}|S )NTFrM  r  ry  c                   s<  t |dhkr t d|S |s(|}t|}|rdt||||}|tdd t||D  }n|d }|||d}r| jd|fddi|\}	}
| jd|dd	d
 tD dd
 tD d\}}tj| |dd
 tD t	dt	dd}t
| |
|}
|	|
fS | jd|fddi|}	|	S d S )Nr<  dilationc                 s  s   | ]\}}|| V  qd S r  r  r}  ar  r  r  r  r    r#  z1_max_pool.<locals>.symbolic_fn.<locals>.<genexpr>r  )kernel_shape_ipads_i	strides_iMaxPoolr  c                 S  s   g | ]}d qS r*  r  r}  _r  r  r  r    r#  z2_max_pool.<locals>.symbolic_fn.<locals>.<listcomp>c                 S  s   g | ]}d qS r*  r  r:  r  r  r  r    r#  )r  r6  r8  c                 S  s   g | ]}d | qS )r  r  r"  r  r  r  r    r#  r   r  )setr
   r  tupler[   zipr"  r  r  r1  r   )r%  r,  r'  r)  r(  r3  	ceil_moder.  kwargsrr  r;  Zflattened_indicesrU  r  ndimsr1  tuple_fnr  r  symbolic_fn  sB    


z_max_pool.<locals>.symbolic_fnr
   quantized_args
parse_args)r  rD  rC  r1  rE  r  rB  r  	_max_pool  s    4rI  zaten::max_pool1d_with_indicesr   zaten::max_pool2d_with_indicesr   zaten::max_pool3d_with_indicesr   zaten::avg_pool1dZ
avg_pool1dzaten::avg_pool2dZ
avg_pool2dzaten::avg_pool3dZ
avg_pool3dc              
     sD   t dt ddddddddddddd	d	d
 fdd}|S )NTrM  r  ry  r  r  Sequence[int]zint | Sequence[int]r&  )r,  r'  r)  r(  r?  count_include_padc              	     s   |s|}t |||| }t|ts*J |}|r^t j| d|d| d dddd}dt| }|rt||||}	|td	d
 t|	|D  }n|d }| jd||||d}
|
S )NPad)r   r   r  constantr  r;  r7  mode_sZvalue_fr  r   c                 s  s   | ]\}}|| V  qd S r  r  r4  r  r  r  r  +  s   z1_avg_pool.<locals>.symbolic_fn.<locals>.<genexpr>AveragePool)r6  r8  r7  )	r
   Z_avgpool_helperr0  r=  r  rJ  r[   r>  r"  )r%  r,  r'  r)  r(  r?  rK  Zdivisor_overrideZadjusted_paddingr.  outputr  rD  r  r  rE    s@    
	
z_avg_pool.<locals>.symbolic_fn)NrF  )r  rD  rE  r  rS  r  	_avg_pool  s
    	 $1rT  zaten::adaptive_avg_pool1dZadaptive_avg_pool1drQ  zaten::adaptive_avg_pool2dZadaptive_avg_pool2dzaten::adaptive_avg_pool3dZadaptive_avg_pool3dzaten::adaptive_max_pool1dZadaptive_max_pool1dr9  zaten::adaptive_max_pool2dZadaptive_max_pool2dzaten::adaptive_max_pool3dZadaptive_max_pool3dc                   s"   t dd fdd}|S )NTFc              	     s  }zt dW n ty2   t d| Y S 0 dgt krZdkrZ| d|S t |}z|dd   W n ty   d  Y n0  d u stdd  D rֈdgt kr| d	|d fS t d
|S  fddt	dt D }|dgt| kr:dgt kr,| d	|d fS t d|S  fddt	dt D }dkr| |||dt  dt  dS | j|||d}|S )Nr  z4adaptive pooling, since output_size is not constant.r<  rQ  ZGlobalAveragePoolr  c                 s  s   | ]}|d u V  qd S r  r  r"  r  r  r  r    r#  z6_adaptive_pool.<locals>.symbolic_fn.<locals>.<genexpr>ZGlobalMaxPoolr$  c                   s   g | ]} | |  qS r  r  r"  rA   output_sizer  r  r    r#  z7_adaptive_pool.<locals>.symbolic_fn.<locals>.<listcomp>r   z-output size that are not factor of input sizec                   s    g | ]}t  | |  qS r  r-  r"  rU  r  r  r    r#  r9  rP  r*  F)r6  r8  )
r
   
_parse_arg	Exceptionr  rJ  r"  r  r/  r  r  )r%  r,  rV  Zoutput_size_valuer0  rp  krR  fnr  rD  typerU  r  rE  |  s@    



$z#_adaptive_pool.<locals>.symbolic_fn)r
   rG  )r  r\  rD  r[  rE  r  rZ  r  _adaptive_pool<  s    @
1r]  r&  rA   c                 C  sF   t |dd dg| d t|   }|ddd |ddd  }|S )zGenerate paddings in ONNX order based on pad in pytorch.
    Args:
        dim: the dimension of the tensor.
        pad: the paddings in pytorch.
            The order is dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, ...
    Nr   r  r  )r1  rJ  )rA   r   paddingsr  r  r  _prepare_onnx_paddings  s    &ra  c              
   C  sf   t | d}t |rbt |rbt |}zdd |D }W n$ ty`   t dddd|  Y S 0 |S )Nr  c                 S  s   g | ]}t |d dqS )ry  r(  )r
   r  )r}  rM  r  r  r  r    s   z)_convert_padding_node.<locals>.<listcomp>rL  r  r;  z)The sizes of the padding must be constant)r
   r  r>  r  r  rX  r@  )r,  r(  
input_listr  r  r  _convert_padding_node  s    



rc  zaten::constant_pad_ndc              
   C  sl   d}zt |dd}W n$ ty:   t dddd| Y S 0 t|}tt ||}t j| d||||ddS )	NrM  rN  rT  rL  r  r;  z*The value for the padding must be constantrN  )r
   r  rX  r@  rc  ra  r{  r  )r%  r,  r(  rT  r  r`  r  r  r  r1     s    

r1   )r%  r,  r   c                 C  sL  t |}t|d dksJ t|d }|}t|D ]}|d| d   }|d| d   }g }	|dkrtj| |d| g| gtjgd}
|	|
 |dk s|dk rt	d| }t	d|  }tj| |d| g|g|gd}|	| n
|	| |dkr*tj| |d| gdg|gd}|	| | j
dg|	R dd| i}q4|S )Nr  r   r<  r  r.  r/  )rc  rJ  r  r
   r  r   r  r  builtinsr   r"  )r%  r,  r   r(  r  curidxZpad_rZpad_lr  leftstartendZmiddlerightr  r  r  _pad_circular  s@    



rk  zaten::reflection_pad1dzaten::reflection_pad2dzaten::reflection_pad3dc                 C  s2   d}t |}tt||}tj| d|||ddS )NreflectrL  r;  r7  rO  r  rc  ra  r
   r{  r  r%  r,  r(  r  r`  r  r  r  r     s    r   zaten::replication_pad1dzaten::replication_pad2dzaten::replication_pad3dc                 C  s2   d}t |}tt||}tj| d|||ddS )NZedgerL  r;  rm  rn  ro  r  r  r  r     s    r   z	aten::padr%  r,  r   r  rT  c                 C  st   t |d}|dkr t| ||S |dkr4t| ||S |dkrJt| |||S |dkr^t| ||S td| |d S )NrU  Z	replicaterl  rM  ZcircularzUnrecognized padding mode )r
   rW  r   r   r1   rk  r	   rY  rp  r  r  r  r   '  s    r   zaten::upsample_nearest1dZupsample_nearest1dZnearestzaten::upsample_nearest2dZupsample_nearest2d   zaten::upsample_nearest3dZupsample_nearest3d   zaten::upsample_linear1dZupsample_linear1dry   zaten::upsample_bilinear2dZupsample_bilinear2dzaten::upsample_trilinear3dZupsample_trilinear3d)r  rA   interpolate_modec                   s    fdd}|S )Nc                   sb   t | |\}}t  t |}|r8t d|S |d u rPt | || }| jd||dS )Nzalign_corners == TrueUpsamplerO  )r
   Z_get_interpolate_attributesZ_interpolate_warningrB  r  Z_interpolate_size_to_scalesr"  )r%  r,  rV  rL  scalesalign_cornersrA   rs  r  r  r  rE  g  s    

z!_interpolate.<locals>.symbolic_fnr  )r  rA   rs  rE  r  rx  r  _interpolate<  s    +ry  zaten::__interpolatec           	      C  s*   t | |||||\}}| jd|||dS )Nrt  ru  )r
   Z _interpolate_get_scales_and_moder"  )	r%  r,  r   Zscale_factorr  rw  Zrecompute_scale_factorZ	antialiasrv  r  r  r  __interpolatex  s    rz  zaten::bitwise_notc                 C  s"   t |std|| d|S NzOONNX export does NOT support exporting bitwise Not for non-boolean input valuesrl  r
   rG  r	   rY  r"  r+  r  r  r  r$     s    
r$   zaten::bitwise_orc                 C  s:   t |std|t |s,td|| d||S )NzVONNX export does NOT support exporting bitwise OR for non-boolean input values. self: zWONNX export does NOT support exporting bitwise OR for non-boolean input values. other: Orr|  rH  r  r  r  r%     s    

r%   c                   s    fdd}|S )Nc                   s   t   fdd}|S )Nc                   s,   t  d  } | || |d|| |dS )NZ_cast_F)r  )r%  r,  r9  Zto_cast_func)r[  to_typer  r  wrap_with_cast  s    zGwrap_logical_op_with_cast_to.<locals>.decorator.<locals>.wrap_with_cast	functoolswraps)r[  r  r~  )r[  r  	decorator  s    z/wrap_logical_op_with_cast_to.<locals>.decoratorr  )r~  r  r  r  r  r    s    r  r   )r  returnc                   s   t   fdd}|S )Nc                   s   |  d | ||S )Nrl  r*  r%  r,  r9  r  r  r  wrap_with_not  s    z4wrap_logical_op_with_negation.<locals>.wrap_with_notr  )r  r  r  r  r  r    s    r  zaten::__not_c                 C  s"   t |std|| d|S r{  r|  rx  r  r  r  __not_  s    
r  zaten::eqc                 C  s   t | tjr:t | tjr:| jdtjdtjddS | }| }|	 |	   krfdkrn nN|
d|
d  krdkrn n*| jdtj|d|dktjddS | d||S )	NrQ  Tri  rR  onnx::ConstantrT  rU  rm  )r0  r\  r   DeviceObjTyper"  rS  r   r  r  r  kindOfrU  )r%  r6  r9  Z	self_nodeZ
other_noder  r  r  rJ     s     
 $rJ   zaten::nec                 C  s   t | ||S r  )rJ   rH  r  r  r  r     s    r   zaten::gtc                 C  s   t | ||S r  _gt_implr  r  r  r  r^     s    r^   c                 C  sJ   t |r<t |r<| jd|tjjd}| jd|tjjd}| d||S )Nr[  r\  r  r
   rG  r"  r]  r^  INT32r  r  r  r  r    s    r  zaten::ltc                 C  s   t | ||S r  _lt_implr  r  r  r  r     s    r   c                 C  sJ   t |r<t |r<| jd|tjjd}| jd|tjjd}| d||S )Nr[  r\  r  r  r  r  r  r  r    s    r  zaten::gec                 C  s   t | ||S r  r  r  r  r  r  rY      s    rY   zaten::lec                 C  s   t | ||S r  r  r  r  r  r  rq     s    rq   zaten::__and_c                 C  s:   t |std|t |s,td|| d||S )NzOONNX export does NOT support exporting bitwise AND for non-boolean input valuesrF  r|  r  r  r  r  __and_  s    

r  zaten::__or_c                 C  s:   t |std|t |s,td|| d||S )NzNONNX export does NOT support exporting bitwise OR for non-boolean input valuesr}  r|  r  r  r  r  __or_  s    

r  zaten::__xor_c                 C  s:   t |std|t |s,td|| d||S )NzOONNX export does NOT support exporting bitwise XOR for non-boolean input valuesrk  r|  r  r  r  r  __xor_0  s    

r  zaten::logical_andZBoolc                 C  s   |  d||S )NrF  r*  r  r  r  r  r   A  s    r   zaten::logical_orc                 C  s   |  d||S )Nr}  r*  r  r  r  r  r   G  s    r   zaten::logical_xorc                 C  s   |  d||S )Nrk  r*  r  r  r  r  r   M  s    r   zaten::logical_notc                 C  s   |  d| j d|tjjdS )Nrl  r[  r\  r"  r]  r^  BOOLr+  r  r  r  r   S  s    r   zaten::__rshift_c                 C  s   t j|}t j|t jj|kr6| jd|| d}| jdtjdtjdd}t	
|sn| jd|tjjd}| d||}| jd|| d}| d||}|S )	Nr[  r\  rQ  r  ri  rR  PowrZ  r   r`  ra  rb  r"  re  rS  r   float32r
   rc  r]  r^  rd  )r%  r6  r9  r  twotwo_powrshiftr  r  r  	__rshift_X  s*    
r  zaten::__lshift_c                 C  s   t j|}t j|t jj|kr6| jd|| d}| jdtjdtjdd}t	
|sn| jd|tjjd}| d||}| jd|| d}| d||}|S )	Nr[  r\  rQ  r  ri  rR  r  r=  r  )r%  r6  r9  r  r  r  lshiftr  r  r  	__lshift_u  s*    
r  zaten::wherec              	   C  s`   t |s| jd|tjjd}|d u rPt| |}t | || jdt	dd|S | d|||S )Nr[  r\  rQ  r<  rR  r  )
r
   rG  r"  r]  r^  r  r   Z_unbind_helperrS  r   )r%  	conditionr6  r9  r  r  r  r  r    s    

r  zaten::log_softmaxc           	      C  s   t |}|d u rt ddS |dk r.|| }||d k}|r|tt|}|d ||  ||< |d< | jd||d}|d }| jd||d	}|r|  d
krt |dd}| jd|t	
| d}|r| jd||d}|S )NrA   fONNX and PyTorch use different strategies to split the input. Input rank must be known at export time.r   r<  r  r  r  Z
LogSoftmaxr  r!  ry  rj  r[  r\  )r
   r{  r  r1  r  r"  r  r  r  r   r`  re  )	r%  r,  rA   rj  r  r   r  Z	return_opr!  r  r  r  r|     s.    
r|   zaten::_log_softmaxc                 C  s>   |r2t j|t jjt jjkr2| jd|tjjd}t	| ||S Nr[  r\  )
r   r`  ra  rb  HALFr"  r]  r^  rd  r|   )r%  r,  rA   Zhalf_to_floatr  r  r  _log_softmax  s    r  zaten::_convolutionc                 C  s&  t |}z|dd  }W n ty0   d }Y n0 |d u sLtdd |D rXtd|||g}t |st |dkr|| |dd  ||| ||	d}tdd |D r|sJ t	|t	|ksJ ||d< | j
|rd	nd
g|R i |}t |st |dkr| 
d||S |S d S )Nr  c                 s  s   | ]}|d u V  qd S r  r  r"  r  r  r  r    r#  z_convolution.<locals>.<genexpr>DUnsupported: ONNX export of convolution for kernel of unknown shape.r<  )r6  r8  r7  dilations_igroup_ic                 s  s   | ]}|d kV  qdS rz  r  )r}  or  r  r  r  	  r#  Zoutput_padding_iZConvTransposeConvr:  )r
   r  rX  r/  r	   rY  r  r{  r  rJ  r"  )r%  r,  r  biasr)  r(  r3  
transposedoutput_paddinggroupsZ	benchmarkZdeterministiccudnn_enabledZ
allow_tf32weight_sizekernel_shaperL  r@  r&  r  r  r  _convolution  sB    




 r  zaten::_convolution_modec                 C  s   t |}z|dd  }	W n ty0   d }	Y n0 |	d u sLtdd |	D rXtd|||g}
t |st |dkr|
| |dkrd}n|dkrd	}|dd  ||||d
}| j	dg|
R i |}t |st |dkr| 	d||S |S d S )Nr  c                 s  s   | ]}|d u V  qd S r  r  r"  r  r  r  r  2	  r#  z$_convolution_mode.<locals>.<genexpr>r  r<  validZVALIDsameZ
SAME_UPPER)r6  r8  Z
auto_pad_sr  r  r  r:  )
r
   r  rX  r/  r	   rY  r  r{  r  r"  )r%  r,  r  r  r)  r(  r3  r  r  r  rL  r@  r&  r  r  r  _convolution_mode	  sB    



r  zaten::convolutionc
           
      C  s"   t | |||||||||	d d d d S r  r  )
r%  r,  r  r  r)  r(  r3  r  r  r  r  r  r  r;   W	  s     r;   zaten::conv1dc           	      C  s\   t |d}|dv r*t| |||||||S t |d}t| ||||||dd|d d d d S d S NrU  )r  r  r  Fr  r
   rW  r  r  	r%  r,  r  r  r)  r(  r3  r  Zstr_paddingr  r  r  r7   w	  s:    r7   zaten::conv2dc           	      C  s\   t |d}|dv r*t| |||||||S t |d}t| ||||||dd|d d d d S d S r  r  r  r  r  r  r8   	  s:    r8   zaten::conv3dc           	      C  s\   t |d}|dv r*t| |||||||S t |d}t| ||||||dd|d d d d S d S r  r  r  r  r  r  r9   	  s:    r9   zaten::conv_transpose1dc	           	      C  s"   t | ||||||d||d d d d S NTr  	r%  r,  r  r  r)  r(  r  r  r3  r  r  r  r4   	  s     r4   zaten::conv_transpose2dc	           	      C  s"   t | ||||||d||d d d d S r  r  r  r  r  r  r5   
  s     r5   zaten::conv_transpose3dc	           	      C  s"   t | ||||||d||d d d d S r  r  r  r  r  r  r6   $
  s     r6   zaten::batch_normc
                 C  s   t |d t rDt |||||gsDtjdk rDt dddd|S t | |||||\}}}}| j	d||||||d| |sdndd	}
|s|
S |
\}}}}}|
|  |
|  |d	|   |d	|   |S d S )
Nr"      ZBatchNormalizationr  zaAll input tensors must have the same `dtype`. Turn off Autocast or export using opset version 15.r<  rr  )	epsilon_fZ
momentum_fr  zbatch_norm_dead_output-)r
   check_training_moderS  Zis_autocast_enabledZargs_have_same_dtyper   export_onnx_opset_versionr@  Z_batchnorm_helperr"  r#  r\  ZsetDebugNameZ	debugName)r%  r,  r  r  running_meanrunning_vartrainingmomentumepsr  rf  resZnew_running_meanZnew_running_varZ
saved_meanZ	saved_varr  r  r  r"   C
  sL    	
r"   zaten::native_layer_normrJ  z#tuple[_C.Value, _C.Value, _C.Value])r%  r,  normalized_shaper  r  r  r  c              
   C  s  dd t t|ddD }t| d}t| |}| jdk rN| jd||d}	n$| d|| jd	tj|tjd
d}	t	| ||	}
t
j|
t
jjk}|rt
j|}| jd|
t
| d}
| jdk r| jdt| |
||d}n,| dt| |
|| jd	tj|tjd
d}t| | d||}| d|
|}|rZt
j|}| jd|t
| d}|d u s|t|s|t| ||}|d u st|st| ||}|r| jd|t
| d}| d|}n
t| |}||	|fS )Nc                 S  s   g | ]
}| qS r  r  r"  r  r  r  r  
  r#  z%native_layer_norm.<locals>.<listcomp>r   r         @   r  r  rQ  ri  rR  r[  r\  r:  rZ  rw  )r  rJ  r
   _generate_wrapped_numberr  r"  rS  r   longr   r   r`  ra  r  re  r   r   r  r   r   r   )r%  r,  r  r  r  r  r  Ztwo_cstZeps_cstr  	numeratorZis_type_halfZ	eps_dtypeZvariancedenominator
normalizedZinput_dtypeZrdenominatorr  r  r  r   |
  s^    




r   zaten::layer_norm)r%  r,  r  r  r  r  cudnn_enabler  c           	      C  s   t | |||||\}}}|S r  )r   )	r%  r,  r  r  r  r  r  r  r;  r  r  r  rp   
  s    rp   zaten::instance_normr   )r%  use_input_statsr  r  r  c
                 C  s,  t |d t |d}
|d u s*t |rl|
d u r>td|tjdg|
 tj	
| d}| jd|d}|d u s~t |r|
d u rtd|tjdg|
 tj	
| d}| jd|d}|d u st |s|d u st |r| jd	||||d
S t |}| }|d }|d u r(td||d }d|d< || |d< t| || jdtj|gtjdd}t| || jdtj|gtjdd}t| || jdtj|gtjdd}t| || jdtj|gtjdd}| d|| jdt|d}t| |||||||||	
}t| || jdt|dS d S )Nrj   r<  zCUnsupported: ONNX export of instance_norm for unknown channel size.rO  ri  rQ  rR  r  InstanceNormalizationr  r   zJUnsupported: ONNX export of instance_norm training for unknown batch size.ZReshape)r
   r  r  r  r	   rY  rS  r   r   r`  ra  rj  r"  r  copyr   rn  r  r"   r  )r%  r,  r  r  r  r  r  r  r  r  channel_sizeweight_value
bias_value
input_sizeZinput_size_reshaper&  cweight_bias_Zrunning_mean_Zrunning_var_input_reshapedrf  r  r  r  rj   
  s    

rj   zaten::unfoldc                   s   t }z|  }W n ty,   d }Y n0 |d urtd||}t||d |} fddt||D }	t|}
ttd|
   fdd|	D }j	dg|R d iS t 
dd	S d S )
Nr   r<  c              	     s*   g | ]"\}}t j g|g|gd qS )r  r
   r  )r}  lowhi)	dimensionr%  r,  r  r  r  J  s   zunfold.<locals>.<listcomp>c              
     s(   g | ] }t jd |d gqS )r  r  )r
   r  r"  r|  )r  r%  r  r  r  r  S  s   r.  r/  ZUnfoldr$  )r
   r  rX  r  r>  rJ  r1  r  popr"  r  )r%  r,  r  r   stepr0  ZsizedimZlow_indicesZ
hi_indicesr   r  r  r  )r  r%  r,  r  r  r   <  s*    

r   z	aten::eluc                 C  sJ   |r|dkrt dd|S |r4|dkr4t dd|S | jd|t |dS )NrO  r  zdoes not support scale in Eluinput_scalez#does not support input_scale in EluElur  )r
   r  r"  rA  )r%  r,  rD  r  r  r  r  r  rE   `  s    rE   z
aten::seluc                 C  s   |  d|S )NZSelur*  r+  r  r  r  r   p  s    r   zaten::index_selectc                 C  s   t | |||S r  )r
   _select_helper)r%  r6  rA   ri   r  r  r  rh   v  s    rh   zaten::index_putc                 C  s\   t |rt |}n|g}t |d}t|dkrH|rDt| ||S |S t ddd| d S )Nr  r   rg   r  r;  )r
   r  r  rW  rJ  r   r  )r%  r6  Zindices_list_valuevalues
accumulateZindices_listr  r  r  rg     s    
rg   zaten::index_fillc                 C  sH   t | |||\}}t |}t ||}t| ||d }t| ||||S r  )r
   _index_fill_reshape_helperrB  r  rN   r   )r%  r6  rA   ri   rT  Zexpanded_index_shapeexpanded_indexZexpanded_valuer  r  r  rf     s    
rf   zaten::index_copyc                 C  s$   t | |||\}}t| ||||S r  )r
   r  r   )r%  r6  rA   ri   sourceZ_expanded_index_shaper  r  r  r  re     s    re   zaten::bucketizec                 C  s   t jj}|rt jj}| jd| d|| d|dd}t|}|d usLJ ttd|d }t	| t
| |||d }	|rt| ||	}
nt| ||	}
| jd|
|d}tj| |dgddS )	Nr.  r)  r   r  r<  r[  r\  r  )r]  r^  r_  r  r"  r
   r{  r1  r  rN   r  rY   r^   r  )r%  r6  Z
boundariesZ	out_int32rj  Zout_type	new_shapeZtensor_rankZunsqueeze_axesZexpanded_boundariescondZcond_outr  r  r  r)     s$    "

r)   zaten::type_asc                 C  sT   t |}t |}||kr(|d ur(|S |d urD| jd|| dS td|d S )Nr[  r\  zUnsupported: ONNX export of type_as for tensor of unknown dtype. Please check if the dtype of the parameter passed to the type_as function is correct.)r
   r  r"  re  r	   rY  )r%  r6  r9  
self_dtypeZother_dtyper  r  r  r     s    

r   zaten::cosine_similarityc           	      C  s   t j| t| |||gdd}t j| t| |||gdd}t j| t| |||gdd}t| t| t| ||| jdt|gd}t| ||S )Nr   r  rQ  rR  )	r
   r  r   r   r   r"  rS  r   rB   )	r%  x1x2rA   r  r>   Zx1_l2Zx2_l2Zdiv_tensr  r  r  r=     s    &r=   zaten::pairwise_distancec                 C  s   t |s | jdt|gd}t| | jdtjdgtjddt| ||}t j| t	| t
| |||dgt |dd}t	| ||S )NrQ  rR  r<  ri  r  ry  r  )r
   r>  r"  rS  r   rB   rt  r   r  r   r   rW  )r%  Zinput1Zinput2pr  r  Zinv_pZ	summationr  r  r  r     s    


r   zaten::clonec                 C  s   |S r  r  )r%  r,  Zunused_memory_formatr  r  r  r0     s    r0   z	aten::absc                 C  s   |  d|S )NAbsr*  rx  r  r  r  r     s    r   z	aten::logc                 C  s   |  d|S )NLogr*  rx  r  r  r  r}     s    r}   zaten::log1pc              	   C  s    t | t| ttd||S )Nr<  )r}   r   r
   r  rS  r   rx  r  r  r  r     s    r   zaten::log10c              	   C  s*   d}|  dt| || j dt|gdS )NgUk@rZ  rQ  rR  r"  r}   rS  r   )r%  r6  Z_ln10r  r  r  r~     s    r~   z	aten::powc                 C  sb   t j|}t|s2t jj}| jd|| d}t|sP| jd|| d}| d||}|S )Nr[  r\  r  )r   r`  ra  r
   rc  rd  r"  re  )r%  r6  exponentZf_dtyper   r  r  r  r     s    

r   zaten::clampc              	   C  s~   t |rt| ||S t |r,t| ||S t |rft |rft j| d|t |dt |dddS t| t| |||S d S )NCliprN     min_fmax_fr  )r
   r  r-   r.   r  r  rW  )r%  r6  r   r   r  r  r  r/   ,  s    



	r/   zaten::clamp_minc                 C  s^   t |r&t j| d|t |dddS tj|}| jd|| d}t j| d||ddS d S )	Nr  rN  r  )r  r  r[  r\  Maxr  	r
   r  r  rW  r   r`  ra  r"  re  )r%  r6  r   rj  r  r  r  r.   B  s    

r.   zaten::clamp_maxc                 C  s^   t |r&t j| d|t |dddS tj|}| jd|| d}t j| d||ddS d S )	Nr  rN  r  )r  r  r[  r\  ZMinr  r  )r%  r6  r   rj  r  r  r  r-   Q  s    

r-   z	aten::maxc                 C  s   t | |||S r  )r
   Z_max_helperr%  r6  dim_or_yr  r  r  r  r   `  s    r   zaten::maximumc                 C  s   t | ||dS N)r  )r   r  r  r  r  r   h  s    r   z	aten::minc                 C  s   t | |||S r  )r
   Z_min_helperr  r  r  r  r   n  s    r   zaten::minimumc                 C  s   t | ||dS r  )r   r  r  r  r  r   t  s    r   z
aten::amaxc                 C  s   | j d|||dS )Nr  r  r*  r%  r6  rA   r  r  r  r  r   z  s    r   z
aten::aminc                 C  s   | j d|||dS )N	ReduceMinr  r*  r  r  r  r  r     s    r   zaten::aminmaxc                 C  sR   d|i}t |s*t |dd}|g|d< | jd|fi || jd|fi |fS )Nr  ry  rA   r  r  r  )r
   r  r  r"  )r%  r6  rA   r  Zreduce_kwargsr  r  r  r     s    

r   z	aten::expc                 C  s   |  d|S )Nr  r*  rx  r  r  r  rL     s    rL   zaten::dropout_zaten::dropoutc                 C  s.   t |d |s|S | jd||dd\}}|S )NrD   ZDropoutr  )Zratio_fr  )r
   r  r"  )r%  r,  r  trainrA  r;  r  r  r  rD     s
    rD   zaten::alpha_dropout_zaten::feature_alpha_dropout_zaten::feature_dropout_zaten::feature_alpha_dropoutzaten::alpha_dropoutzaten::feature_dropoutc                   s   t ddd fdd}|S )NrM  r  r  c                   s   |rt  d|S |S )Nztraining mode)r
   r  )r%  r,  r  r  r  r  r  feature_dropout  s    z-_unsupported_dropout.<locals>.feature_dropoutr
   rH  )r  r  r  r  r  _unsupported_dropout  s    r  z
aten::normc                 C  sx   |dkrt d}n |dkr(t d}ntd||| |||d}|d urtt |dd}| jd	|t| d
}|S )Nr<  ZReduceL1r  ZReduceL2z)ONNX export only p-norms with p of 1 or 2)rA   r  ry  rj  r[  r\  )	r
   Z_reduce_op_symbolic_helperr	   rY  r  r"  r   r`  re  )r%  r6  r  rA   r  rj  rN  r  r  r  r  r     s    r   zaten::conv_tbcc              	   C  sX   | j d|g dd}| j d|g dd}t| |||dg|gdgd}| j d|g ddS )Nr  )r<  r  r   r  )r  r<  r   r<  )r  r   r<  )r"  r7   )r%  r,  r  r  r   convr  r  r  r3     s    r3   zaten::_uniquec                 C  s   t d|S )N_uniquer  )r%  r,  sortedreturn_inverser  r  r  r    s    r  zaten::_unique2c                 C  s   t ddd| d S )N_unique2r  r;  r  )r%  r,  r  r  Zreturn_countsr  r  r  r    s    r  zaten::_cast_Bytez8Avoid using this function and create a Cast node insteadc                 C  s   | j d|tjjdS r  )r"  r]  r^  r  r%  r,  Znon_blockingr  r  r  
_cast_Byte  s    r  zaten::_cast_Charc                 C  s   | j d|tjjdS r  )r"  r]  r^  r  r  r  r  r  
_cast_Char  s    r	  zaten::_cast_Shortc                 C  s   | j d|tjjdS r  )r"  r]  r^  r  r  r  r  r  _cast_Short  s    r
  zaten::_cast_Intc                 C  s   | j d|tjjdS r  )r"  r]  r^  r  r  r  r  r  	_cast_Int  s    r  zaten::_cast_Longc                 C  s   | j d|tjjdS r  )r"  r]  r^  r_  r  r  r  r  
_cast_Long  s    r  zaten::_cast_Halfc                 C  s   | j d|tjjdS r  )r"  r]  r^  ZFLOAT16r  r  r  r  
_cast_Half  s    r  zaten::_cast_Floatc                 C  s   | j d|tjjdS r  )r"  r]  r^  rd  r  r  r  r  _cast_Float  s    r  zaten::_cast_Doublec                 C  s   | j d|tjjdS r  )r"  r]  r^  rv  r  r  r  r  _cast_Double   s    r  zaten::_cast_Boolc                 C  s   | j d|tjjdS r  r  r  r  r  r  
_cast_Bool&  s    r  zaten::emptyc                 C  s   t | |||||S r  )r  )r%  r0  rj  layoutdevice
pin_memorymemory_formatr  r  r  rI   ,  s    rI   zaten::empty_likec                 C  s   t | |||||S r  )r  )r%  r,  rj  r  r  r  r  r  r  r  rH   :  s    rH   zaten::new_emptyc                 C  s2   t |}t |r |d ur |}t| |||||S r  )r
   r  r  rI   r%  r6  r0  rj  r  r  r  r  r  r  r  r   H  s    
r   zaten::scalar_tensorc                 G  s<   t |dd}|d u rtjj}| jd|t| d}|S )Nry  rj  r[  r\  )r
   r  r   r`  rd  r"  re  )r%  Zscalarrj  optionsr  r  r  r   R  s
    r   zaten::tensorc                 C  s  t |dd}t |r|d u r6tjt |d }g }t |D ]L}| jdt	dgd}t 
| ||}| jd|t| d}|| qD| jd	g|R d
diS |d u rtj|}t |rt |st |r| jd|ddd}| jd|t| dS )Nry  rj  r   rQ  r<  rR  r[  r\  r.  r/  ZConcatFromSequence)r/  Z
new_axis_i)r
   r  r  r   r`  ra  r  r"  rS  r  r4  re  r  Z_is_listr?  Z_is_scalar_list)r%  datarj  r  requires_gradrb  r   Zshape_referencer  r  r  r   [  s,    

r   zaten::as_tensorc                 C  s   t | |||S r  )r   )r%  r  rj  r  r  r  r  r   w  s    r   zaten::zerosc                 C  sz   |d u rt jj}n
t |}t|d}t|trZt|dkrZ| jdt	
g t	jd}| jd|t	j
dg| ddS )Nr  r   rQ  rR  ConstantOfShaperi  r   r`  rd  r
   r  r0  r1  rJ  r"  rS  r   r   rn  rj  r%  r0  rj  r  r  r  rg  sizes_r  r  r  r  |  s    

r  zaten::zeros_likec           	      C  sT   |  d|}t|r*tj|tjj}n
t|}| j d|tjdg|	 ddS )Nr)  r  r   ri  rR  
r"  r
   r  r   r`  ra  rd  rS  r   rj  	r%  r,  rj  r  r  r  r  r2  rg  r  r  r  r    s    

r  zaten::new_zerosc                 C  s2   t |}t |r |d ur |}t| |||||S r  )r
   r  r  r  r  r  r  r  r     s    
r   z
aten::zeroc                 C  s   t |}t| ||S r  )r
   r  r  )r%  r6  r  r  r  r  r    s    
r  z
aten::onesc                 C  sz   |d u rt jj}n
t |}t|d}t|trZt|dkrZ| jdt	
g t	jd}| jd|t	j
dg| ddS )Nr  r   rQ  rR  r  r<  ri  r  r  r  r  r  r     s    

r   zaten::ones_likec           	      C  sT   |  d|}t|r*tj|tjj}n
t|}| j d|tjdg|	 ddS )Nr)  r  r<  ri  rR  r  r  r  r  r  r     s    

r   zaten::new_onesc                 C  s2   t |}t |r |d ur |}t| |||||S r  )r
   r  r  r   r  r  r  r  r     s    
r   z
aten::fullc              	   C  s   t |d}t |rX|d u r&tjjn|}t| ||||}t| ||| jdt	
ddS t |dd}|d u rxtjj}	n
t|}	t |d}
t|
trt|
dkr| jdt	
g t	jd}| jd	||d|	 dS d S )
Nr   rQ  r<  rR  ry  rj  r  r   r  )r
   r  r>  r   r`  rd  r  r   r"  rS  r   r  r0  r1  rJ  r   rn  r  rj  )r%  r0  rT  rj  r  r  r  const_valuetmprg  r  r  r  r  rW     s"    


rW   zaten::full_likec              	   C  s   t |d}t |dd}|d u r6tj|tjj}n
t|}t |rt| ||||}	| j	d||
 d}t| |	|| j	dtddS | 	d	|}
| j	d
|
tj|g| ddS d S )NrN  ry  rj  r[  r\  rQ  r<  rR  r)  r  ri  )r
   r  r  r   r`  ra  rd  r>  r  r"  re  r   rS  r   rj  )r%  r,  
fill_valuerj  r  r  r  r  rg  r   r2  r  r  r  rV     s"    

rV   zaten::new_fullc           	      C  s4   t |}t |r |d ur |}t| ||||||S r  )r
   r  r  rW   )	r%  r6  r   r!  rj  r  r  r  r  r  r  r  r   %  s    
r   	aten::eyec                 G  s   t |dkrX|\}}}}}t| |dg}| jd||dd}t| ||||}	| d|	S t |dkr|\}}
}}}}| jdt| |dgt| |
dgdd}t| ||||}	| d|	S tddt | d	S )
Nrr  r   r.  r  ZEyeLiker  r"  with 
 arguments)rJ  r
   r  r"  r  r  )r%  rL  r&  rj  r  r  Z_pin_memoryr  r2  r   mr  r  r  rO   6  s"    rO   aten::slicec                 G  s2  t |dkrr|\}}}}t|d}|dkr:td||  dkoXt| t	j
}|  dkoxt| t	j
}|  dk}	|  dk}
|s|	r|s|
r|  dkrtjtjjkrtd|nBt| |dg}t| |dg}t| |dg}| d	||||S nT|r&dn
t|d}|r>tjn
t|d}t|d}tj| ||g|g|gd
S nt |dkr|\}}}d}|  dkot| t	j
}|  dkot| t	j
}|rdn
t|d}|rtjn
t|d}tj| ||g|g|gd
S tddt | dS )Nrq  ry  r<  z"step!=1 is currently not supportedr!  r  zUnsupported: ONNX export of Slice with dynamic inputs. DynamicSlice is a deprecated experimental op. Please use statically allocated variables or export to a higher opset version.r   ZDynamicSlicer  r2  r&  r#  r$  )rJ  r
   rW  r	   rY  r  r  r0  r\  r   ZNoneTyper   operator_export_typer]  ZOperatorExportTypesZONNXr  r"  r   r  r  r  )r%  r6  rL  rA   rh  ri  r  Zis_start_noneZis_end_noneZis_start_onnx_constZis_end_onnx_constZstart_unsqueezedZend_unsqueezedZdim_unsqueezedr  r  r  r   N  s    








r   zaten::hardtanhr%  r6  Zmin_valZmax_valc                 C  s   t j| d|||ddS )Nr  r  r  r  r(  r  r  r  rc     s    rc   zaten::hardswishc                 C  s   t | |}| d||S r  )ra   r"  )r%  r6  Zhsr  r  r  rb     s    
rb   zaten::hardsigmoidc                 C  s   | j d|ddS )NHardSigmoidgUUUUUU?r  r*  rx  r  r  r  ra     s    ra   zaten::tanhshrinkc                 C  s   |  d|t| |S )NrE  )r"  r   rx  r  r  r  r     s    r   zaten::hardshrinkc                 C  sx   t j|t jj}| jdtj|| dd}t| t	| ||t
| |t| |}| d||| jdtjd| ddS NrQ  ri  rR  r  r   )r   r`  ra  rd  r"  rS  r   rj  r   r^   r   r   )r%  r6  lambdrg  lambd_opr  r  r  r  r`     s"    "r`   zaten::softshrinkc           	      C  s   t j|t jj}| jdtj|| dd}t| ||}| d|t	| ||| jdtjd| dd}t
| |t| |}| d|t| ||| jdtjd| dd}t| ||S r*  )r   r`  ra  rd  r"  rS  r   rj  r^   r   r   r   r   )	r%  r6  r+  rg  r,  Zgt_condZgt_outZlt_condZlt_outr  r  r  r     s8    
	
	r   zaten::aliasc                 C  s   |S r  r  rx  r  r  r  r     s    r   zaten::unsqueezec                 C  s~   |dk rlt |}|dur^tdt| d d d t|| d  d d	  || d }nt d
d|S t j| ||gdS )zbImplement unsqueezing a pytorch tensor in ONNX by inserting a new dimension at the specified `dim`r   Nz)ONNX export unsqueeze with negative axis r  r  r  r<  r   r  r  r  r  )r
   r{  r  r  r  r  r  r  r  r  r  r    s2    

r  z
aten::sortc                 C  sn   |d urt dd| t |}z|| }W n tyB   d }Y n0 |d u rZt dd|S | jd|||ddS )NZSortz'Out parameter is not supported for sortr$  TopKr  Zk_ir/  r  )r
   r  r  rX  r"  )r%  r6  rA   Z	decendingrf  Z
self_sizesr  r  r  r  r     s    

r   zaten::numelc                 C  s   t | |S r  )r
   Z_numel_helperrx  r  r  r  r   %  s    r   z
aten::topkc                 C  s<   |d urt dd| |s(t dd| | jd|||ddS )Nr-  z'Out parameter is not supported for topkzAscending TopK is not supportedr  r.  )r
   r  r"  )r%  r6  rY  rA   Zlargestr  rf  r  r  r  r   *  s    r   zprim::convert_element_typec                 G  s,   t |d dd}| jd|t| dS )Nr   ry  rj  r[  r\  )r
   r  r"  r   r`  re  )r%  r6  rL  rj  r  r  r  r:   8  s    r:   zaten::toc                 G  s  dd }||r|S t |dkr|d }t|d r|d   dkrt|d  d}t|tjrt |j	dkr|
 }t|}n|}t|st|tjrtj|d }| jd|| dS | jd|t| dS nt |d	krt|d
 dd}| jd|t| dS t |dkrXt|d dd}| jd|t| dS t |dkrt|d dd}| jd|t| dS td|S )Nc                 S  s   t | dkrL| d   dkpJ| d  tj pJt| d  tj	S t | dkrrt
| d dd}|d u S t | dv rt
| d dd}|d u S d	S )
Nrq  r   prim::devicerr  r<  ry  rj  )r     F)rJ  r  r  r\  isSubtypeOfr   ListTypeofIntsr0  r  r
   r  )rL  rj  r  r  r  is_aten_to_device_only@  s    z"to.<locals>.is_aten_to_device_onlyrq  r   r  rT  r[  r\  rr  r<  ry  rj  r  r0  zUnknown aten::to signature)rJ  r
   r>  r  r  r  r0  rS  r  r2  rn   r&  r   r`  ra  r"  re  r  r  )r%  r6  rL  r4  rj  Ztvalr  r  r  r   >  sB    
r   zaten::repeatc                 C  s0   t jj}t| ||}| d||}| d||S )Nr  ZTile)r   r`  r_  r   r"  )r%  r6  repeatsrj  Zshape_r  r  r  r     s    r   zaten::repeat_interleavec              
   C  s  t |}t |}t |}|d u r2td||d u rFtd||d u rZtd|t |rt | || jdt	dgd}tj	dtj
d}n
t |}|dk r|t|7 }| }t|D ] \}	}
|
d u rd	\||	< ||	< q|dks|d
kr8|d d
kr8|| dkr(t dddd|S t | |||S |d
kr|| dkrbt dddd|S |d d u rt dddd|S |d || ksJ d|d }ntd|g }t | ||d}t | |||}d\||< ||< t|D ]\}	}t| ||	 |d
 }| jdt|d |d
  d|| jdt||d
 d  dg}| jdg|R ddi}t| ||d }t j| || jdt|ddd}|| q| jdg|R d|iS )NzGUnsupported: ONNX export of repeat_interleave for unknown repeats rank.zGUnsupported: ONNX export of repeat_interleave for unknown repeats size.zEUnsupported: ONNX export of repeat_interleave for unknown input size.rQ  r  rR  r   ri  )r   r  r<  r   r     z3Unsupported along dimension with unknown input sizez*Unsupported for cases with dynamic repeatsz2repeats must have the same size as input along dimz%repeats must be 0-dim or 1-dim tensor)r  r<  r.  r/  Z	allowzero)r
   r{  r  r	   rY  r  r4  r"  rS  r   rn  rB  rJ  r  	enumerater@  Z-_repeat_interleave_single_value_repeat_helperZ_repeat_interleave_split_helperr  r  rN   r  )r%  r6  r5  rA   rV  Zrepeats_dimZrepeats_sizesZinput_sizesZinput_sizes_temprf  r  ZrepsZfinal_splitsZr_splitsZi_splitsZr_splitZi_splitZr_concatr  r  r  r     s    




"

r   zaten::pixel_shufflec           	      C  s  t |}t|dkr$t dd|S tdd |dd  D rt j| t | |ddg| jd	t	d
d||d
d
gdd
d}| jd|g dd}t j| || jd	t	g ddd
d}t j| || jd	t	g ddd
d}t 
| |ddgS |d | | }t j| || jd	t	d||||d |d gdd
d}| jd|g dd}t j| || jd	t	d||d | |d | gdd
dS d S )Nrq  r   only support 4d inputc                 s  s   | ]}|d u V  qd S r  r  r"  r  r  r  r    r#  z pixel_shuffle.<locals>.<genexpr>r<  r  r2  rQ  r   r  rR  r7  r  )r   r<  rq  r  rr  r2  r  )r   r   r  r<  r   r   )r   r   r   r   r  r<  rr  r
   r  rJ  r  r/  r4  r  r"  rS  r   r  )	r%  r6  Zupscale_factorr  
after_viewafter_transpose	reshape_h	reshape_woutput_channelr  r  r  r     s~    
	

r   zaten::pixel_unshufflec           
      C  s  t |}t|dkr$t dd|S tdd |dd  D rt j| t | |dg| jdt	d	d	d
|d	gdd	d}t j| || jdt	d	d	d	d	d
|gdd	d}| jd|g dd}t j| || jdt	g ddd	d}t 
| |ddgS |d | | }t j| || jdt	d
|d |d | ||d | |gdd	d}	| jd|	g dd}t j| || jdt	d
||d | |d | gdd	dS d S )Nrq  r   r9  c                 s  s   | ]}|d u V  qd S r  r  r"  r  r  r  r  I  r#  z"pixel_unshuffle.<locals>.<genexpr>r<  r2  rQ  r   r  rR  r7  r  )r   r<  r2  rr  r  rq  r  )r   r  r<  r<  r   r   r  r:  )
r%  r6  Zdownscale_factorr  r=  r>  r<  Zfinal_reshaper?  r;  r  r  r  r   A  sx    




r   c           *   
     s  t d d d d d  g d}ttdd |D |}|rFd	nd
dkrxt  d|	  krxtdd|S t  d|	  ksJ  fddtdt D |
r̈jd|g dd}|r|rtdd|S 	dr|d	d  
  }d d }t|dd u r2tdd|S |	 }|}g }dksTdkrZ|}ndkrp|\}}g }|d u rtn|}dkrg dndkrg ddd fdd}fdd}fd d!}tD ]\}|r>d	kr||\}}}n||\}}t}||d f}nd	kr|d
| \}} }!|d
| d \}"}#}$jd"|!|$dd#}n,|d
| \}} |d
| d \}"}#t}jd"||"dd#}jd"| |#dd#}d
| d
| d
 f}|||||g}%|%||g|R   dkr,|%||g|R   |r6i nd$d%i}&dkr|	rX||g}'n|g}'jdg|%R d
|'d&|&\}}(n^dkrjdg|%R d
dd'|&\}}(n.dkrjdg|%R d(d)|&\}}(})|	r$jd|g d*d}tj|jd+tg d,d-dd.}nt|dg}||( dkr||) q|
rpjd|g dd}dkr~|(njd"g|R d/di}dksdkr||fS dkrdkr|)njd"g|R d/di}|||fS d S )0NzVExporting a model to ONNX with a batch_size other than 1, with a variable length with z can cause an error z9when running the ONNX model with a different batch size. z4Make sure to save the model with a batch size of 1, z=or define the initial states (h0/c0) as inputs of the model. )r	  r  r  ZAffiner  ZThresholdedReluZ
ScaledTanhr)  r  ZSoftsignr  c                 S  s   g | ]}|  qS r  )lower)r}  Zact_funr  r  r  r    r#  z _generic_rnn.<locals>.<listcomp>rq  r  LSTMr<  zLSTMs with projectionsc                   s   g | ]} ||  qS r  r  r"  )all_weightsweights_per_layerr  r  r    s   r   r  r,  r  zRNN/GRU/LSTMzdropout in training modeRNNzunknown hidden sizeGRU))r<  r  r   r<  )r  r2  )rF  )r2  rq  )r<  r2  c                   s.    fdd|D } j dg|R ddiS )Nc              	     s2   g | ]*\}}t j d g| g| gdqS )r   r  r  )r}  xyr%  r&  wr  r  r    s   z8_generic_rnn.<locals>.reform_weights.<locals>.<listcomp>r.  r/  r   r*  )r%  rJ  r&  Z	intervalsZslicesr  rI  r  reform_weights  s    z$_generic_rnn.<locals>.reform_weightsc                   s`   |  }dkr|\}}n,dks*dkrF fdd|D \}}t  fdd||fD S )NrD  rE  rA  c                 3  s   | ]} |V  qd S r  r  r}  rJ  r%  hidden_sizereform_permutationrK  r  r  r    s   zB_generic_rnn.<locals>.transform_weights_no_bias.<locals>.<genexpr>c                 3  s   | ]}t  |d gV  qdS rz  r  r}  rG  r'  r  r  r    s   )r=  )layer_indexweights	weight_ih	weight_hhr%  rN  layer_weightsrO  rK  variantr  r  transform_weights_no_bias  s    

z/_generic_rnn.<locals>.transform_weights_no_biasc                   s|   |  }dkr|\}}}}n0dks.dkrN fdd|D \}}}} j d||dd}t fd	d|||fD S )
NrD  rE  rA  c                 3  s   | ]} |V  qd S r  r  rL  rM  r  r  r    s   z:_generic_rnn.<locals>.transform_weights.<locals>.<genexpr>r.  r   r  c                 3  s   | ]}t  |d gV  qdS rz  r  rP  r'  r  r  r    s   )r"  r=  )rQ  rR  rS  rT  Zbias_ihZbias_hhbias_concatrU  r  r  transform_weights  s    z'_generic_rnn.<locals>.transform_weightsc                   s&   dkr| S t j | dg|g|gdS )Nr<  r   r  r  )rG  rh  ri  )r%  
num_layersr  r  retrieve_state  s    z$_generic_rnn.<locals>.retrieve_stater.  r  Zdirection_sbidirectional)r  hidden_size_iZactivations_s)r  r^  Zlinear_before_reset_ir2  )r  r^  )r   r  r<  r2  rQ  )r   r   r  rR  r7  r/  )r  r  dictr>  rJ  r
   r  r  r"  
startswithr@  r  r  r  r4  rS  r  r  )*r%  rW  r,  Zinitial_statesrB  
has_biasesr[  rD   r  r]  batch_firstbatch_sizesZonnxActivationsZvariantToOnnxActivationMapZnonlinearityw_hhZunidirectionalZprev_outputh_outsZh0Zc0c_outsZsequence_lensrX  rZ  r\  ry  rS  rT  rY  Zstate_indicesZweight_ih_fZweight_hh_fZbias_fZweight_ih_bZweight_hh_bZbias_binputsextra_kwargsZ
activationZh_outZc_outr  )	rB  r%  rN  rV  r[  rO  rK  rW  rC  r  _generic_rnn  s   





	








&
&ri  c
                 C  s2   t |t | }
}t| d||
|||||||	S )NrA  r
   r  ri  )r%  r,  hidden_vweight_vra  r[  rD   r  r]  rb  hiddenr  r  r  r  
_lstm_fullh  s     rn  c
                 C  s4   t |t | }
}t| d||
||||||	|dS )NrA  rc  rj  )r%  r,  rc  rk  rl  ra  r[  rD   r  r]  rm  r  r  r  r  _lstm_packed  s     rp  z
aten::lstmc                 G  s2   t |d rt| g|R  S t| g|R  S d S Nr2  )r
   r?  rp  rn  r%  rL  r  r  r  r     s    r   zaten::lstm_cellc                   s   t  |dg}t |} fdd|D }t |rB||||fn||f}t |rXdnd}	t d||||	dddddd\}
}}t  |dgt  |dgfS )	Nr   c                   s   g | ]}t  |d gqS rP  r  rP  r'  r  r  r    r#  zlstm_cell.<locals>.<listcomp>TFrA  r<  )r[  rD   r  r]  rb  )r
   r  r  Z
_is_tensorri  r  )r%  r6  rm  Zw_ihrd  Zb_ihZb_hhr,  r  ra  r;  re  rf  r  r'  r  r     s0    
r   z	aten::grurE  Zgruzaten::rnn_tanhZRNN_TANHZrnn_tanhzaten::rnn_reluZRNN_RELUZrnn_relur  c                   s^   t ddddddddd	fdd t ddddddddd	fdd fdd	}|S )
NrM  ry  rN  c
                   s&   t |}
t|  |||
||||||	S r  rj  )r%  r,  rm  rl  ra  r[  rD   r  r]  rb  r  rs  r  r  	_rnn_full  s    
z"_one_hidden_rnn.<locals>._rnn_fullc
                   s(   t |}
t|  |||
|||||	|dS )Nro  rj  )r%  r,  rc  rm  rl  ra  r[  rD   r  r]  r  rs  r  r  _rnn_packed  s    
z$_one_hidden_rnn.<locals>._rnn_packedc                   s2   t |d r| g|R  S  | g|R  S d S rq  )r
   r?  rr  )rt  ru  r  r  symbolic  s    z!_one_hidden_rnn.<locals>.symbolicr   )r  rv  r  )rt  ru  r  r  _one_hidden_rnn  s    rw  zaten::_dim_arangec                 C  s@   |  d|}| j d|| j dt|ddd}t| |dd d d S )Nr)  r  rQ  rR  r   r  rq  )r"  rS  r   r   )r%  likerA   Z
like_shapestopr  r  r  _dim_arange  s
    rz  zaten::detachc                 C  s   |S r  r  r+  r  r  r  r@   #  s    r@   zaten::contiguousc                 C  s   |dkrt d||S )Nr  z-onnx memory_format support is not implemented)r	   rY  )r%  r,  r  r  r  r  r2   )  s
    r2   zaten::_pack_padded_sequencec                 C  sz   |r| j d|g dd}| tjj s:td|t	j
|t	j
jt	j
jkrh| j d|tjjd}| j d||dd	S )
Nr  r,  r  z*'lengths' must be a Tensor for ONNX exportr[  r\  zprim::PackPaddedr  r  )r"  r\  r1  rS  r   Z
TensorTypegetr	   rY  r   r`  ra  rb  r  r]  r^  r  )r%  r,  lengthsrb  r  r  r  _pack_padded_sequence3  s    r~  zaten::_pad_packed_sequencec                 C  s6   | j d||dd\}}|r.| j d|g dd}||fS )Nzprim::PadPackedr  r{  r  r,  r  r*  )r%  r  rc  rb  Zpadding_valuetotal_lengthr}  r  r  r  _pad_packed_sequenceL  s    r  zaten::randintc                 G  s  t |dd}t |dd}t |dd}|d u r<tjj}n
t|}|d u rZt d||d u rnt d|t |d}	t |	r| jd|t	j
dgt	jd	d
}
| jd|
||d}n| jd|	||d}tjj}| jd|| d}||kr| jd|| d}|S )Nry  rj  r  highr   r  r  r   ri  rR  RandomUniformLikelow_fhigh_fRandomUniform)shape_ir  r  r[  r\  )r
   r  r   r`  r_  r  r  r>  r"  rS  r   rt  re  )r%  r  r  shapesrj  r  low_ihigh_irg  r2  shape_constr   	int_dtyper   r  r  r  r   _  sD    



r   zaten::randint_likec                 G  s   t |dd}t |dd}t |dd}|d u r<tjj}n
t|}|d u rZt d||d u rnt d|| jd|||d}	tjj}
| jd|	|
 d	}|
|kr| jd|| d	}|S )
Nry  rj  r  r  r   r  r  r[  r\  )r
   r  r   r`  r_  r  r"  re  )r%  r6  r  r  rj  r  r  r  rg  r   r  r   r  r  r  r     s*    

r   zaten::randnc                 G  s   t |dd}|d u r tjj}n
t|}t |d}t |rr| jd|tj	dgtj
dd}| jd|| d	S | jd
|| dS )Nry  rj  r  r  r   ri  rR  RandomNormalLikedtype_iZRandomNormalr  r  r
   r  r   r`  rd  r  r>  r"  rS  r   rt  re  r%  r  rj  r  rg  r2  r  r  r  r  r     s*    


r   z
aten::randc                 G  s   t |dd}|d u r tjj}n
t|}t |d}t |rr| jd|tj	dgtj
dd}| jd|| d	S | jd
|| dS )Nry  rj  r  r  r   ri  rR  r  r  r  r  r  r  r  r  r  r     s*    


r   zaten::randn_likec                 C  sH   t |dd}|d u r*tj|tjj}n
t|}| jd|| dS )Nry  rj  r  r  r
   r  r   r`  ra  rd  r"  re  )r%  r6  rj  r  r  r  r  rg  r  r  r  r     s    

r   zaten::rand_likec                 C  sB   t |dd}|d u r(tj|tjj}| jd|t| dS )Nry  rj  r  r  r  )r%  r6  rj  r  r  r  r  r  r  r  r     s    
r   zaten::rreluc                 C  s@   |s || d }| j d||dS | j d|||d}|  d||S )Nr  r  r  r  )r  r  r  r*  )r%  r,  r@  upperr  r  r  r  r  r  r  r     s
    r   zaten::bernoullic           	      C  s   |d ur t |s t dd| |d ur@t |s@t dd| tj|tjj}|tjjkrlt dd|S | jd|dd| d}|d urt |s|n|}| d	||}| jd
|| dS )NZ	Bernoulliz,out parameter is not supported for bernoulliz(generator is not supported for bernoulliinput dtype not accessibler  rO  r  )r  r  r  r  r[  r\  )	r
   r  r  r   r`  ra  rb  r"  re  )	r%  r,  r  r  rf  rj  ZrandsZprobrR  r  r  r  r#     s2    r#   zaten::log_sigmoidc                 C  s   |  d|}|  d|S )Nr  r  r*  )r%  r,  r  r  r  r  r{   ,  s    r{   z	aten::erfc                 C  s   |  d|S )NErfr*  r+  r  r  r  rK   3  s    rK   zaten::flattenc                 C  s   t |}|d u r t dd|S |dkr8t | |dgS |dkrL| d|S |dk r\|| }|dkr||d kr| jd||dS |dkr||d kr| jd||d dS t | ||||S )	NrA   r  r   r<  r  Flattenr  r  )r
   r{  r  r4  r"  Z_flatten_helper)r%  r,  Z	start_dimZend_dimrA   r  r  r  rQ   9  s$    
rQ   zaten::nonzeroc                 C  s   t | | d|S )z/Emitted from `torch.nonzero(x, as_tuple=False)`ZNonZero)r   r"  r+  r  r  r  r   V  s    r   zaten::nonzero_numpyc                 C  s   t | t| |d|dS )Nr<  )r  )r   r   )r%  r,  r  r  r  r  r   ]  s    r   zaten::isnanc                 C  s   |  d|}|S )NZIsNaNr*  )r%  r,  rR  r  r  r  rm   c  s    rm   z	aten::anyc              	   G  s   t |dkr|d }d\}}n6|\}}}t|d}dd |dD }t|d}| jd	|tjjd
}tj| |||d}t	| || jdt
jdt
jddS )Nr<  r   rI  r   c                 S  s   g | ]}t |qS r  r-  )r}  r  r  r  r  r  u  r#  z_any.<locals>.<listcomp>r  ry  r[  r\  r  rQ  ri  rR  )rJ  r
   rW  r  r"  r]  r^  r_  r  r^   rS  r   r  )r%  rL  r,  rA   r  Z	input_sumr  r  r  _anyj  s    

r  z	aten::allc              	   G  sP   |  d|d }t|dkr.|  dt| |S |  dt| ||d |d S d S )Nrl  r   r<  r  )r"  rJ  r  )r%  rL  r,  r  r  r  _all~  s    r  zaten::narrowc                 C  s   t j| ||g|g|| gdS )Nr  r  )r%  r,  rA   rh  lengthr  r  r  r     s    r   zaten::argmaxztorch._C.Valuer%  r,  rA   r  c                 C  s   t | |||dS )NZArgMaxr
   Z_argmin_argmax_helperr  r  r  r  r     s    r   zaten::argminc                 C  s   t | |||dS )NZArgMinr  r  r  r  r  r     s    r   zaten::scatterc                 C  s   t j|t jj}t|}t|r:| jd||||dS t j|}||krb| jd|| d}| jd||t	| |||dS d S )NZScatterr  r[  r\  )
r   r`  ra  rb  r
   rB  r>  r"  re  rM   )r%  r6  rA   ri   srcZsrc_typer  r  r  r  r     s    

r   zaten::scatter_addc                 C  sz   t |}|d u r t dd|S t j|dd}|rP| jdtj|| dd}nt| ||}t 	| ||||}t
| ||S )Nr   r  F)Zallow_nonstaticrQ  ri  rR  )r
   r  r  r  r"  rS  r  rj  r  Z_scatter_helperr   )r%  r6  rA   ri   r  rg  r0  Zto_addr  r  r  r     s    
r   z
aten::log2c              	   C  s(   d}|  dt| || j dt|dS )Ng9B.?rZ  rQ  rR  r  )r%  r6  Z_ln2r  r  r  r     s    r   zaten::is_floating_pointc                 C  s6   t |r | jdtdgdS | jdtdgdS NrQ  r<  rR  r   )r
   rc  r"  rS  
BoolTensorrx  r  r  r  rk     s    
rk   zaten::__is_c                 C  sL   t |r@t |r*| jdtdgdS | jdtdgdS t| ||S r  )r
   r  r"  rS  r  rJ   rH  r  r  r  __is_  s
    

r  zaten::__isnot_c                 C  s   t | ||S r  )r  rH  r  r  r  __isnot_  s    r  zaten::one_hotc                 C  sn   | j dtddgd}tj|tjjtjjtjjtjj	tjj
hv rZ| j d|tjjd}| j d|||dd	S )
NrQ  r   r<  rR  r[  r\  OneHotr  r  )r"  rS  r  r   r`  ra  rb  r  r  r  r  r]  r^  r_  )r%  r6  Znum_classesr  r  r  r  r     s    r   zaten::gatherc           	   	   C  s   t |drt dd|S tj|}| jdtddgd}t	| || jdt|gd}| jd| jd	||||d
|
 d}| dt | ||d g|}t j| ||gddS )Nry  rX   zsparse_grad == TruerQ  r   r<  rR  r[  r  r  r\  r=  r  )r
   r  r  r   r`  ra  r"  rS  r  r   re  r  r  )	r%  r6  rA   ri   Zsparse_gradrg  r  depthr   r  r  r  rX     s    rX   c                 C  s   t | ||||S r  )r
   Z_var_mean_helper)r%  r,  rA   Z
correctionr  r  r  r  	_var_mean	  s    r  z	aten::stdc                 G  s"   t | |g|R  \}}| d|S Nr  r  r"  r%  r,  rL  r	  r;  r  r  r  r     s    r   z	aten::varc                 G  s   t | |g|R  \}}|S r  )r  r  r  r  r  r	    s    r	  zaten::var_meanc                 G  s6   t |dkr t| |d |d d S t| |g|R  S d S )Nr<  r   )rJ  r  )r%  r,  rL  r  r  r  r    s    r  zaten::std_meanc                 G  s&   t | |g|R  \}}| d||fS r  r  )r%  r,  rL  r	  r  r  r  r  r   "  s    r   zaten::logsumexpc                 C  s   | j d|||dS )NZReduceLogSumExpr  r*  r  r  r  r  r   (  s    r   aten::arangec           
        s  dd } fdd}t |dks,t |dkrt |dkr>d }n||d }tj |d |d	\}}}}t |dg}||}t t t ||d d dg}	 jd
|	t	|
 dS t |dkst |dkrt |dkrd }n||d }tj |d |d |d |d\}}}}t |dg}t |dg}t |dg}| d d|||}t t t |d d d dg}	 d d|	||}	 jd
|	t	|
 dS t |dkrz||d }tj |d |d |d\}}}}t |dg}t |dg}| d||} dt t t ||g|dd  R  dg|}	 jd
|	t	|
 dS tddt | dS )Nc                 S  s   t | d} | S )Nry  )r
   r  ri  r  r  r  _get_arange_dtype0  s    z!arange.<locals>._get_arange_dtypec                   s.   t | r* jd d| tjj d} | S )Nr[  r  r\  )r
   rc  r"  r   r`  r_  re  )range_tensorr'  r  r  _float_step_convert4  s    


z#arange.<locals>._float_step_convertr  rr  r<  r   )ri  rj  r[  r\  rq  r0  r2  )rh  ri  r  rj  rZ  rE  r:  r=  r  )rh  ri  rj  r  r#  r$  )rJ  r
   Z_arange_cast_helperr  r  r   r   r"  r   r`  re  r  )
r%  rL  r  r  rj  ri  rh  r  r  Zarange_tensorr  r'  r  r   .  sl    	
&r   zaten::linspacec           
      C  sT   t | |d }t| t| ||t| || jdtjdtjdd}	t| t	| ||	|S )NrQ  r<  ri  rR  )
r
   Z_arange_helperrB   r   r"  rS  r   rn  r   r   )
r%  rh  ri  Zstepsrj  r  r  r  r  r  r  r  r  rz   {  s    
 rz   z
aten::liftc                 C  s   |S r  r  rx  r  r  r  rt     s    rt   zaten::masked_fillc                 C  s6   | j d|tjjd}t|}|  d|t|||S )zImplement the masked_fill functionality available for a pytorch tensor in ONNX.

    Fills elements of the input tensor with `value` where `mask` is True.
    r[  r\  r  )r"  r]  r^  r  r
   rB  r  r%  r6  maskrT  r  r  r  r     s    
r   zaten::masked_fill_c                 C  s   t | |||S r  )r   r  r  r  r  r     s    r   aten::indexc                   s  t |rt |}n|g}fddfdd|D }t|dkr`t jd|d ddS d	d t|D  t dkrS t dkrt d | d  S t }|d u rt d
dS t	
dtj d t }tfddt|D jd  fddt|D  djd|d| d  } d  }t|d ddD ]@}d| |  |}	d||	}d| |  }qdtd|t|}
 tt d  d d krֈjdtdgdg fddt|D  }jdg|R ddi}t |ttd d d dg tt d d || d  }jd|dfddt d D |
g  fddt d |D  }jdg|R ddi}n.jd|
g fddt|D R ddi}t |S d S ) Nc                   sh   t | sdtj| tjjtjjks.t | rd jdk rDt	
dtd t  t | dg} | S )Nr  z?Exporting masked indices are only supported after ONNX opset 9.zExporting aten::index operator with indices of type Byte. Only 1-D indices are supported. In any other case, this will produce an incorrect ONNX graph.r<  )r
   r  r   r`  ra  rb  r  rG  r  r	   rY  r  r  r  r   )ri   rx  r  r  try_mask_to_index  s&    

z index.<locals>.try_mask_to_indexc                   s   g | ]} |qS r  r  )r}  rf  )r  r  r  r    r#  zindex.<locals>.<listcomp>r<  r   F)Zapply_reshapec                 S  s   g | ]\}}t |s|qS r  )r
   r  )r}  ry  rf  r  r  r  r    s   r  z9operator of advanced indexing on tensor of unknown rank. z=Exporting aten::index operator of advanced indexing in opset z is achieved by combination of multiple ONNX operators, including Reshape, Transpose, Concat, and Gather. If indices include negative values, the exported graph will produce incorrect results.c              
     s0   g | ](} j d  j dt|gdddqS )r  rQ  rR  r   r  )r"  rS  r  r}  rA   )r%  shape_tensorr  r  r    s   r  c                   s   g | ]}| vr|qS r  r  r"  )adv_idx_indicesr  r  r    r#  r  r  r  r  r  r=  r:  rQ  rR  c                   s   g | ]}| vr| qS r  r  r"  r  dim_tensor_listr  r  r    s   r.  r/  c                   s   g | ]} | qS r  r  r"  )r  r  r  r  *  r#  c                   s   g | ]}| vr| qS r  r  r"  r  r  r  r  ,  s   c                   s   g | ]}| vr| qS r  r  r"  r  r  r  r  7  s   )r
   r  r  rJ  r  r8  rh   r{  r  r  r  r   r  r-  r  r"  r1  rS  r  r4  )r%  r6  ri   r  r  Zadv_idx_countZcum_adv_indexZ
multiplierry  Z	adv_indexZcum_adv_index_shape_tensorZfolded_adv_idx_shape_listZfolded_adv_idx_shapeZadv_idx_permuteZfinal_shape_listZfinal_shaper  )r  r  r%  r6  r  r  r  ri     s    




	ri   zaten::linalg_normzSequence[int] | Noner%  r6  ordrA   r  rj  c                 C  s   d }|d u r|t |r<t | |dg}| jdtdgd}t |}|d u r\t dd|S |dkrrt |d}qd	dg}n8t	|dkrt |r| jdtdgd}t |d}|rt
| |||||S t| |||||S )
Nr  rQ  r  rR  rA   (Input rank must be known at export time.r<  rN  r   )r
   r  r4  r"  rS  r  r{  r  rW  rJ  rx   rv   )r%  r6  r  rA   r  rj  	ord_valueself_dimr  r  r  rw   B  s(    



rw   zaten::linalg_vector_normc                 C  s   t | |||||S r  )r
   Z_linalg_vector_norm_helperr  r  r  r  rx   e  s    
rx   zaten::linalg_matrix_normz	list[int]c              	   C  s  t |d}|dkr"t| |||S |dkr8t dd|S t |d}|d u rZt| |||S |dksj|dkrxt dd	|S t |}|d u rt dd
|S |d dk r|d  |7  < |d dk r|d  |7  < |tjks|tj kr|d |d  |d< |d< |d |d kr*|s*|d  d8  < t j| | d||d g|d}|dkrt	| || jdt
|d gd|d\}	}
n*t| || jdt
|d gd|d\}	}
|	S d S )NrU  ZfroZnuczlinalg.matrix_normzord==nucrN  r  r_  zord==2r  r   r<  r  r  rQ  rR  )r  r  )r
   rW  rU   r  r{  r  infr  r"  r   rS  r  r   )r%  r6  r  rA   r  rj  r  r  r  r  Z_indicesr  r  r  rv   r  sP    


rv   zaten::linalg_crossr  c                 C  s   t | |||S r  )r>   )r%  r,  r9  rA   r  r  r  ru     s    ru   zaten::frobenius_normc                 C  s,   |  d||}tj| |||d}|  d|S )Nr=  r  r  )r"  r
   r  )r%  r6  rA   r  ZsqrZsumsqrr  r  r  rU     s    rU   zaten::multinomialc                 C  sZ   |d ur t |s t dd| |s:|dkr:t dd| t| |}| jd|tjj|dS )NZMultinomialz*generator is not supported for multinomialr<  zGreplacement=False when num_samples > 1 is not supported for multinomial)r  Zsample_size_i)r
   r  r  r}   r"  r]  r^  r_  )r%  r,  Znum_samplesreplacementr  Z	log_inputr  r  r  r     s"    
r   zaten::baddbmmc           
      C  s\   t j|}t| ||}t| || jd|| d}t| || jd|| d}	t| ||	S r  )r   r`  ra  r   r   r"  re  r   )
r%  r6  Zbatch1Zbatch2r  rD  rg  Z	batch_mulZmul_aZmul_br  r  r  r!     s    r!   zaten::meshgridz
str | None)r%  indexingc                   s<  |d u rd}n|dvr(t d| |t|}|dkrP|dd d |d d<  fdd	|D } fd
d	|D } jdg|R ddi}g }t|D ]l\}}	 jdtjdtjddgt	| }
|| |
|< t
 |	 jdg|
R ddi}| d|| q|dkr*|d |d  |d< |d<  jdg|R  S )Nij>   xyr  zUnsupported indexing: r  r<  r  r  c                   s,   g | ]$}t  | jd tdgdqS )rQ  r  rR  )r
   r4  r"  rS  r  r|  r'  r  r  r    s   zmeshgrid.<locals>.<listcomp>c                   s   g | ]}  d |qS )r)  r*  r|  r'  r  r  r    r#  r.  r/  r   rQ  ri  rR  r  prim::ListConstruct)r	   rY  r
   r  r"  r8  rS  r   rn  rJ  r3  r  )r%  r  r  Zunpacked_tensor_listr  Ztensors_shapeZ	out_shaperf  ry  r   r  Z
t_reshapedr  r'  r  r     s2    


 
r   zaten::remainderc                 C  s(   t | ||}| d||}| d||S )Nr=  rE  )rW  r"  )r%  r,  r9  rB   Zquor  r  r  r     s    r   z
aten::gelu)r%  r6  approximatec                 C  s&  |dkrt dt j }d}tj|tjd}tj|tjd}tjdtjd}tjdtjd}t| |t| ||}	t| |t| |t| ||	}
t| |t| |t| || d|
S d}| d	| d
|tj|tjd}t| || jdtjdtjdd}t| t| ||| jdtjdtjddS d S )Nr   r  gHm?ri  rO        ?r  g;f?r  rZ  rQ  r<  rR  )	r  r   r  rS  r   ru  r   r   r"  )r%  r6  r  ZkBetaZkKappar  kapparq  ZhalfZ	self_cubeinnerZ_sqrt2rK   Zerf_plusoner  r  r  rZ     s(    $"
rZ   zaten::group_normc              
   C  s  t |d}|d ur$|| dks$J t |}|d u rDt dd|S d|dg}	t | || jdt|	d}
| jdtjdg| t	j
| d	d}| jdtjd
g| t	j
| d	d}| jd|
|||d}t | || d|}|d u s|  r,tjdgt	j
| d	}| jd|d}|d u sD|  rntjd
gt	j
| d	}| jd|d}ttd|d }t| t| |t | ||t | ||S )Nr<  r   r]   zunknown input rankr  rQ  rR  rO  ri  r  r  r  r)  )r
   r  r{  r  r4  r"  rS  r  r   r   r`  ra  rj  r  
mustBeNoner1  r  r   r   r  )r%  r,  Z
num_groupsr  r  r  r  r  Z
input_rankr2  r  r  r  Znorm_reshapedr   r  r  r  r  r  r  r]   +  sX    


r]   zaten::_weight_normc                 C  s   t |}|d urttt|}|d urH|dk r6||7 }|dkrH|| t| |d|d}| d||}| d||S td|d S )Nr  r  r<  rZ  r=  zDUnsupported: ONNX export of _weight_norm for tensor of unknown rank.)	r
   r{  r1  r  remover   r"  r	   rY  )r%  rl  Zweight_grA   r  r  Znorm_vrB   r  r  r  _weight_normh  s    

r  z	aten::dimc                 C  s   |  d|}|  d|S )zFImplement the dim functionality available for a pytorch tensor in ONNXr)  Sizer*  r5  r  r  r  rA     s    rA   zaten::__contains_c                 C  sd   t |}tdd |D rTt |rT| jdtt | ddd |D v dS t	
d|d S )Nc                 s  s   | ]}t |V  qd S r  )r
   r  rP  r  r  r  r    s   z__contains_.<locals>.<genexpr>rQ  rT  c                 s  s   | ]}t | d V  qdS )rT  N)r
   r  r  rP  r  r  r  r    r#  rR  zJUnsupported: ONNX export of __contains__ for non-constant list or element.)r
   r  r  r  r"  rS  r   r  r  r	   rY  )r%  r6  elementZunpacked_listr  r  r  __contains_  s$    
r  zaten::__getitem_c                 C  s    t | || jdtdgd|S r  )r   r"  rS  r   )r%  r6  ry  r  r  r  
__getitem_  s    r  z
aten::itemc                 C  s   |S r  r  rx  r  r  r  rn     s    rn   z
aten::takec              
   C  sD   t | || jdtjdgtjdd}t| |d|}t| ||}|S )NrQ  r  ri  rR  r   )r
   r4  r"  rS  r   rn  rh   r   )r%  r6  ri   Zself_flattenedrf  r  r  r  r     s    r   c                 C  s&   t | ||}t| |}t| ||}|S r  )r   rL   r   )r%  r,  targetdiff_Zexp_rR  r  r  r  _kl_div_log_target_impl  s    
r  c           	      C  sZ   t | |}t| ||}t| ||}t| |}t| || jdtdd}t| |||}|S r  )	r}   r   r   r  r^   r"  rS  r   r  )	r%  r,  r  Zlog_r  Z
output_posZzeros_Zmask_rR  r  r  r  _kl_div_non_log_target_impl  s    

r  zaten::kl_divc                 C  sj   |rt | ||}nt| ||}|dkr*|S |dkrB| jd|ddS |dkrZtj| |ddS td|S d S )Nr   r<  r  r  r  z4kl_div with reduction other than none, mean, or sum.)r  r  r"  r
   r  r  )r%  r,  r  	reductionZ
log_targetrR  r  r  r  ro     s    ro   zaten::mse_lossc                 C  sh   t | t| ||t| ||}|dkr(|S |dkr@| jd|ddS |dkrXtj| |ddS td|S d S )Nr   r<  r  r  r  z6mse_loss with reduction other than none, mean, or sum.)r   r   r"  r
   r  r  )r%  r,  r  r  rR  r  r  r  r     s    r   zaten::as_stridedc                 C  s  t |d}t|}t | || jdtjdgtjdd}t |stjdgtj	d}t
t||D ]6\}\}	}
dg| }d||< |t|	||
  }qd|r|| }| d|| jd|dS d }t
|D ]\}}
dg| }d||< t| || jdtdgd| jdt|d}	t | t| |	d	d d d | jdt|d}| d
|| jdt|
gd}|d u rr|}q| d||}q|r| d|| dt|g}| d||S d S )Nr  rQ  r  ri  rR  r   r<  r  rq  r=  r:  )r
   r  rJ  r4  r"  rS  r   rn  r>  r  r8  r>  r   r  r   )r%  r6  r0  stridesoffsetr  Zself_1dindry  r   r)  Zr_sizeZtmp_indr  r  r  r     sL    



r   zaten::__derive_indexc              	   C  s   |  d||  d||S )Nr:  r=  r*  )r%  ri   rh  r  r  r  r  __derive_index  s    r  zaten::__range_lengthc                 C  s6   |  d||}|  dt| ||}| j d|tjjdS )NrE  r  r[  r\  )r"  r   r]  r^  r_  )r%  lor  r  r   rB   r  r  r  __range_length  s    
r  zaten::linearc                 C  s   t |}t| |}|dkrp|  sp| jdtjdtjdd}| jdtjdtjdd}t	| |||||}n$t
| ||}|  st| ||}|S )Nr  rQ  r<  ri  rR  )r
   r{  r   r  r  r"  rS  r   rn  r   r   r   )r%  r,  r  r  r  rD  r  rR  r  r  r  ry   )  s    

zaten::hann_windowz
int | None)r%  rj  c              	   C  s   |d u r.t  }|r|js t j}tj|}	n
t|}	t| |dd d d }
| jd|
t	j
jd}t| | jdt jtjt jdd|}|du rt| || jdt jdt jdd}t| ||}| jdt| t| ||	 d}|S )	Nrq  r[  r\  rQ  ri  rR  Fr<  )rS  rs  rk   rt  r   r`  Z
from_dtyper   r"  r]  r^  rd  r   r   r  r  r   r&  rB   r   r   re  )r%  Zwindow_lengthZperiodicrj  r  r  r  r  Zdtype_rg  Zn_arrayrR  r  r  r  r_   9  s,    

r_   zaten::mvc                 C  s   t | ||S r  r   )r%  r6  Zvecr  r  r  r   a  s    r   z	aten::dotc                 C  s   t | ||S r  r  rH  r  r  r  rC   f  s    rC   zaten::movedimc           
      C  s   | d}| d}| | ks(J ||k r8|S t|}|d usNJ tt|}| }| }t|	 |	 D ] \}}	|||	< d||< d||	< q|dd |D }dd |D }t||D ]\}}	|||	< q| j
d||dS )Nr  c                 S  s   g | ]}|d kr|qS r  r  r  r  r  r  r    r#  zmovedim.<locals>.<listcomp>c                 S  s   g | ]}|d kr|qS r  r  r  r  r  r  r    r#  r  r  )r  r   r  r
   r{  r1  r  r  r>  tolistr"  )
r%  r6  r  Zdestinationr  r  Zsrc_dimsZdst_dimsr  dstr  r  r  r   k  s&    




r   z
aten::fillc                 C  s    t j|t jj}t| |||S r  )r   r`  ra  rd  rV   )r%  r6  rT  rg  r  r  r  rP     s    rP   zaten::index_addc                   s  t d |r0tt|dkr0tdd|S t d  d u rPtd|t	|}t	|}|d u st|d u rtd|||kr|| }t
|D ]}	t| |t	|g}qt| }
t| }|
d ur|d ur|
|krtd|tt
|}d	d
 t
|D } fdd
t
|D }tj| ||||d}t| ||}t
 D ]}	t| |dg}qLt
|  d D ]}	t| |t	|g}qtt| | t| |||S )NzyWarning: ONNX export does not support duplicated values in 'index' field, this will cause the ONNX model to be incorrect.r<  rd   z
alpha != 1ry  zXONNX export does NOT support exporting 'index_add_()' function with unknown 'dim' value.z~ONNX export does NOT support exporting 'index_add_()' function while the rank of self tensor or tensor to be added is unknown.zoONNX export does not support exporting 'index_add_()' function with duplicated values in 'index' parameter yet.c                 S  s   g | ]}d qS rP  r  r"  r  r  r  r    r#  zindex_add.<locals>.<listcomp>c                   s   g | ]}| krt jnd qS r*  )sysmaxsizer"  r^  r  r  r    r#  r  r   )r  r  r
   rA  rB  r  r  r	   rY  r{  r  r  r  r1  r  rM   r   )r%  r6  rA   ri   r9  rD  Zself_dim_rankZother_dim_rankdeltary  Zother_dim_sizeZself_dim_sizeZnew_shape_axesZnew_shape_startsZnew_shape_endsr  r  r^  r  rd     s\    


rd   z
aten::rollc                 C  s   t |t |ksJ |}tt |D ]}g }tj| ||| g||  gtjgd}|| tj| ||| gdg||  gd}|| | jdg|R d|| i}q$|S )Nr  r   r.  r/  )rJ  r  r
   r  r  r  r  r"  )r%  r6  Zshiftsr  r  ry  r  r2  r  r  r  r     s    

r   zaten::crossc                 C  sp   t ||}t| |dg|g}t| |dg|g}t| |dg|g}t| |dg|g}t| t| ||t| ||S )Nr  r<  )r
   Z_get_dim_for_crossr   r   r   )r%  r,  r9  rA   Zroll_x_1Zroll_y_1Zroll_x_2Zroll_y_2r  r  r  r>     s    r>   zaten::cdistr  #use_mm_for_euclid_dist_if_necessaryc                 C  s   t |d}t |d}|d us$J |d us0J t |d}t |d}|dkr||dksp|d u r||dkr||dkr|t| ||S t |}|d usJ t | ||d g}	t | ||d g}
t| |	|
|dd	d
S )Nr_  rN  ry  r  r<     r  gư>F)r  r  )r
   r  rW  _euclidean_distr{  r  r   )r%  r  r  r  Zcompute_modeZrow_size_x1Zrow_size_x2Zp_floatr  Zbroadcasted_x1Zbroadcasted_x2r  r  r  r+     s.    
r+   c              	   C  s  t |}|d usJ t j| t| |t | ddgdd}t| |}t j| t| |t | ddgdd}t| |}| jdgt| t | d|||gR ddi}| jdg|||gR ddi}	t| |t	| |	dd}
t
j|
}| jd	t | d
| d}t j| d|
|dd}
t| |
}
|
S )Nr  r  Tr  r.  g       r/  r_  r[  r  r\  r  r  r  )r
   r{  r  r   r  r   r"  r   r   r   r   r`  ra  re  r  r   )r%  r  r  r  Zx1_normZx1_padZx2_normZx2_padZx1_Zx2_r  rj  r   r  r  r  r  $  sJ    


	

r  z
aten::lerpc                 C  sx   |  d||}t| |  d|| j dtdd|  d||  d|||  d||  d||  d| j dtdd|S )	NrE  r  rQ  r  rR  r:  r=  rO  )r"  r  rS  r   )r%  r6  ri  r  diffr  r  r  rs   N  s    rs   zaten::broadcast_tensorsc                   sT   t |}t |d |D ]}t |q fdd|D } jdg|R  S )Nr   c                   s   g | ]}t  |qS r  )rM   r|  r%  Zt_with_final_shaper  r  r  m  r#  z%broadcast_tensors.<locals>.<listcomp>r  )r
   r  r  r   r"  )r%  r6  Zall_tensorsr   Zt_listr  r  r  r'   c  s    
r'   zaten::is_pinnedc                 C  s   d S r  r  )r%  r6  r  r  r  r  rl   q  s    rl   prim::ConstantSplitc                 C  s^   t ||}|d u r"t dd|S |g||  }|| }|rF|| | jd|||t|dS )Nr  r  r  r  )r
   r  r  r  r"  rJ  )r%  r6  r  rA   r   r  r  r  r  r  r   w  s    
r   prim::ConstantChunkc                 C  s@   t ||}|d u r"t dd|S || d | }t| |||S )Nr  r  r<  )r
   r  r  r   )r%  r6  r  rA   r  r  r  r  r  r     s    r   zprim::shapec                 C  s   |  d|S r(  r*  rx  r  r  r  r     s    r   z	prim::maxc                 C  s   t j| d||ddS )Nr  r  r  r  rH  r  r  r  r     s    
r   z	prim::minc                 C  sB   |s6t |r,t| || jdtdgd}t| |S t| ||S r  )r
   r  r   r"  rS  r   r   rH  r  r  r  r     s
    

r   z
prim::datac                 C  s   |S r  r  rx  r  r  r  r     s    r   zprim::layoutc                 C  s   | j dtddS r  r  rx  r  r  r  r     s    r   r  c                 O  s   d S r  r  r%  rg  r@  r  r  r  r     s    r   zprim::ListUnpackzlist[_C.Value] | None)r%  r  c                 O  s2   t |dkr.|d   dkr.t|d S d S )Nr<  r   r  )rJ  r  r  r
   r  r  r  r  r  r     s     r   zprim::TupleConstructc                 O  s   d S r  r  r  r  r  r  r     s    r   zprim::Uninitializedc                 O  s   d S r  r  r  r  r  r  r     s    r   zprim::unchecked_castc                 C  s   |S r  r  rx  r  r  r  r     s    r   zprim::dtypec                 C  s.   t |}|d u rtjj}| jdt|dS rP  )r
   r  r   r`  rd  r"  rS  r   )r%  r6  rg  r  r  r  r     s    
r   prim::tolistc                 C  s&   t |d}|dkr"t dd|S |S )ztolist is currently supported only for 1D input tensors.

    dim_val and elem_ty_val represent dimension and type annotations
    that need to match dimension and type of the input tensor.
    ry  r<  r  zdim_val > 1)r
   r  r  )r%  r,  Zdim_valZelem_ty_valrA   r  r  r  r     s    r   r/  Nonec                 O  s>   | j   }t|tjrd S tdd|  d| j  S )Nr/  z,output type should be 'DeviceObjType', not '')	original_noderR  r\  r0  r   r  r
   r  r  )r%  rg  r@  output_typer  r  r  r     s    r   z
prim::Loopzlist[_C.Value]c              	   O  s*  | j }| j}| j}| j}tj}tj}t| }	t	j
| dg|R | t|	d\}
}}t|	|D ]\}}t| D ]l\}}|dkr|t|k r|||   |dkrx|d t|k rxt| tjsx|||d    qxtj||j|||d qdtj||}tjr&tj||| |S )NZLoopr  Zn_blocksr   r<  F)r  envvalues_in_envparams_dictr   r'  r  r=  blocksr   add_op_with_blocksoutputsSizerJ  r>  r8  rg  r#  r\  r0  r   r$  rS  _jit_pass_onnx_blockblock%_jit_pass_fixup_onnx_controlflow_nodeonnx_shape_inference(_jit_pass_onnx_node_shape_type_inference)r%  rg  attrsr  r  r  r  r'  opset_version
old_blocks_new_op_outputsnew_block_contextsnew_node	old_blocknew_block_contextry  Zb_infixed_outputsr  r  r  r     sR    r   zprim::Ifc              	   O  s  | j }| j}| j}| j}| j}tj}tj}	|d  	 dk}
|
rt
|d  d }t|trnt|nt|}|r~dnd}t| | }tj|||||d}t| }t| }g }tt|D ]B}|| |vrtd||  d|| |||  }|| q|S t| }tj| dg|R | t|d	\}}}t||D ]"\}}tj||j|||d
 qXtj ||	}tj!rtj"|||	 |S d S )Nr   r  rT  r<  TzThe sub block ATen output z is not in env.Ifr  F)#r  r  r  r  r  r   r'  r  r  r  r
   r  r  r0  r1  r  r  r  rS  r   r  r  r  rJ  r	   rY  r  r=  r   r  r  r>  r  r  r  )r%  rg  r  r&  r  r  r  r  r'  r  Z	static_ifZ
input_flagr  Z	block_idxZ	current_bZif_output_listZcurrent_b_listZfinal_b_listrf  Zonnx_br  r  r  r  r  r  r  r  r  r  r   4  sv    r   r!  c                   s.   j }| rd S t|  tjr*d S |ddkrN jdt	
|ddS |ddkrr jdt	
|ddS |  tj s|  tj r jdtt	
|ddS |  tj r fddt	
|dD } jd	g|R  S td
|d dtj d| d S )NrT  r   rQ  rR  rU  Zvalue_sc                   s   g | ]} j d |dqS )rQ  r  r*  )r}  rU  r'  r  r  r    s   z!prim_constant.<locals>.<listcomp>r  z"Unsupported prim::Constant kind: 'z'. Please send a bug report at .)r  r  r0  rR  r\  r   r  r  r"  r
   r  r1  r2  r3  ZofFloatsrS  r   Z	ofStringsr	   rY  r   ZPYTORCH_GITHUB_ISSUES_URL)r%  rg  r  r  Zstr_constantsr  r'  r  r     s8    

r   
prim::type)r%  device_valuec                 O  sJ   |   dkr<t|   }|d ur<| jdt|dS tdd|S )Nr/  rQ  r  r  z,Device type cannot be statically determined.)	r  r  r   Zget_device_from_valuer,  r"  r  r
   r  )r%  r  rL  r@  r  r  r  r  r     s    r   zonnx::Placeholderc                 O  s*   | j }| j}| j}| j}tj||||S r  )r  r  r  r  rS  r   Z'_jit_onnx_convert_pattern_from_subblock)r%  rg  r  r  r  r  r  r  r  r  r     s    r   zaten::resolve_conjzaten::resolve_negr+  c                 C  s   |S r  r  r+  r  r  r  r    s    r  zaten::_conjzaten::conj_physicalc                 C  s    t |rt d|S t| |S )Nz aten::_conj, aten::conj_physical)r
   Zis_complex_valuer  r  r+  r  r  r  r    s    
r  zaten::logit)r%  r6  r  c                 C  s   | j dtdd}t|s| j d|tj| d}|  d||}|  d||}|  d|||}|  d	||}|  d|||}n|}|  d||}	|  d
||	}
|  d|
S )NrQ  rO  rR  r[  r\  rE  r  r  r  rZ  r  )	r"  rS  r   r
   r  r   r`  ra  re  )r%  r6  r  rq  Zone_sub_epsZself_less_equal_one_sub_epsZtemporary_selfZtemporary_self_less_epszr   rB   r  r  r  r     s    
r   )N)N)N)rO  )T)N)N)N)N)N)N)r   N)N)F)N)N)NNN)N)N)FF)NN)NN)N)FN)NNNFN)F)NNF)NN)F)NNNFN)F)F)NNNFN)F)F)NNNFN)F)N)N)NN)NN)NNFN)NNFN)NNN)N)F)r  )NF)FN)N)r  )N)TNNNNF)N)N)r  r  )N)N(o  __doc__
__future__r   rd  r  r  r  r  typingr   r   Ztyping_extensionsr   rS  Ztorch._C._onnxr   Z_onnxr]  Ztorch.nn.modules.utilsZ
torch.onnxr   r   r	   r
   Ztorch.onnx._globalsr   Ztorch.onnx._internalr   r   collections.abcr   Ztorch.typesr   r  partialZonnx_symbolicZ_onnx_symbolicr   r  r-  r3  rG  r   r   r   r   r   r   rB   rH  r   rK  rX  rW  rR   rT   r   r   r*   r   r  r   r&   r   r   r   r   r   r   r   r<   r   r   r   r   r    r   r   r  Z_apply_paramsr  r?   r  r  r   r   rN   r(   rM   rG   rF   r   r   r   r  r
  r  r   r  r   r  r   r   r   r   r   r   r   r   r   r,   rS   r  r   rr   r\   r   r   r[   nnmodulesutilsZ_singleZ_pairZ_triplerI  r   r   r   rT  r]  ra  rc  r1   rk  r   r   r   ry  rz  r$   r%   r  r  r  rJ   r   r^   r  r   r  rY   rq   r  r  r  r   r   r   r   r  r  r  r|   r  r  r  r;   r7   r8   r9   r4   r5   r6   r"   r   rp   rj   r   rE   r   rh   rg   rf   re   r)   r   r=   r   r0   r   r}   r   r~   r   r/   r.   r-   r   r   r   r   r   r   r   rL   rD   r  r   r3   r  r  r  r	  r
  r  r  r  r  r  r  rI   rH   r   r   r   r   r  r  r   r  r   r   r   rW   rV   r   rO   r   rc   rb   ra   r   r`   r   r   r  r   r   r   r:   r   r   r   r   r   ri  rn  rp  r   r   rw  rz  r@   r2   r~  r  r   r   r   r   r   r   r   r#   r{   rK   rQ   r   r   rm   r  r  r   r   r   r   r   r   rk   r  r  r   rX   r  r   r	  r  r   r   r   rz   rt   r   r   ri   rw   rx   rv   ru   rU   r   r!   r   r   rZ   r]   r  rA   r  r  rn   r   r  r  ro   r   r   r  r  ry   r_   r   rC   r   rP   rd   r   r>   r+   r  rs   r'   rl   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r  r  r   r  r  r  r  <module>   s^    
&
5


>	&

	>5			:



7			


6)		&&







C	7###72L4$_"#



	     	    
    	    	H&


Jg
F
N  cB
	
*      
$	$	L
 $,!,,;
"
$:	
,	      & 
E *
4Z"	