a
    h                     @   s  U d dl Z d dlZd dlZd dlZd dlmZmZ d dlmZm	Z	m
Z
mZmZmZ d dlmZ d dlZd dlZd dlZd dlZd dlZd dlZd dlmZ d dlmZ d dlmZ d dlmZ d d	lmZ d d
l m!Z! er d dl"Z"d dl#Z#d dl$m%Z& d dl'Zd dl(Zd dl)Zd dl*Zd dl+Zda,ee- e.d< g dZ/e-dddZ0i Z1e2e3e4f e.d< d\e5eej6j7 e3dddZ8ee3 dddZ9e3dddZ:dd Z;e3dddZ<e=e>Z?G dd  d eZ@ej6j7dd!d"d#ZAeBe3d$f dd%d&ZCej6j7eBed$f d!d'd(ZDej6j7ed!d)d*ZEej6j7eBe3d$f d!d+d,ZFeBe3d$f eBe3d$f d-d.d/ZGeBeejHejIe4ejJeKejLe-f d$f eBd0 d1d2d3ZMeBejHd$f eBd0 eBejHd$f d4d5d6ZNeejHd7d8d9ZOeejHe4eKe-f d:ejHd;d<d=ZPejHeejHejIe4ejJeKejLe-f eejHe4eKe-f d>d?d@ZQdAeBe3d$f eBejHd$f eBd0 eBe3d$f eBejHd$f eBd0 e-eBdB eBeejHejIe4ejJeKejLe-f d$f eBeejHe4eKe-f d$f dCdDdEZRdAeBe3d$f eBejHd$f eBd0 eBe3d$f eBejHd$f eBd0 e-eBdB eBeejHejIe4ejJeKejLe-f d$f eBeejHe4eKe-f d$f dCdFdGZSeTdHdIdJZUG dKdL dLZVe jWG dMdN dNZXee3eBe3ee3ef f f ZYee.dO< e jWdPdQedRdSG dTdU dUZZedRdSG dVdW dWZ[edRdSddXej6j7eeeZee3ef f  dYdZd[Z\dS )]    N)MappingSequence)AnyCallableFinalOptionalTYPE_CHECKINGUnion)	TypeAlias)
FakeTensor)compatibility)FakeTensorProp)OperatorSupport)CALLABLE_NODE_OPS)_pytree_pybind_state_SUPPORT_ONNXRT)is_onnxrt_backend_supportedtorch_compile_backendOrtExecutionProviderOrtBackendOptions
OrtBackendreturnc                  C   sz   t du rvzVtd td td ddl} ddl} ddl} ddlm}m}m	}m
} da W n tyt   da Y n0 t S )	a!  Returns ``True`` if ONNX Runtime dependencies are installed and usable
    to support TorchDynamo backend integration; ``False`` otherwise.

    Example::

        # xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX)
        >>> import torch
        >>> if torch.onnx.is_onnxrt_backend_supported():
        ...     @torch.compile(backend="onnxrt")
        ...     def f(x):
        ...             return x * x
        ...     print(f(torch.randn(10)))
        ... else:
        ...     print("pip install onnx onnxscript onnxruntime")
        ...
    Nonnxruntimezonnxruntime.capi._pybind_stateZ
