a
    hT                     @   s  d dl Z d dlZd dlZd dlZd dlmZ d dlmZmZm	Z	m
Z
mZmZ d dlmZmZ d dlZd dlm  mZ d dlmZ d dlmZmZmZ d dlmZmZmZ d dlm Z  d d	l!m"Z" d d
l#m$Z$ edej%ej&Z'e(e)dddZ*e j+G dd dZ,G dd dej-Z.G dd dZ/e	de0Z1G dd dej2Z3G dd dZ4G dd dZ5G dd dZ6G dd dZ7G d d! d!Z8G d"d# d#Z9G d$d% d%Z:G d&d' d'e:Z;G d(d) d)e:Z<G d*d+ d+e:Z=G d,d- d-e:Z>G d.d/ d/e:Z?G d0d1 d1Z@G d2d3 d3ZAdS )4    N)abstractmethod)AnyCallableNewTypeOptionalTypeVarUnion)overrideSelf)TracingContext)
FakeTensorFakeTensorModeTensor)MetaConverterMetaTensorDescMetaTensorDescriberSymNode)ShapeEnv)no_dispatch	_SymNodeTnamereturnc                 C   s
   |  dS )z
    An ops filter which allows pickle-safe ops. Pickle-safe ops are built-in
    ones where it will be possible to unpickle on any machine which has PyTorch.
    )ztorch.ops.atenztorch.ops.fbgemm)
startswithr    r   E/var/www/auris/lib/python3.9/site-packages/torch/fx/_graph_pickler.py_ops_filter_safe   s    r   c                   @   s(   e Zd ZU eZeeegef  e	d< dS )Options
ops_filterN)
__name__
__module____qualname__r   r    r   r   strbool__annotations__r   r   r   r   r   (   s   
r   c                       s   e Zd ZdZdejee dd fddZe	e
eedef eedf f ddd	Ze	e
ee dd
dZede
ee edddZeeee
dddZ  ZS )GraphPicklerzb
    GraphPickler is a Pickler which helps pickling fx graph - in particular
    GraphModule.
    N)fileoptionsr   c                    s4   t  | |pt | _tt | _tdd| _d S )NF)Z	copy_data)	super__init__r   r)   _UnpickleStateTokenobject_unpickle_stater   _meta_tensor_describer)selfr(   r)   	__class__r   r   r+   5   s    zGraphPickler.__init__.)objr   c                 C   s   t |trt| |S t |tjjr0t| |S t |tjj	tjj
frRt| |S t |trht| |S t |tjrt| |S t |tjjrt| |S t |tjjrJ t| | }r|S tS d S N)
isinstancer   _TensorPickleDatareduce_helpertorchfxGraphModule_GraphModulePickleData_opsZOperatorBaseOpOverloadPacket_OpPickleDatar   _ShapeEnvPickleDataSymInt_SymNodePickleData_guardsr   _TracingContextPickleDataNode_TorchNumpyPickleDataNotImplemented)r0   r3   reducer   r   r   reducer_overrideC   s     

zGraphPickler.reducer_overridec                 C   s   || j u rdS d S d S )Nunpickle_state)r.   )r0   r3   r   r   r   persistent_idn   s    
zGraphPickler.persistent_id)r3   r)   r   c                 C   sH   t  ,}| ||}|| | W  d   S 1 s:0    Y  dS )z#
        Pickle an object.
        N)ioBytesIOdumpgetvalue)clsr3   r)   streampicklerr   r   r   dumpsu   s    


zGraphPickler.dumps)data	fake_moder   c                 C   sH   t |}t| "}t||}| W  d   S 1 s:0    Y  dS )z%
        Unpickle an object.
        N)_UnpickleStaterK   rL   _GraphUnpicklerload)rS   rT   staterP   Z	unpicklerr   r   r   loads   s    
