o
    wZh^                     @  s  d Z ddlmZ ddlZddlZddlZddlmZ ddlZddlm	Z	 ddl
mZ ddlmZmZmZmZmZmZ ddlmZmZ erJdd	lmZ g d
ZejejddZedede ddddddZ!eddddZ"ede dddddZ#ed e dddd!d"Z$ed#dd$d%Z%ed&ede dd'ddd(d)Z&ed*	+ddd,d-Z'ed.e dd'dd/d0Z(ed1e)d2d3d4gd5ed6e)d7d8d4gd5ed9e)d:d;d4gd5ed<e)d=d3d>gd5ed?e)d@d8d>gd5edAe)dBd;d>gd5edCe)dDd8dEgd5ddKdLZ*edMedd+d+d+d+d+d+ddNdOZ+edPe dd'dddddQdRZ,edSe dd'ddddTdUZ-edVe dd'dWdddXdYZ.edZdd[d\Z/ed]dd^d_Z0ed`ddadbZ1edcddddeZ2edfddgdhZ3ediddjdkZ4edldddmdnZ5edoddpdqZ6edrddsdtZ7eduddvdwZ8edxedddydzZ9ed{dd|d}Z:ed~e dd'd'd'dddZ;ede dd'd'd'd'dddZ<ede ddd'd'd'dWddddZ=ede dd'd'dWddddZ>ede dd'd'dWddddZ?ede dd'ddddZ@eddddZAede ddd'd'ddddZBede ddd'd'ddddZCede dd'd'ddddZDdddZEedddddZFedededdddZGedededdddZHeddddZIeddddZJeddddZKeddddZLede dd'dddZMedejdd+dddddZNedddddÄZOedădddƄZPedǃdddɄZQedʃddd̄ZRed̓dddτZSedЃddd҄ZTedӃedԃdddքZUed׃ed؃dddڄZVddd܄ZWdddބZXdddZYede ddddddddZZeddddZ[ededd+d+e dd'd'dddZ\ede ddddddddZ]ede dddd'd'd'dd'd'	dddZ^ede dddddddZ_eddddZ`ed						ddddZaeddd dZbeddddZceddddZdeddd	d
ZeeddddZfeddddZgdS (  z(This file exports ONNX ops for opset 11.    )annotationsN)TYPE_CHECKING)_C)_onnx)_type_utilserrorssymbolic_helpersymbolic_opset10symbolic_opset9utils)	jit_utilsregistration)Sequence)9addappendarangeargsort
atleast_1d
atleast_2d
atleast_3dcatchunk	clamp_max	clamp_minclampconstant_pad_ndcumsumDeleteembedding_bagembedding_renormflattengatherhardtanhhstackim2col
index_fillindex
index_copy	index_putinsert
linalg_detlinalg_vector_normlogdetmasked_scattermasked_selectmmnarrownormalpadpixel_shufflepopprim_constant_chunkreflection_padrelu6	remainderreplication_padroundscatterselectsizesortsplit_with_sizessplitsqueezestacktopkunbind
unique_dim	unsqueezevstack   )Zopsetzaten::hardtanhTvfgjit_utils.GraphContextself_C.Valuemin_valfloatmax_valc                 C  s`   t j|t jj}| jdtj|| dd}| jdtj|| dd}tj	| d|||ddS )NConstantdtypeZvalue_tClip   Zopset_before)
r   JitScalarType
from_valueFLOAToptorchtensorrT   r   _op_with_optional_float_cast)rK   rM   rO   rQ   scalar_type ra   J/var/www/auris/lib/python3.10/site-packages/torch/onnx/symbolic_opset11.pyr"   `   s   r"   zaten::clampc                   s    fdd}t j|t jj}|t jjkr|||}|||}t|r*t ||S t|r5t ||S t|dkrNt|dkrNtj	 d|||ddS t t |||S )Nc                   s*   | d urt | s jd| | dS | S )NCastZto_i)r   _is_noner\   	onnx_type)r^   rT   rK   ra   rb   _cast_if_not_nonev   s   z clamp.<locals>._cast_if_not_noner   rV   rW   rX   )
r   rY   rZ   	UNDEFINEDr   re   r   r   _get_tensor_rankr_   )rK   rM   minmaxrh   r`   ra   rg   rb   r   t   s"   




r   zaten::clamp_minc                 C  s^   | j d|tj| d}t|dkr%t| }tj	| d|||ddS tj	| d||ddS )Nrc   rd   r   rV   rW   rX   ZMax
r\   r   rY   rZ   rf   r   rj   opset9Zunusedr_   )rK   rM   rk   rl   ra   ra   rb   r         

r   zaten::clamp_maxc                 C  s^   | j d|tj| d}t|dkr%t| }tj	| d|||ddS tj	| d||ddS )Nrc   rd   r   rV   rW   rX   ZMinrm   )rK   rM   rl   rk   ra   ra   rb   r      ro   r   zaten::relu6c                 C  sX   t j|t jj}| jdtjd| dd}| jdtjd| dd}t| |||S )NrR   r   rS   rU      )	r   rY   rZ   r[   r\   r]   r^   rT   r   )rK   inputr`   rO   rQ   ra   ra   rb   r7      s   r7   zaten::selectic                 C  s   | j d|||dS )NGatheraxis_ir\   )rK   rM   dimr&   ra   ra   rb   r<      s   r<   zaten::index_putFc                   s  t |rt |}n|g}t |d}t|dkr|S t|dkrmtt|D ]}t || r;d|| ||< q(|d }|dd  D ]	}t	||}qFd|  fdd|D }jdg|R d	d
i}nW|d }|}	t |	rt 
|}
|
d ur|
dkrt||	|S t 
|	}t 
|}|d ur|d ur||krt |	tt||}	t||	|S d| t |d
g}t jd|dgt|gtjgd}jd |dd}t 
|}
|
d ur|
dkrt||d }t ||}tj|tjj}|tjjkr&tj|tjj}||kr%jd|| d}n	|r/td||rVjdd|tjdg| dd}d|||}t	||}|S d|||}|S )Nbr      ZNonZeroShapec                   s(   g | ]}t t| d dgqS )N)r   _unsqueeze_helperrn   expand).0indZbroadcast_index_shaperK   ra   rb   
<listcomp>   s    zindex_put.<locals>.<listcomp>Concatru   r{   ZaxesZstartsZendsrt   rc   rd   z'self does not have a valid scalar type.ConstantOfShaperS   rU   	ScatterND)r   _is_packed_list_unpack_list
_parse_arglenrange_is_boolr\   rn   r   rj   Zmasked_fillr|   listr-   _slice_helpersysmaxsizer}   _reshape_helperr   rY   rZ   ri   rf   r   SymbolicValueErrorr]   r^   rT   )rK   rM   Zindices_list_valuevalues
accumulateZindices_listZidx_r&   r   Zbool_inprankZ	mask_rankZ	self_rankZsub_data_shapeZvalues_shapeZself_scalar_typeZvalues_scalar_typeZzerosresultra   r   rb   r(      s   
(






r(   zaten::pixel_shufflec                 C  s8   t |}|d ur|dkrt ddS | jd||ddS )N   r3   zonly support 4d inputZDepthToSpaceZCRD)Zblocksize_imode_s)r   rj   _unimplementedr\   )rK   rM   Zupscale_factorr   ra   ra   rb   r3   I  s   
r3   zaten::upsample_nearest1dZupsample_nearest1d   Znearest)Zdecoratezaten::upsample_nearest2dZupsample_nearest2dr   zaten::upsample_nearest3dZupsample_nearest3d   zaten::upsample_linear1dZupsample_linear1dZlinearzaten::upsample_bilinear2dZupsample_bilinear2dzaten::upsample_trilinear3dZupsample_trilinear3dzaten::upsample_bicubic2dZupsample_bicubic2dZcubicnamestrrw   intinterpolate_modec                 C  s   t | ||S N)r   Z_interpolate_helper)r   rw   r   ra   ra   rb   _interpolateR  s   r   zaten::__interpolatec              	   C  s   t | ||||||S r   )r   Z__interpolate_helper)rK   rq   r=   Zscale_factormodeZalign_cornersZrecompute_scale_factorZ	antialiasra   ra   rb   __interpolater  s   r   zaten::gatherc                 C  s*   t |drt ddS | jd|||dS )Nrr   r!   zsparse_grad == TrueZGatherElementsrt   )r   _maybe_get_constr   r\   )rK   rM   rw   r&   Zsparse_gradra   ra   rb   r!     s   r!   zaten::scatterc              	   C  s~   t j|}t|}t|r| jd||||dS t j||kr0| jd|t j| d}| jd||t	| |||dS )NZScatterElementsrt   rc   rd   )
r   rY   rZ   r   _maybe_get_scalar	_is_valuer\   rf   rn   	expand_as)rK   rM   rw   r&   srcZsrc_typera   ra   rb   r;     s   