onnxscriptr   )decomposition_tablefx_onnx_interpreterpasses
type_utilsTF)r   	importlibimport_module
torch.onnxtorch.onnx._internal%torch.onnx._internal._exporter_legacytorch.onnx._internal.fxr   r   r   r   ImportError)torchr   r   r   r    r(   N/var/www/auris/lib/python3.9/site-packages/torch/onnx/_internal/onnxruntime.pyr   /   s    



r   _dumped_onnx_model)model_stringgraph_moduler   c                 C   s   t jdd}|sdS t|dd }| | d}t|d}||  W d   n1 s^0    Y  |t|< |dur| | d}t|d	d
d }|t|j W d   n1 s0    Y  |S )a  Stores the onnx model into a file.
    The name is "{ONNXRT_DUMP_PATH}{N}.onnx"
    where *N* is the number of files already stored with
    this prefix.
    If graph_module is not None, the graph is stored as a string with
    the same filename except the extension (.txt).
    ONNXRT_DUMP_PATHN    z.onnxwbz.txtwzutf-8)encoding)osenvirongetr*   openwritestrgraph)r+   r,   prefixnfilenamefZfilename_txtr(   r(   r)   _dump_onnx_modelb   s    
(.r?   c                   C   s   dgS )NCPUExecutionProviderr(   r(   r(   r(   r)   _infer_default_eps{   s    rA   namec                 C   s   t j rt jj|  dS )zIf PyTorch is installed with CUDA support, this starts NVTX range.

    Check torch.cuda.nvtx.range_push's document for more details.
    N)r'   cudais_availablenvtxZ
range_pushrB   r(   r(   r)   _nvtx_range_push   s    
rG   c                   C   s   t j rt jj  dS )zIf PyTorch is installed with CUDA support, this terminates NVTX range.

    Check torch.cuda.nvtx.range_pop's document for more details.
    N)r'   rD   rE   rF   Z	range_popr(   r(   r(   r)   _nvtx_range_pop   s    
rH   )device_typec                 C   sR   ddl m} | dkr|j S | dkr0|j S | dkrB|j S td|  d S )Nr   r   rD   cpuZmaiazUnsupported device type: )onnxruntime.capir   	OrtDevicerD   rJ   Znpu
ValueError)rI   ORTCr(   r(   r)   _get_ort_device_type   s    


rO   c                       sZ   e Zd ZdZee eeef d fddZe	ee
jjf e
jjed fddZ  ZS )OrtOperatorSupporta0  Operator support for ONNXRuntime backend.

    It has two-level of support decision. One is via support_dict and the other one
    is via extra_support_dict. The logic of using support_dict is implemented in
    OrtOperatorSupport and extra_support_dict is used by OperatorSupport.is_node_supported.
    )support_dictextra_support_dictc                    s   t  | || _d S N)super__init___onnx_support_dict)selfrQ   rR   	__class__r(   r)   rU      s    zOrtOperatorSupport.__init__)
submodulesnoder   c                    s   |j tvrdS |j dkr>|j| jv r>td|jt|j dS t ||rftd|jt|j dS t	d|jt|j dS )NFcall_functionz0support_dict supports node.target: %s (type: %s)Tz6extra_support_dict supports node.target: %s (type: %s)zLsupport_dict and extra_support_dict don't support node.target: %s (type: %s))
opr   targetrV   loggerinfotyperT   is_node_supportedwarning)rW   rZ   r[   rX   r(   r)   rb      s,    
z$OrtOperatorSupport.is_node_supported)__name__
__module____qualname____doc__setr   dictr9   rU   r   r'   nnModulefxNodeboolrb   __classcell__r(   r(   rX   r)   rP      s
    	rP   r,   r   c                 C   sh   | j }g }d}|jD ].}|jdkr,|| |du r|jdkr|}q|du rPdS |D ]}|| qTdS )z
    In torch.fx.Graph, placeholder is a special assignment node. If it's not
    executed in the beginning, it could overwrite values computed by upstream
    nodes.
    Nplaceholder)r:   nodesr]   appendprepend)r,   r:   placeholdersZfirst_not_placeholderr[   rq   r(   r(   r)   _move_placeholder_to_front   s    


rv   .c                  G   sP   g }| D ]>}t |dr|j}|jdkr2|d q|jdkr|d qt|S )zBReturn the first valid device (i.e., GPU or CPU) in argument list.devicerD   CUDAExecutionProviderrJ   r@   )hasattrrw   ra   rs   tuple)argsepsargrw   r(   r(   r)   _infer_ep_from_device   s    


r~   c                 C   sX   g }| j jD ]B}|jdkrt|drDd|jv rDt|jd tjsDJ || qt	|S )Nrq   metaval)
r:   rr   r]   ry   r   
isinstancer'   Tensorrs   rz   )r,   ru   r[   r(   r(   r)   _extract_graph_module_inputs   s    
r   c                 C   s2   | j jD ]}|jdkr|jd   S qtddS )zHCollect "val" fields from outputs metadata in this torch.fx.GraphModule.outputr   z2No output node found in this torch.fx.GraphModule.N)r:   rr   r]   r{   rM   )r,   r[   r(   r(   r)   _extract_graph_module_outputs  s    
r   c                 C   s(   t t| \}}dd |D }t| S )z[Return the all valid devices (i.e., GPU or CPU) among outputs of this torch.fx.GraphModule.c                 S   s*   g | ]"}t |d rd|jv r|jd qS )r   r   ry   r   ).0Z
output_argr(   r(   r)   
<listcomp>  s   z/_infer_ep_from_graph_module.<locals>.<listcomp>)r   Ztree_flattenr   r~   )r,   Zflattened_output_args_Zselected_output_argsr(   r(   r)   _infer_ep_from_graph_module  s    r   )r|   r   c                 C   s*   t tddd}t| }tt||ddS )z:Sort execution providers in eps based on pre-set priority.)epr   c                 S   s   | dkrdS | dkrdS dS )Nr@      rx   r0   r   r(   )r   r(   r(   r)   get_execution_provider_priority!  s
    z2_sort_eps.<locals>.get_execution_provider_priorityT)keyreverse)r9   intrh   rz   sorted)r|   r   Z