zGraphPickler.loads)N)N)r!   r"   r#   __doc__rK   rL   r   r   r+   r	   r-   tupler   r   rH   r$   rJ   classmethodbytesrR   staticmethodr   rY   __classcell__r   r   r1   r   r'   /   s   *	r'   c                   @   s   e Zd ZeddddZdS )rU   N)rT   r   c                 C   s   || _ t | _d S r4   )rT   r   meta_converter)r0   rT   r   r   r   r+      s    z_UnpickleState.__init__)r!   r"   r#   r   r+   r   r   r   r   rU      s   rU   r,   c                       s<   e Zd Zejedd fddZeeedddZ	  Z
S )rV   N)rP   rI   r   c                    s   t  | || _d S r4   )r*   r+   r.   )r0   rP   rI   r1   r   r   r+      s    z_GraphUnpickler.__init__)pidr   c                 C   s   |dkr| j S tdd S )NrI   zInvalid persistent ID)r.   pickleUnpicklingError)r0   ra   r   r   r   persistent_load   s    z_GraphUnpickler.persistent_load)r!   r"   r#   rK   rL   rU   r+   r	   r-   rd   r_   r   r   r1   r   rV      s   rV   c                   @   sp   e Zd ZU eeef ed< eee	e
eeege	f e
eef f dddZe	ddddZee	d	d
dZdS )r?   rS   rQ   r3   r   c                 C   s   | j | ||jffS r4   unpickler.   rO   rQ   r3   r   r   r   r7      s    z!_ShapeEnvPickleData.reduce_helperN)envr   c                 C   s*   |j r
J |j | _| jd= | jd= d S )NZtracked_fakesZfake_tensor_cache)Z_translation_validation_enabled__dict__copyrS   )r0   ri   r   r   r   r+      s    
z_ShapeEnvPickleData.__init__rI   r   c                 C   sB   |j s
J |j jsJ | j D ]\}}t|j j|| q |j jS r4   )rT   	shape_envrS   itemssetattr)r0   rI   kvr   r   r   rg      s
    
z_ShapeEnvPickleData.unpickle)r!   r"   r#   dictr$   r-   r&   r\   r'   r   r[   r   r
   rU   r,   r7   r+   rg   r   r   r   r   r?      s   
	r?   c                   @   sn   e Zd Zeeeeeee	gef eee
f f dddZeddddZedd	d
Ze	ejdddZdS )rA   re   c                 C   s<   | |j |jf}t|tjr&tj|fS tdt| d S )NzUnhandled SymNode type )	noder.   r5   r8   r@   rA   unpickle_sym_intNotImplementedErrortype)rO   rQ   r3   argsr   r   r   r7      s    
z _SymNodePickleData.reduce_helperN)rs   r   c                 C   s$   |j | _|j| _|j| _|j| _d S r4   )Z_exprexprrm   pytypeZ_hinthint)r0   rs   r   r   r   r+      s    z_SymNodePickleData.__init__)r   c                 C   s0   ddl m} | jd usJ || j| j| j| jS )Nr   r   )torch.fx.experimental.sym_noder   rm   rx   ry   rz   )r0   r   r   r   r   _to_sym_node   s    z_SymNodePickleData._to_sym_noderl   c                 C   s   t |  S r4   )r8   r@   r|   r0   rI   r   r   r   rt      s    z#_SymNodePickleData.unpickle_sym_int)r!   r"   r#   r\   r'   r   r[   r   r
   rU   r,   r7   r   r+   r|   r8   r@   rt   r   r   r   r   rA      s   rA   c                   @   sn   e Zd ZU ee ed< eeeee	e
egef ee
ef f dddZeeddddZeed	d
dZdS )r6   metadatare   c                 C   s   | j | |j||jffS r4   )rg   r/   r.   rh   r   r   r   r7      s    
z_TensorPickleData.reduce_helperN)	describertr   c                 C   s|   | |}|jr&t|jtjjjs&J tj|d d| _	t
jD ]:}|dv rJq<t| j	|d u s<J d| dt| j	| q<d S )NrT   )rT   	view_funcz
not None: z: )Zdescribe_tensorr   r5   r8   Z_subclassesZ
meta_utilsZ_FakeTensorViewFuncdataclassesreplacer~   r   Z_UNSERIALIZABLEgetattr)r0   r   r   r~   rp   r   r   r   r+      s    