r;   zaten::cumsumnonec                 C  sn   | j dtj|tjdd}|r,|  dkr,t|dd}| j d|t	|
 d}n|}|  d	||}|S )
NrR   rS   rU   zprim::Constantrr   rT   rc   rd   ZCumSum)r\   r]   r^   r   nodekindr   
_get_constr   rY   rf   )rK   rM   rw   rT   Z
dim_tensorZparsed_dtypecastZcsumra   ra   rb   r     s   r   zaten::masked_selectc                 C  s$   t | t | ||}| d||S )NGatherND)rn   nonzeror   r\   )rK   rM   maskr&   ra   ra   rb   r.     s   r.   zaten::masked_scatterc                 C  sr   t | t | ||}t| |tdg}tj| |tdgtdgt | |tdgd}| 	d|||S )Nr{   r   r   r   )
rn   r   r   r   r   r]   
LongTensorr   r=   r\   )rK   rM   r   sourcer&   ra   ra   rb   r-     s   

r-   z	aten::lenc                 C  sT   t |s|  dkr| d|S t| || jdtdgd}t | |dgS )Nzonnx::SplitToSequenceZSequenceLengthrR   r   rU   )	r   _is_tensor_listr   r   r\   r=   r]   r   _squeeze_helper)rK   rM   Zsz_0ra   ra   rb   _len  s   r   zaten::__getitem_c                 C  s0   t |r| d||S ddlm} || ||S )N
SequenceAtr   )
__getitem_)r   r   r\   Ztorch.onnx.symbolic_opset9r   )rK   rM   rr   getitemra   ra   rb   r     s   
r   zaten::_set_itemc                 C  s   |  d||}|  d|||S )NSequenceEraseSequenceInsertrv   )rK   tensor_listrr   rI   ra   ra   rb   	_set_item  s   r   zaten::appendc                 C     |  d||S Nr   rv   )rK   rM   r^   ra   ra   rb   r        r   z	aten::addc                 C  sn   t |r/t |r/| }| dkrt ddS t |}|}|D ]	}| d||}q#|S t	| |||S )Nzprim::ListConstructr   z6does not support adding dynamic tensor list to anotherr   )
r   r   r   r   r   r   r   r\   rn   r   )rK   rM   otheralphaZtensor_list_nodeZtensorsltra   ra   rb   r     s   
r   zaten::insertc                 C  s   |  d|||S r   rv   )rK   rM   posr^   ra   ra   rb   r)     s   r)   z	aten::popc                 C  r   Nr   rv   rK   r   rw   ra   ra   rb   r4     r   r4   zaten::Deletec                 C  r   r   rv   r   ra   ra   rb   r     r   r   z	aten::catc                 C  s6   t |rt| ||S t |dd}| jd||dS )Nrr   rw   ConcatFromSequencert   )r   r   rn   r   r   r\   r   ra   ra   rb   r     s   
r   zaten::stackc                 C  s8   t |rt| ||S t |dd}| jd||ddS )Nrr   rw   r   ry   ru   Z
new_axis_i)r   r   rn   rB   r   r\   r   ra   ra   rb   rB     s   
rB   zaten::_unique2c           	      C  s$   | j d||dd\}}}}|||fS )NUniquer   )sorted_ioutputsrv   )	rK   rM   sortedreturn_inversereturn_countsu_indicesinverse_indicescountsra   ra   rb   _unique2  s   
r   zaten::unique_dimc           
      C  s&   | j d|||dd\}}}}	|||	fS )Nr   r   )ru   r   r   rv   )