unique_epsr(   r(   r)   	_sort_eps  s    r   zORTC.OrtDevice.)valuesr   c                    s   ddl m  ttdddttjtjttjttj	t
f td fddt| dkrrtfd	d
| D }|S dfS d S )Nr   r   )	device_idr   c                 S   s   | pdS )Nr   r(   )r   r(   r(   r)   _device_id_or_zero:  s    z-_get_onnx_devices.<locals>._device_id_or_zero)valuer   c                    s|   t | tjr0 t| jj j | jjS t | tj	t
tjttjtfrd td j dS tdtt|  d S )NrJ   r   zUnsupported value type: )r   r'   r   rL   rO   rw   ra   Zdefault_memoryindexSymIntr   SymFloatfloatSymBoolrn   rM   r9   r   )rN   r   r(   r)   _map_tensor_or_sym_to_device=  s    

z7_get_onnx_devices.<locals>._map_tensor_or_sym_to_devicec                 3   s   | ]} |V  qd S rS   r(   )r   r   )r   r(   r)   	<genexpr>R      z$_get_onnx_devices.<locals>.<genexpr>r0   )rK   r   r   r	   r'   r   r   r   r   r   rn   lenrz   )r   Zort_devicesr(   )rN   r   r   r)   _get_onnx_devices0  s    r   )tensorsdevicesr   c           
      C   s   dd l }ddlm} tj|jtj|jtj|jtj|jtj|jtj	|j	tj
|j
tj|jtj|ji	}| }|t|  g }g }g }| D ]0}	|||	j  ||	  ||	  q|| |||| |S )Nr   r   )numpyrK   r   r'   float16float32float64uint8int8int16int32int64Zlonglongrn   Zbool_OrtValueVectorZreserver   rs   dtypesizeZdata_ptrpush_back_batch)
r   r   nprN   Ztorch_dtype_to_numpy_dtypeZ	ortvaluesZdtypesZshapesZ	data_ptrstensorr(   r(   r)   !_get_ortvalues_from_torch_tensorsX  s.    r   )r   r   c                 C   s*   | j rtdtj|  | j| jd}|S )Nz#sparse tensor is not yet supported.)r   rw   )Z	is_sparserM   r'   emptyr   r   rw   )r   outr(   r(   r)   _to_real_tensorx  s    r   onnx.ValueInfoProto)dynamo_value
value_infor   c                 C   s   t | tjr4t|jjjjdkr4| jdkr4t| S t | t	rNtj
| tjdS t | trhtj
| tjdS t | trtj
| tjdS t | tjsJ |  S dS )z9Helper function to wrap PyTorch variables as torch.Tensorr   )r0   )r   N)r   r'   r   r   ra   tensor_typeshapedimZsqueezer   r   r   r   r   rn   
contiguous)r   r   r(   r(   r)   _adjust_scalar_from_fx_to_onnx  s    





r   )r   
prim_valuer   c                 C   s<   t | tjsJ dt |tjttjttjtfr8| 	 S | S )zFHelper function to wrap ORT-produced torch.Tensor as PyTorch variableszORT's output must be tensor.)
r   r'   r   r   r   r   r   r   rn   item)r   r   r(   r(   r)   _adjust_scalar_from_onnx_to_fx  s    r   onnxruntime.InferenceSessionr   .)sessinput_namesinputsinput_devicesoutput_namesoutputsoutput_devicespreallocate_outputinput_value_infosnormalized_prim_outputsr   c
                 C   s&  dd l }
ddlm} td tdd t||D }t  td t||}|rntdd |D }t||}n| }t  td |
	 }|
d	d
 | |||||| t  |rtd tdd t||	D }t  |S dd l}