z_TensorPickleData.__init__rl   c                    sT   t j| j jd}tg tjf ttjt	f t
d fdd} j| jj|d d S )Nr   )make_meta_tdevicer   c                    s:   t    t j|  |W  d    S 1 s,0    Y  d S r4   )r   r   rT   )r   r   rI   r   r   	with_fake  s    z-_TensorPickleData.unpickle.<locals>.with_fake)r   r   r~   rT   r   r8   r   r   r   r$   r   r`   Zmeta_tensorrm   )r0   rI   r~   r   r   r   r   rg     s    
z_TensorPickleData.unpickle)r!   r"   r#   r   r   r&   r\   r'   r[   r   r
   rU   r,   r7   r   r   r+   rg   r   r   r   r   r6      s   

r6   c                	   @   s   e Zd Zeeeeeee	e
gef ee	ef f  dddZeeddddZe
edef d	d
dZeeee	 dddZdS )rE   re   c                 C   s&   |  | }r| j||jffS d S d S r4   )from_objectrg   r.   )rO   rQ   r3   rS   r   r   r   r7     s    z#_TorchNumpyPickleData.reduce_helperN)modr   r   c                 C   s   || _ || _d S r4   )r   r   )r0   r   r   r   r   r   r+   ,  s    z_TorchNumpyPickleData.__init__.rl   c                 C   s&   t t| j| j}tjjj	 | S r4   )
r   	importlibimport_moduler   r   r8   _dynamo	variablesmiscZget_np_to_tnp_map)r0   rI   npr   r   r   rg   0  s    z_TorchNumpyPickleData.unpickle)tnpr   c                 C   s   t |sd S tjjj }z|| }s0W d S W n tyF   Y d S 0 t|dd  }s\d}t|dd  }spd S |tt	
||ksJ | ||S )Nr"   numpyr!   )callabler8   r   r   r   Zget_tnp_to_np_mapget	TypeErrorr   r   r   )rO   r   Z	tnp_to_npr   r   r   r   r   r   r   4  s    
z!_TorchNumpyPickleData.from_object)r!   r"   r#   r\   r'   r-   r   r[   r   r
   rU   r,   r7   r$   r+   rg   r   r   r   r   r   rE     s   rE   c                   @   sp   e Zd Zeeejjee	e
egejjf ee
ef f dddZejjeddddZeejjdd	d
ZdS )r;   re   c                 C   s   | j | ||j|jffS r4   )rg   r)   r.   rh   r   r   r   r7   K  s    
z$_GraphModulePickleData.reduce_helperN)gmr)   r   c                 C   sH   t |tjjjr| }n| }|j | _	| j	d= t
|j|| _d S )N_graph)r5   r8   r9   Z_lazy_graph_moduleZ_LazyGraphModuleZ_real_recompileZ	recompilerj   rk   gm_dict_GraphPickleDatar   graph)r0   r   r)   Z_python_coder   r   r   r+   W  s    
z_GraphModulePickleData.__init__rl   c                 C   s.   t jjt jj}| j|_| j|||_|S r4   )	r8   r9   r:   __new__r   rj   r   rg   r   )r0   rI   r   r   r   r   rg   a  s    z_GraphModulePickleData.unpickle)r!   r"   r#   r\   r'   r8   r9   r:   r[   r   r
   rU   r,   r7   r   r+   rg   r   r   r   r   r;   J  s   

r;   c                   @   sX   e Zd Zejjeejjd f eddddZejj	ed ejjf e
ejjdddZdS )_NodePickleDataN)rs   mappingr)   r   c                    sp   t tjj fdd|j| _t tjj fdd|j| _|j| _|j| _t	
|j|| _|j| _|j| _d S )Nc                    s    |  S r4   r   nr   r   r   <lambda>o      z*_NodePickleData.__init__.<locals>.<lambda>c                    s    |  S r4   r   r   r   r   r   r   q  r   )pytreetree_map_onlyr8   r9   rD   rw   kwargsr   opr>   rb   targetrv   meta)r0   rs   r   r)   r   r   r   r+   i  s    z_NodePickleData.__init__)r   r   rI   r   c                    sx   t t fdd| j}t t fdd| j}| j|}t|sRt|t	sRJ |
| j|||| j| j}| j|_|S )Nc                    s    |  S r4   r   r   r   r   r   r     r   z*_NodePickleData.unpickle.<locals>.<lambda>c                    s    |  S r4   r   r   r   r   r   r     r   )r   r   r   rw   r   r   rg   r   r5   r$   Zcreate_noder   r   rv   r   )r0   r   r   rI   rw   r   r   rs   r   r   r   rg     s    z_NodePickleData.unpickle)r!   r"   r#   r8   r9   rD   rr   r   r+   GraphrU   rg   r   r   r   r   r   h  s   r   c                   @   s   e Zd Zeeeeeegef ee	 f dddZ
eeed dddZeeeed ed f ed d	d
dZeeedddZeeedddZeeeedddZdS )r>   )rQ   r   r   c                 C   s   |  ||j}|j|jffS r4   )rb   r)   rg   r.   )rO   rQ   r   resultr   r   r   r7     s    z_OpPickleData.reduce_helper)r   r)   r   c                 C   s   t |trt|S tjj|}t |tjjr<| 	|t
|S t |tjjrX| 	|t|S |dr||dd\}}t||S |dr|dd\}}t|S tdt| d| d| d S )N)z	builtins.zmath.ztorch..   z	operator.zTARGET:  )r5   r$   _OpStrPickleDatar8   r9   rD   Z_pretty_print_targetr<   
OpOverload
_pickle_op_OpOverloadPickleDatar=   _OpOverloadPacketPickleDatar   split_OpBuiltinPickleData_OpOperatorPickleDataru   rv   )rO   r   r)   r   rootZdetail_r   r   r   rb     s    



z_OpPickleData.pickler   r   )r   dataclsr)   r   c                 C   s4   |j  }r,|| s,ddlm} |d|  || S )Nr   )BypassFxGraphCachez"Unable to pickle non-standard op: )r    Ztorch._inductor.codecacher   )r   r   r)   r    r   r   r   r   r     s    z_OpPickleData._pickle_oprl   c                 C   s   d S r4   r   r}   r   r   r   rg     s    z_OpPickleData.unpickler   c                 C   s<   d|v r.| dd\}}t | }| ||S t | S dS )zC
        Like `globals()[name]` but supports dotted names.
        r   r   N)r   globals_getattr_by_name)rO   r   r   restr   r   r   r   _lookup_global_by_name  s
    
z$_OpPickleData._lookup_global_by_namer   r   r   c                 C   s.   d|v r$| dd\}}t| |} q t| |S )zG
        Like `getattr(root, name)` but supports dotted names.
        r   r   )r   r   )r   r   r   r   r   r   r     s    z_OpPickleData._getattr_by_nameN)r!   r"   r#   r\   r'   r-   r[   r   rU   r,   r7   r   rb   r^   r$   r   rv   r   r   rg   r   r   r   r   r   r   r>     s(   r>   c                   @   s,   e Zd ZeddddZeedddZdS )r   Nr   c                 C   s
   || _ d S r4   r   r0   r   r   r   r   r+     s    z_OpStrPickleData.__init__rl   c                 C   s   | j S r4   r   r}   r   r   r   rg     s    z_OpStrPickleData.unpickle)r!   r"   r#   r$   r+   rU   rg   r   r   r   r   r     s   r   c                   @   s0   e Zd ZeddddZeejjdddZ	dS )r   Nr   c                 C   s
   || _ d S r4   r   r   r   r   r   r+     s    z_OpOverloadPickleData.__init__rl   c                 C   s"   |  | j}t|tjjsJ |S r4   )r   r   r5   r8   r<   r   r0   rI   r3   r   r   r   rg     s    z_OpOverloadPickleData.unpickle)
r!   r"   r#   r$   r+   rU   r8   r<   r   rg   r   r   r   r   r     s   r   c                   @   s0   e Zd ZeddddZeejjdddZ	dS )r   Nr   c                 C   s
   || _ d S r4   r   r   r   r   r   r+     s    z$_OpOverloadPacketPickleData.__init__rl   c                 C   s"   |  | j}t|tjjsJ |S r4   )r   r   r5   r8   r<   r=   r   r   r   r   rg     s    z$_OpOverloadPacketPickleData.unpickle)
r!   r"   r#   r$   r+   rU   r8   r<   r=   rg   r   r   r   r   r     s   r   c                   @   s.   e Zd ZeeddddZeedddZdS )r   Nr   c                 C   s   || _ || _d S r4   )r   r   )r0   r   r   r   r   r   r+     s    z_OpBuiltinPickleData.__init__rl   c                 C   sV   | j dkrt| jS | j dkr6dd l}| || jS | j dkrN| t| jS td S )Nbuiltinsmathr   r8   )r   __builtins__r   r   r   r   r8   ru   )r0   rI   r   r   r   r   rg     s    


z_OpBuiltinPickleData.unpickler!   r"   r#   r$   r+   rU   r-   rg   r   r   r   r   r     s   r   c                   @   s,   e Zd ZeddddZeedddZdS )r   Nr   c                 C   s
   || _ d S r4   r   r   r   r   r   r+     s    z_OpOperatorPickleData.__init__rl   c                 C   s   dd l }| || jS )Nr   )operatorr   r   )r0   rI   r   r   r   r   rg     s    z_OpOperatorPickleData.unpickler   r   r   r   r   r     s   r   c                   @   s<   e Zd ZejjeddddZejje	ejjdddZ
dS )r   N)r   r)   r   c                 C   sB   |j | _|j| _i }|jD ]}t|||||< qt| | _d S r4   )Z_tracer_cls
tracer_clsZ_tracer_extrastracer_extrasnodesr   r[   values)r0   r   r)   r   rs   r   r   r   r+     s    
z_GraphPickleData.__init__)r   rI   r   c                 C   s:   t j|| j| j}i }| jD ]}||||||< q|S r4   )r8   r9   r   r   r   r   rg   )r0   r   rI   r   r   Zndr   r   r   rg   !  s
    
z_GraphPickleData.unpickle)r!   r"   r#   r8   r9   r   r   r+   r:   rU   rg   r   r   r   r   r     s   r   c                   @   sf   e Zd Zeeejjee	e
egejjf ee
ef f dddZeddddZeedd	d
ZdS )rC   re   c                 C   s   | j | ||jffS r4   rf   rh   r   r   r   r7   .  s
    z'_TracingContextPickleData.reduce_helperN)contextr   c                 C   sL   |j | _ |j| _|j| _|j| _|j| _|j| _|j| _|j| _|j| _d S r4   )	module_contextframe_summary_stackloc_in_frameaot_graph_nameparams_flatparams_flat_unwrap_subclassesparams_unwrapped_to_flat_indexoutput_strides#force_unspec_int_unbacked_size_like)r0   r   r   r   r   r+   =  s    z"_TracingContextPickleData.__init__rl   c                 C   sV   t |j}| j|_| j|_| j|_| j|_| j|_| j|_| j|_| j	|_	| j
|_
|S r4   )r   rT   r   r   r   r   r   r   r   r   r   )r0   rI   r   r   r   r   rg   S  s    
z"_TracingContextPickleData.unpickle)r!   r"   r#   r\   r'   r8   rB   r   r[   r   r
   rU   r,   r7   r+   rg   r   r   r   r   rC   -  s   
rC   )Br   r   rK   rb   abcr   typingr   r   r   r   r   r   Ztyping_extensionsr	   r
   r8   Ztorch.utils._pytreeutilsZ_pytreer   Ztorch._guardsr   Ztorch._subclasses.fake_tensorr   r   r   Ztorch._subclasses.meta_utilsr   r   r   r{   r   Z%torch.fx.experimental.symbolic_shapesr   Ztorch.utils._mode_utilsr   r@   ZSymFloatr   r$   r%   r   Z	dataclassr   Picklerr'   rU   r-   r,   	UnpicklerrV   r?   rA   r6   rE   r;   r   r>   r   r   r   r   r   r   rC   r   r   r   r   <module>   sH    [
>,(E


 