rK   rM   rw   r   r   r   r   r   r   r   ra   ra   rb   rE   #  s   

rE   z
aten::topkc              	   C  s   t j| ||||||dS )N)largestr   out)r   Z_topk_helper)rK   rM   krw   r   r   r   ra   ra   rb   rC   .  s   rC   z
aten::sortc                 C  s   t j| ||||dS N)	decendingr   r   Z_sort_helper)rK   rM   rw   r   r   ra   ra   rb   r>   6  s   r>   zaten::argsortc                 C  s   t j| ||||d\}}|S r   r   )rK   rM   rw   r   r   _indicesra   ra   rb   r   <  s   

r   zaten::roundc                 C  sz   t |s|S |dkr| d|S | d|| jdttd|d}| d|}| d|| jdttdd| dS )Nr   ZRoundMulrR   
   rU   r{   )r   _is_fpr\   r]   r^   pow)rK   rM   Zdecimalsmulr:   ra   ra   rb   r:   E  s   
$ r:   zaten::remainderc                 C  s4   t |s
t |rt| ||S | jd||ddS )NModr   )Zfmod_i)r   r   rn   r8   r\   )rK   rq   r   ra   ra   rb   r8   S  s   r8   zaten::splitc              
     s  t ||sy jd|||d|d u rS t |rmtt ||krm fddt |D } jdtjdgtjdd} jdtj|gtjdd}g }t	|D ]}	 d	|||	 }