td |
jjj|}tdd t||	D }t  |S d S )Nr   r   r   c                 s   s   | ]\}}t ||V  qd S rS   r   r   r}   r   r(   r(   r)   r     s   z8_run_onnx_session_with_ortvaluevector.<locals>.<genexpr>r   c                 s   s$   | ]}t |trt|n|V  qd S rS   )r   r   r   )r   tr(   r(   r)   r     s   run_with_ortvaluevectorZ'disable_synchronize_execution_providers1zafter run_with_ortvaluevectorc                 s   s   | ]\}}t ||V  qd S rS   r   r   onnx_outputprim_outputr(   r(   r)   r     s   c                 s   s   | ]\}}t ||V  qd S rS   r   r   r(   r(   r)   r     s   )r   rK   r   rG   rz   ziprH   r   r   Z
RunOptionsZadd_run_config_entryr   Zonnxruntime.trainingZtrainingZ	ortmodule_utilsZ_ortvalues_to_torch_tensor)r   r   r   r   r   r   r   r   r   r   r   rN   Z
ort_inputspth_outputsort_outputsZrun_optionsr(   r(   r)   %_run_onnx_session_with_ortvaluevector  sP    

r   c
                    s`   dd l  tdd t||D } fddt||D }
| ||
}tdd t||	D }|S )Nr   c                 s   s   | ]\}}t ||V  qd S rS   r   r   r(   r(   r)   r   !  s   z/_run_onnx_session_with_fetch.<locals>.<genexpr>c                    s&   i | ]\}}| j |  qS r(   )ZOrtValueZortvalue_from_numpyrJ   r   )r   rC   r   r   r(   r)   
<dictcomp>%  s   z0_run_onnx_session_with_fetch.<locals>.<dictcomp>c                 s   s"   | ]\}}t t||V  qd S rS   )r   r'   Z
from_numpy)r   r   r   r(   r(   r)   r   *  s
   )r   rz   r   run)r   r   r   r   r   r   r   r   r   r   feedr   r   r(   r   r)   _run_onnx_session_with_fetch  s    
r   )ra   c                 C   s.   ddl }t|jjt|jjt|jji}|| S )a=  
    Converts a Python type to the corresponding ONNX tensor element type.
    For example, `_from_python_type_to_onnx_tensor_element_type(float)` returns
    `onnx.TensorProto.FLOAT`.

    Args:
      type (type): The Python type to convert.

    Returns:
      int: The corresponding ONNX tensor element type.

    r   N)	onnxr   TensorProtoFLOATr   INT64rn   BOOLr6   )ra   r   Z(_PYTHON_TYPE_TO_ONNX_TENSOR_ELEMENT_TYPEr(   r(   r)   -_from_python_type_to_onnx_tensor_element_type4  s    r   c                   @   sh   e Zd ZdZdeedf ed eedf ed ed ed eeejdf ejf dddZ	d	d
 Z
dS )OrtExecutionInfoPerSessionzWInformation required to execute torch.fx.GraphModule using onnxruntime.InferenceSessionr   .r   r   sessionr   r   r   output_value_infosr   r   example_outputsc	           	      C   s4   || _ || _|| _|| _|| _|| _|| _|| _d S rS   r   )	rW   r   r   r   r   r   r   r   r   r(   r(   r)   rU   N  s    z#OrtExecutionInfoPerSession.__init__c           
      G   s  dd l }|jjtj|jjtj|jjtj|jj	tj
|jjtj|jjtj|jjtj|jjtj|jjtj|jjtj|jjtj|jjtj|jjtji}dd | D }t|t| jkrdS t || jD ]\}}t!|tj"t#t$fs dS t!|t$t#tfr2t%t&|}||j&j'j(kr dS t|j&j'j)j*dkr dS q||j+ }||j&j'j(krR dS t |j)|j&j'j)j*D ]T\}}	t!|t$r|	j,|ksf|	j-rqfn$t!|tj.r|	j-rqfn  dS qfqdS )Nr   c                 S   s   i | ]\}}||qS r(   r(   )r   r   r   r(   r(   r)   r     s   z;OrtExecutionInfoPerSession.is_supported.<locals>.<dictcomp>FT)/r   r   r   r'   r   ZFLOAT16r   Z
FLOAT8E5M2Zfloat8_e5m2ZFLOAT8E5M2FNUZZfloat8_e5m2fnuzZFLOAT8E4M3FNZfloat8_e4m3fnZFLOAT8E4M3FNUZZfloat8_e4m3fnuzZDOUBLEr   r   rn   ZUINT8r   ZINT8r   ZINT16r   ZINT32r   r   r   itemsr   r   r   r   r   r   r   r   ra   r   Z	elem_typer   r   r   Z	dim_valueZ	dim_paramr   )
rW   r{   r   Z(_onnx_tensor_element_type_to_torch_dtypeZ(_torch_dtype_to_onnx_tensor_element_typer}   r   Z
onnx_dtyper   Zonnx_dimr(   r(   r)   is_supportedo  sX    













z'OrtExecutionInfoPerSession.is_supportedN)rd   re   rf   rg   rz   r9   r	   r'   r   rU   r   r(   r(   r(   r)   r   K  s   

!r   c                   @   s@   e Zd ZddddZejjdddZejjedd	d
Z	dS )"OrtExecutionInfoForAllGraphModulesNr   c                 C   s
   i | _ d S rS   )execution_info_per_graph_module)rW   r(   r(   r)   rU     s    z+OrtExecutionInfoForAllGraphModules.__init__r,   c                 G   s8   || j vrd S | j | }|D ]}|j| r|  S qd S rS   )r   r   )rW   r,   r{   
candidates	candidater(   r(   r)   &search_reusable_session_execution_info  s    



zIOrtExecutionInfoForAllGraphModules.search_reusable_session_execution_info)r,   r`   c                 C   s,   || j vr|g| j |< n| j | | d S rS   )r   rs   )rW   r,   r`   r(   r(   r)   cache_session_execution_info  s    
z?OrtExecutionInfoForAllGraphModules.cache_session_execution_info)
rd   re   rf   rU   r'   rl   GraphModuler   r   r   r(   r(   r(   r)   r     s
   r   r   T)frozenF)Zis_backward_compatiblec                   @   s   e Zd ZU dZdZeee  ed< dZ	e
ed< dZeee  ed< dZe
ed< dZe
ed	< dZed
 ed< dZeeedgdf   ed< dS )r   aJ  Options for constructing an ``OrtBackend``, the ONNX Runtime
    backend (``"onnxrt"``) for ``torch.compile``.

    Example::

        >>> @torch.compile(
        ...     backend="onnxrt",
        ...     options=torch.onnx._OrtBackendOptions(...),
        ... )
        ... def ort_function(x):
        ...     return x ** x
    Npreferred_execution_providersTinfer_execution_providersdefault_execution_providersFr   use_aot_autogradzonnxruntime.SessionOptionsort_session_optionszonnx.ModelProtopre_ort_model_transforms)rd   re   rf   rg   r   r   r   r   __annotations__r   rn   r   r   r   r   r   r   r(   r(   r(   r)   r     s   

r   c                   @   s   e Zd ZU dZdee dddZejj	e
eeeeef f  dddZejj	d	d
dZejj	ejj	dddZejj	ejj	dddZdZeed< g Zeed   ed< edeeeeeef f  d dddZedd Zedd ZdS )r   a	  A backend compiles (sub-)graphs in torch.fx.GraphModule to onnxruntime.InferenceSession calls.

    The compiler entry point is OrtBackend.compile, which
        1. partitions the original graph into supported sub-graphs (type: torch.fx.GraphModule) and unsupported
           sub-graphs.
        2. For each supported sub-graph, it replaces its _wrapped_call function with _ort_accelerated_call.
        3. Inside _ort_accelerated_call, it creates onnxruntime.InferenceSession and calls it to execute the sub-graph.
    Noptionsc                 C   s   ddl m} dd l}dd l}dd l}|d u r2t n|| _|jjj	
 | _|jjjj| jj}d d d d d d}t||| _i | _t | _d| _d| _t|jdrtnt| _d S )Nr   r   )getattrz_operator.getitemz_operator.mulz_operator.addz_operator.subFr   )rK   r   r"   r$   +torch.onnx._internal.fx.decomposition_tabler   _optionsr   	_internalZ_exporter_legacyZResolvedExportOptions_resolved_onnx_exporter_optionsrl   r   Z'_create_onnx_supports_op_overload_tableonnx_registryrP   _supported_ops_partitioner_cacher   _all_ort_execution_info_assert_allclose_to_baselineexecution_country   r   r   r   r   )rW   r   rN   r'   rQ   rR   r(   r(   r)   rU   !  s2    
zOrtBackend.__init__rp   c                 G   s   d}| j jr.t|  }r|}nt| }r.|}g }g | j jp>g t|| j jpTt R D ]T}t|t	rr|i f}n"t|t
r|d d u r|d i f}|d urZ||vrZ|| qZ|S )Nr(   r0   r   )r  r   r~   r   r   r   r   rA   r   r9   rz   rs   )rW   r,   r{   Zinferred_epsZeps_from_argsZeps_from_graph_moduleZselected_epsr   r(   r(   r)   _select_epsu  s*    



zOrtBackend._select_epsr   c                 O   s  ddl }ddlm}m} | jj|g|R  }|rd|j}|j}	|j}
|j	}|j
}|j}|j}|j}n|| }| jjrd| _t|}dd }t||}n@zt|j|i |}W n& ty   td| d| _ Y n0 | }|| }|j|| jjd}|j| jjj d	}| j!j"r:| j!j"D ]}|| q*|# }t$j%&d
dr^t'||d |j(|| j!j)| j*|g|R  d}t+dd |j,j-D }	t+dd |j,j.D }
t/|}t0|t+rt/|}n
t/|f}t+dd |j,j-D }t+dd |j,j.D }t1||	||
||||d}| j2|| |  j3d7  _3t0|t4j5}|rL|fn|}t0|t+s`J t6dd |D sxJ t7d | ||	|||
||| j!j||
}t8  | j9rt4j:j;j<|g|R ddi}|r|fn|}t=||D ]\}}t4j>?|| q|r|d S |S )a  This function replaces GraphModule._wrapped_call in compiled model.

        The _wrapped_call is the underlying implementation of forward method. Replacing
        it means we delegate the computation to _ort_acclerated_call and therefore
        onnxruntime.InferenceSession.
        r   N)r   r   Fc                 S   s&   t | drd| jv r| jd S | S d S )Nr   r   r   r   r(   r(   r)   maybe_map_to_meta_val  s    
z>OrtBackend._ort_acclerated_call.<locals>.maybe_map_to_meta_valzFakeTensorProb failed for %s)Zfx_graph_moduleonnxfunction_dispatcher)opset_versionr-   r   )Zpath_or_bytesZsess_options	providersc                 s   s   | ]}|j V  qd S rS   rB   r   inputr(   r(   r)   r     r   z2OrtBackend._ort_acclerated_call.<locals>.<genexpr>c                 s   s   | ]}|j V  qd S rS   rB   r   r   r(   r(   r)   r     r   c                 s   s   | ]
}|V  qd S rS   r(   r  r(   r(   r)   r     r   c                 s   s   | ]
}|V  qd S rS   r(   r  r(   r(   r)   r     r   r   r0   c                 s   s"   | ]}t |tjtjtfV  qd S rS   )r   r'   r   r   r   )r   elemr(   r(   r)   r   7  s   Z$run_onnx_session_with_ortvaluevectorexecutorZaten)@r   r%   r   r   r	  r   r   r   r   r   r   r   r   r   ZMovePlaceholderToFrontr   r  Zdynamic_shapesr   r   r   Ztree_mapr   	propagate	Exceptionr_   rc   ZFxOnnxInterpreterZInsertTypePromotionr  Zto_model_protor  r  r  r   ZSerializeToStringr4   r5   r6   r?   ZInferenceSessionr   r  rz   r:   r  r   r   r   r   r   r  r'   r   allrG   rH   r
  Z_primsr  executer   ZtestingZassert_close)rW   r,   r{   kwargsr   r   r   Z!cached_execution_info_per_sessionZonnx_sessionr   r   r   r   r   r   Zprim_outputsZextracted_outputsr  Zfx_interpreterZexportedZ
onnx_modelZ	transformZonnx_model_bytesZexecution_info_per_sessionZis_single_tensor_outputr   Zonnx_outputsZbaseline_outputsZnormalized_baseline_ouptutsr   Zbaseline_outputr(   r(   r)   _ort_acclerated_call  s    
	

	

zOrtBackend._ort_acclerated_callc           	      C   s   ddl m} || jv r"| j| }n\|}||| jdd}| }|| j|< |jjD ],}|jdkrPd|jv rPt	||j}| j
|_qP|S )Nr   )CapabilityBasedPartitionerT)Zallows_single_node_partitionZcall_moduleZfused_)Z!torch.fx.passes.infra.partitionerr  r  r  Zpartition_and_fuser:   rr   r]   rC   r  r  Z_wrapped_call)	rW   r,   r{   r  Zpartitioned_prim_graph_moduleZprim_graph_moduleZpartitionerr[   Zfused_moduler(   r(   r)   compileZ  s     


zOrtBackend.compilec                 C   sF   | j jr:ddlm} ddlm} || j|| jjd||S | ||S )zIf ``OrtBackendOptions.use_aot_autograd`` is ``True``, the `auto_autograd` compiler
        will be invoked, wrapping this ``OrtBackend`` instance's ``compile`` method. Otherwise,
        the ``compile`` method is invoked directly.r   )#min_cut_rematerialization_partition)aot_autograd)Zfw_compilerZpartition_fnZdecompositions)	r  r   Zfunctorch.compiler  Ztorch._dynamo.backends.commonr  r  r  r   )rW   r,   r{   r  r  r(   r(   r)   __call__  s    zOrtBackend.__call__   %_OrtBackend__instance_cache_max_count_OrtBackend__instance_cache)r   r   c                    s   t t dddt t s,t f i  p&i  t fddtjD d}|du rttjtjk s~J dtj dt d	t d