|
 d
|||
| |
}qQ|S  fddt	|D S t ||||S )NSplitToSequencert   c                   s   g | ]
}t  |d gqS r   )r   r|   )r~   rI   rg   ra   rb   r   f  s    zsplit.<locals>.<listcomp>rR   r   rS   rU   AddSlicec                   s2   g | ]}  d  j dtj|gtjddqS )r   rR   rS   rU   )r\   r]   r^   long)r~   rr   rK   Z	split_outra   rb   r   t  s    )r   Z_is_split_staticr\   r   r   r   r]   r^   r   r   r   rn   r@   )rK   rM   Zsplit_size_or_sizesrw   _outputssplit_sizesstartaxisresrr   endra   r   rb   r@   Z  s0   

	r@   zaten::split_with_sizesc                 C  s   t | ||||S r   )r@   )rK   rM   r   rw   r   ra   ra   rb   r?     s   r?   zaten::unbindc              	   C  sB   |d u r| j d|| j dtjdtjdd|ddS t| |||S )Nr   rR   ry   rS   rU   r   )ru   
keepdims_i)r\   r]   r^   r   rn   rD   )rK   rM   rw   r   ra   ra   rb   rD     s   rD   c                 C  sz  t |st |rt |r| jd|ddd}t| || jdtdgd}t 	|}|du r<| d| d	|}n| jdtj|tj
d
d}| d| d|| jdtjdtj
d
d|}| jd|tjjd}| jd|| jd|tjdgtj
d
ddd}t | || jdtddgd}| jdt| |dgddgd}t | || jdtdgd}| jd|tjjd}|S )a!  Generate paddings in ONNX order based on pad in pytorch.

    Args:
        input: the input tensor.
        pad: the paddings in pytorch.
            The order is dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, ..., dim_m_begin, dim_m_end,
            where m is in range [0, n].
    r   r   ry   r   rR   rU   NSizerz   rS   Subr      rc   rd   r   r   rt   r{   	TransposeZperm_i)r   r   Z_is_listZ_is_scalar_listr\   rn   r=   r]   r^   rj   int64_C_onnxTensorProtoDataTypeZINT64r   opset10flip)rK   rq   r2   Zpad_lenr   	extensionpaddingsZ	padding_cra   ra   rb   _prepare_onnx_paddings  sF   
 
" r   zaten::constant_pad_ndc                 C  s:   d}t |}t ||}t| ||}| jd||||dS )NconstantPadr   )r   r   _if_scalar_type_asr   r\   )rK   rq   paddingvaluer   r2   ra   ra   rb   r     s
   
r   zaten::reflection_pad1dzaten::reflection_pad2dzaten::reflection_pad3dc                 C  "   d}t | ||}| jd|||dS )Nreflectr   r   r   r\   rK   rq   r   r   r   ra   ra   rb   r6        r6   zaten::replication_pad1dzaten::replication_pad2dzaten::replication_pad3dc                 C  r  )Nedger   r   r  r  ra   ra   rb   r9     r  r9   z	aten::padrq   r2   r   r   c                 C  sr   t |d}|dkrt| ||S |dkrt| ||S |dkr%t| |||S |dkr0t| ||S td| |)NsZ	replicater  r   ZcircularzUnrecognized padding mode )	r   r   r9   r6   r   rn   Z_pad_circularr   r   )rK   rq   r2   r   r   ra   ra   rb   r2     s   zaten::linalg_detc                 C  s   |  d|S )NZDetrv   )rK   rM   ra   ra   rb   r*     s   r*   zaten::logdetc                 C  s   t | t| |S r   )rn   logr*   )rK   rq   ra   ra   rb   r,     s   r,   aten::arangec                 G  s  dd }t |dkrFtdd |D rFtj}| jdtj|d |dd	}| jdtj|d
 |dd	}| jdtjd
|dd	}| d|||S t |dksRt |dkrt |dkr[d }n||d
 }tj| |d |d\}}}}| jdtjd| dd	}	| jdtjd
| dd	}| d|	||S t |dkst |dkrt |dkrd }n||d }tj| |d |d
 |d |d\}
}}}| d|||S t |dkr||d }tj| |d |d
 |d\}}}}| jdtjd
| dd	}| d|||S t	ddt | dS )Nc                 S  s   t | d} | S )Nrr   )r   r   rS   ra   ra   rb   _get_arange_dtype
  s   z!arange.<locals>._get_arange_dtyper   c                 s  s    | ]}t |tV  qd S r   )
isinstancer   )r~   valra   ra   rb   	<genexpr>  s    zarange.<locals>.<genexpr>rR   r   rS   rU   ry   Ranger   )r   rT   r      r   )r   r   steprT   rp   )r   r   rT   r	  zwith z
 arguments)
r   allr]   r   r\   r^   r   Z_arange_cast_helperrT   r   )rK   argsr
  rT   r   r   Zdelta_defaulttype_r  Zstart_defaultr   ra   ra   rb   r     sj   
r   zaten::_dim_arangec                 C  s@   |  d|}| j d|| j dt|ddd}t| |dd d d S )Nrz   rs   rR   rU   r   rt   r   )r\   r]   r^   r   )rK   likerw   Z
like_shapestopra   ra   rb   _dim_arangeP  s
   r  z
aten::size)Zquantize_outputc                 C  s"   |d u r
|  d|S t| ||S )Nrz   )r\   r   _size_helperrK   rM   rw   ra   ra   rb   r=   Z  s   r=   zaten::squeezec                 C  sx  |d u r
|  d|S t|st| ||gS t|dd}t|}|}|d ur1|dk r1||7 }t||}|dk r?|d u sC|d u r| j dt|gd}t	| ||}| j dtj
dtjdd}|  d	||}	tj| d
|	dd\}
\}}}t|||g}t|j| | d|}t|j| |
S |}|dkrtdt| d d t| d d d  |S t| ||gS )NZSqueezerr   rw   r   rR   rU   ry   rS   EqualIfr   )n_blocksZIdentityz5This model contains a squeeze operation on dimension z. The size of z%this dimension in the given input is z. The model will zWbe exported without the squeeze node. If the model is intended to be used with dynamic z7input shapes, please export with dynamic_axes argument.)r\   r   _is_constantr   r   rj   Z_get_tensor_dim_sizer]   r^   r  Zonesr   r   add_op_with_blocksr   _add_output_to_blockblockwarningswarnr   )rK   rM   rw   Z
input_rankZadjusted_dimdim_sizeZdim_constantr=   	const_oneZcondZif_opZ
if_contextZelse_contextr   Zsqueeze_Z	identity_ra   ra   rb   rA   b  sX   