tjt  } |S )a  Returns a possibly cached instance of an ``OrtBackend``. If an existing
        backend was created previously through this function with the same options,
        it will be returned. Otherwise a new backend will be created, cached, and
        returned.

        Note: if ``options`` sets ``ort_session_options``, a new ``OrtBackend``
        will always be returned, since ``onnxruntime.SessionOptions`` cannot
        participate in caching.abc                 S   sh   | j |j ksH| j|jksH| j|jksH| j|jksH| j|jksH| j|jkrLdS | jd us`|jd urddS dS )NFT)r   r   r   r   r   r   r   r$  r(   r(   r)   reusable  s     





z<OrtBackend.get_cached_instance_for_options.<locals>.reusablec                 3   s   | ]}|j  r|V  qd S rS   )r  )r   r&  r   r'  r(   r)   r     r   z=OrtBackend.get_cached_instance_for_options.<locals>.<genexpr>NzNo more than z instances of z allowed. Please instantiate `z` explicitly to pass to `torch.compile`. See https://github.com/pytorch/pytorch/pull/107973#discussion_r1306144795 for discussion.)r   r   nextr   r#  r   r"  rs   )r   backendr(   r(  r)   get_cached_instance_for_options  s$    

	z*OrtBackend.get_cached_instance_for_optionsc                   C   s   t j  d S rS   )r   r#  clearr(   r(   r(   r)   clear_cached_instances  s    z!OrtBackend.clear_cached_instancesc                   C   s
   t tjS rS   )rz   r   r#  r(   r(   r(   r)   get_cached_instances  s    zOrtBackend.get_cached_instances)N)N)rd   re   rf   rg   r   r   rU   r'   rl   r   r   rz   r9   r   r   r  r  r  r   r"  r   r   r#  liststaticmethodr	   r+  r-  r.  r(   r(   r(   r)   r     s.   
	U! E7 6
r   r   )r,   r   c                C   s   t || |S rS   )r   r+  )r,   r{   r   r(   r(   r)   r     s    r   )N)]Zdataclassesr    loggingr4   collections.abcr   r   typingr   r   r   r   r   r	   Ztyping_extensionsr
   r'   Ztorch._CZ
torch._opsZtorch._prims.executorZtorch.fxZ!torch.onnx._internal._lazy_importZtorch._subclasses.fake_tensorr   Ztorch.fx._compatibilityr   Z torch.fx.passes.fake_tensor_propr   Z torch.fx.passes.operator_supportr   Ztorch.fx.passes.tools_commonr   Ztorch.utilsr   r   r   rK   r   rN   r"   r#   r$   r  Ztorch.onnx._internal.fx.passesr   rn   r   __all__r   r*   ri   r9   r   bytesrl   r   r?   rA   rG   rH   rO   	getLoggerrd   r_   rP   rv   rz   r~   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   ra   r   r   Z	dataclassr   r   r   r   r   r(   r(   r(   r)   <module>   s"  
 	0 		
2

 ) !	

T

&^!$
<   Q