rA   zaten::unsqueezec                 C  s(   t |rt |dd}t | ||gS )Nrr   rw   )r   r  r   r|   r  ra   ra   rb   rF     s   
rF   zaten::mmc                 C  s   | j d||dddS )NZGemmg        g      ?)Zbeta_fZalpha_frv   )rK   rM   r   ra   ra   rb   r/     s   r/   zaten::indexc                 C  s   t |rt |}n|g}t|dkr9|d }t |s9t |s,tj|tjj	kr9t
| |}| d||S t
| ||S )Nry   r   r   )r   r   r   r   re   r   r   rY   rZ   UINT8rn   r   r\   r&   )rK   rM   r&   r   ra   ra   rb   r&     s   


r&   zaten::index_fillc                 C  sJ   t | |||\}}t |}t ||}t| ||d }t| ||||S r   )r   _index_fill_reshape_helperr   r   rn   r}   r;   )rK   rM   rw   r&   r   Zexpanded_index_shapeexpanded_indexZexpanded_valuera   ra   rb   r%     s   
r%   zaten::index_copyc                 C  s$   t | |||\}}t| ||||S r   )r   r%  r;   )rK   rM   rw   r&   r   Z_expanded_index_shaper&  ra   ra   rb   r'     s   r'   zaten::bitwise_right_shiftzaten::__rshift_c                 C     t j|t jjt j|kr| jd|t j| d}t j|t jjt jjkr3| jd||ddS | jdtjdtj	dd	}t
|sO| jd|tjjd}| d
||}| jd|t j| d}| d||}|S )Nrc   rd   BitShiftRIGHTZdirection_srR   r   rS   rU   PowDivr   rY   rZ   ri   r\   rf   r$  r]   r^   Zfloat32r   r   r   r   r[   )rK   rM   r   twotwo_powrshiftra   ra   rb   	__rshift_  2   

r1  zaten::bitwise_left_shiftzaten::__lshift_c                 C  r'  )Nrc   rd   r(  LEFTr*  rR   r   rS   rU   r+  r   r-  )rK   rM   r   r.  r/  lshiftra   ra   rb   	__lshift_  r2  r5  c                 C  s   |  d|| j dt|d d}|  d|| j dt||d  d}|  d| j dtdd|| j dt|d}td|| |}| j d|dd}t| |dg}t| || j dtd	dgd}	|  d||	}
|
S )
Nr   rR   r   rU   r   ry   r  r   r{   )r\   r]   r^   r   rF   r   r|   r   )rK   Zinput_dZkernel_size_dZ
dilation_dZ	padding_dZstride_dZblocks_dZblocks_d_indicesZkernel_gridZkernel_maskZ
block_maskra   ra   rb   _get_im2col_indices_along_dim  s0   	r6  c                 C  s.   | j dtdd||gd d}|  d||S )NrR   r   r   rU   r   )r\   r]   r   )rK   rq   	padding_h	padding_wr2   ra   ra   rb   _get_im2col_padded_input8  s    r9  c              
   C  s   t | || jdtdd}t | || jdtdd}| d|| jdt|| d}| jdt| |dgt| |dg| jdtdgdddS )	NrR   r   rU   ry   r   r   r{   rt   )r=   r\   r]   r^   r   r|   )rK   rq   kernel_hkernel_wZ	batch_dimZchannel_dimZchannel_unfoldedra   ra   rb   _get_im2col_output_shape@  s   r<  zaten::im2colisc                 C  s  t | || jdtdd}t | || jdtdd}|d |d }}	|d |d }
}|d |d }}|d |d }}t| ||||
|}t| |||||	}t| |||}t| ||
|}| jd||dd}| jd||d	d}| jd
|g dd}t| ||S )NrR   r   rU   r   r   ry   rs   rt   r   r   )r   ry   r   r   r   r   r   )	r=   r\   r]   r^   r6  r<  r9  r   r   )rK   rq   Zkernel_sizeZdilationr   ZstrideZinput_hZinput_wZstride_hZstride_wr7  r8  Z
dilation_hZ
dilation_wr:  r;  Zblocks_row_indicesZblocks_col_indicesZoutput_shapeZpadded_inputoutputra   ra   rb   r$   P  s$   r$   zaten::narrowc                 C  s"   |  d||}tj| ||||dS )Nr   r   )r\   r   r   )rK   rq   rw   r   lengthr   ra   ra   rb   r0     s   r0   zaten::flattenc                 C  s   t |}|dkr|S |dkr&|dks|d ur%||d kr%| jd||dS n|dkrB|dks8|d urB||d krB| jd||d dS |d u rLt dd	S |dk rT|| }t | ||||S )
Nry   r{   ZFlattenrt   r   r   rw   zfONNX and PyTorch use different strategies to split the input. Input rank must be known at export time.)r   rj   r\   r   Z_flatten_helper)rK   rq   Z	start_dimZend_dimrw   ra   ra   rb   r      s$   
r    zaten::linalg_vector_normrx   Sequence[int] | Nonekeepdimboolc                 C  s   t | |||||S r   )r   Z_linalg_vector_norm_helper)rK   rM   ordrw   rB  rT   ra   ra   rb   r+     s   
r+   zaten::embedding_bagc
           
      C  s   t | |||||||||	
S r   )r   Z_embedding_bag_helper)
rK   Zembedding_matrixr   offsetsZscale_grad_by_freqr   sparseZper_sample_weightsZinclude_last_offsetZpadding_idxra   ra   rb   r     s   r   zaten::embedding_renormc              	   C  s   |  d|}|  d||}t|}|dkrd}n|dkrd}n
td| d|| j ||dgdd	}|  d
|| j dtdd}	t|}|  d||	}
|  d||
}|  d|  d||||}|  d|t| |dg|S )Nr   rs   ry   ZReduceL1r   ZReduceL2z8Unsupported: ONNX export of embedding_renorm with norm: z. Only 1. and 2. are supported.)axes_ir   r   rR   gHz>rU   r,  r   ZWhereZGreaterr   )r\   r   r   r   r]   r^   r   r|   )rK   weightr   Zmax_normZ	norm_typeZunique_indicesZpartial_weightZnorm_iZpartial_weight_normZpartial_weight_norm_scalesZpartial_weight_renormra   ra   rb   r     s<   

r   zaten::chunkc              
   C  s   | j d|  d||dd}|  d|| j dtjdgtjdd	}|  d
|  d|||}t| ||d |  d||  d||g}| j dg|R ddi}t| |||S )Nrs   rz   r   rt   r   rR   ry   rS   rU   r,  r   r   r   ru   )r\   r]   r^   r   rn   r}   r@   )rK   rM   chunksrw   r"  Zchunk_size_s
chunk_sizeZ	chunk_vecra   ra   rb   r     s   r   zaten::normalc	           
      C  sD   |d urt |st| ||d }t| || d|}	t| |	|S )NZRandomNormalLike)r   re   rn   r}   r   r\   r   )
rK   meanZstdsizes	generatorrT   ZlayoutZdeviceZ
pin_memoryr   ra   ra   rb   r1     s   r1   zaten::atleast_1dtorch._C.Valuec              
   C  s   t |r?t |r?t |}g }|D ]"}|}t |}|dkr0t | || jdtdgd}|	| q| jdg|R  S t |}|dkrXt | || jdtdgd}|S )Nr   rR   ry   rU   SequenceConstruct)
r   r   r   r   rj   r   r\   r]   r^   r   rK   rM   r   Znew_tensor_listr^   Z
new_tensorZtensor_rankra   ra   rb   r     s$   


r   zaten::atleast_2dc                 C  s   t |rNt |rNt |}g }|D ]1}|}t |}|dkr2t | || jdtddgd}n|dkr?t j	| |dgd}|
| q| jdg|R  S t |}|dkrjt | || jdtddgd}|S |dkrwt j	| |dgd}|S )Nr   rR   ry   rU   rG  rP  r   r   r   r   rj   r   r\   r]   r^   r|   r   rQ  ra   ra   rb   r   8  s2   


r   zaten::atleast_3dc                 C  sP  t |ret |ret |}g }|D ]H}|}t |}|dkr2t | || jdtg dd}n$|dkrIt j	| |dgd}t j	| |dgd}n|dkrVt j	| |dgd}|
| q| jd	g|R  S t |}|dkrt | || jdtg dd}|S |dkrt j	| |dgd}t j	| |dgd}|S |dkrt j	| |dgd}|S )
Nr   rR   )ry   ry   ry   rU   ry   rR  r{   r   rP  rS  rQ  ra   ra   rb   r   Y  sH   


r   zprim::ConstantChunkc              
   C  s  |  d|}| j dtj|gtjdd}| j d||dd}| j dtjdgtjdd}| j dtj|gtjdd}| j dtj|d gtjdd}	|  d	||	}
|  d
|
|}g }t|D ]'}| j dtj|d gtjdd}|  d||}||  d|||| |}q]|S )Nrz   rR   rS   rU   rs   r   rt   ry   r   r,  r   r   )r\   r]   r^   r   r   r   )rK   rM   rJ  rw   Zinput_shaper   Zinput_shape_dimr   rK  Zchunk_size_minus_1Zinput_shape_dim_shiftZ	chunk_dimr   rr   r&   r   ra   ra   rb   r5     s"    r5   zaten::hstackr   c              
   C  s   t | |}| d|| jdtjdtjdd}| d|}| d|}| jdtjdtjdd}| d	||}tj| d
|ddd\}\}}	}
|jd|ddd}t|j	| |	jd|ddd}t|	j	| |
  }|S )Nr   rR   r   rS   rU   rz   r   ry   r  r  r   )r  r   r   r   )r   r\   r]   r^   r   r   r  r   r  r  r   r>  )rK   r   Zfirst_tensorZfirst_tensor_shapeZfirst_tensor_dimr#  Zequal_to_oneZif_op_greaterZif_context_equalZelse_context_equalr   Z	result_ifZresult_elser   ra   ra   rb   r#     s2   
r#   zaten::vstackc                 C  s   t | |}| jd|dddS )Nr   r   r   )r   r\   )rK   r   ra   ra   rb   rG     s   
rG   )rK   rL   rM   rN   rO   rP   rQ   rP   )rK   rL   )F)r   r   rw   r   r   r   r   r   )r   N)
rK   rL   rq   rN   r2   rN   r   rN   r   rN   )rK   rL   rw   rA  rB  rC  )NNNNNN)rK   rL   rM   rO  )rK   rL   r   rN   )h__doc__
__future__r   	functoolsr   r   typingr   r]   r   Ztorch._Cr   r   Z
torch.onnxr   r   r   r	   r   r
   rn   r   Ztorch.onnx._internalr   r   collections.abcr   __all__partialZonnx_symbolicZ_onnx_symbolicZquantized_args
parse_argsr"   r   r   r   r7   r<   r(   r3   Z_apply_paramsr   r   r!   r;   r   r.   r-   r   r   r   r   r   r)   r4   r   r   rB   r   rE   rC   r>   r   r:   r8   r@   r?   rD   r   r   r6   r9   r2   r*   r,   r   r  r=   rA   rF   r/   r&   r%   r'   r1  r5  r6  r9  r<  r$   r0   r    r+   r   r   r   r1   r   r   r   r5   r#   rG   ra   ra   ra   rb   <module>   s   <"

{


	
$9G
2
  +3% +