o
    Zh{X                     @   s  d dl Z d dlmZmZ d dlmZmZ d dlmZm	Z	m
Z
 d dlmZ d dlmZ e	r6d dlZd dlmZ g dZe jG d	d
 d
Ze jG dd dZe jG dd dZe jG dd dZe jG dd dZe jG dd dZe jG dd dZe
eeeeeeef ZG dd deZe jG dd dZG dd deZe jG dd dZe jG dd  d Z e jG d!d" d"Z!d#d$ Z"d%efd&d'Z#d(d)d*d+d,e$e% d%d"fd-d.Z&dS )/    N)
CollectionMapping)autoEnum)OptionalTYPE_CHECKINGUnionFakeScriptObject)is_fake)GraphSignature)ConstantArgumentCustomObjArgumentExportBackwardSignatureExportGraphSignature	InputKind	InputSpec
OutputKind
OutputSpecSymIntArgumentSymFloatArgumentSymBoolArgumentTensorArgumentc                   @      e Zd ZU eed< dS )r   nameN__name__
__module____qualname__str__annotations__ r!   r!   K/var/www/auris/lib/python3.10/site-packages/torch/export/graph_signature.pyr         
 r   c                   @   r   )TokenArgumentr   Nr   r!   r!   r!   r"   r$   $   r#   r$   c                   @   r   )r   r   Nr   r!   r!   r!   r"   r   )   r#   r   c                   @   r   )r   r   Nr   r!   r!   r!   r"   r   .   r#   r   c                   @   r   )r   r   Nr   r!   r!   r!   r"   r   3   r#   r   c                   @   s.   e Zd ZU eed< eed< dZee ed< dS )r   r   	class_fqnNfake_val)r   r   r   r   r    r&   r   r
   r!   r!   r!   r"   r   8   s   
 r   c                   @   s,   e Zd ZU eed< eeeeedf ed< dS )r   r   Nvalue)	r   r   r   r   r    r   intfloatboolr!   r!   r!   r"   r   ?   s   
 r   c                   @   s0   e Zd Ze Ze Ze Ze Ze Ze Z	dS )r   N)
r   r   r   r   
USER_INPUT	PARAMETERBUFFERCONSTANT_TENSOR
CUSTOM_OBJTOKENr!   r!   r!   r"   r   P   s    
r   c                   @   sB   e Zd ZU eed< eed< ee ed< dZee	 ed< dd Z
dS )r   kindargtargetN
persistentc              	   C   sP   | j tjkr| jd usJ dt| jtttt	t
ttfs&J dt| j d S )Nz,Failed to specify persistent flag on BUFFER.zgot )r1   r   r-   r4   
isinstancer2   r   r   r   r   r   r   r$   typeselfr!   r!   r"   __post_init__`   s"   zInputSpec.__post_init__)r   r   r   r   r    ArgumentSpecr   r   r4   r*   r9   r!   r!   r!   r"   r   Y   s   
 r   c                   @   s6   e Zd Ze Ze Ze Ze Ze Ze Z	e Z
dS )r   N)r   r   r   r   USER_OUTPUTLOSS_OUTPUTBUFFER_MUTATIONGRADIENT_TO_PARAMETERGRADIENT_TO_USER_INPUTUSER_INPUT_MUTATIONr0   r!   r!   r!   r"   r   s   s    
r   c                   @   s2   e Zd ZU eed< eed< ee ed< dd ZdS )r   r1   r2   r3   c              	   C   s(   t | jtttttttfsJ | jd S N)	r5   r2   r   r   r   r   r   r$   r   r7   r!   r!   r"   r9      s   zOutputSpec.__post_init__N)	r   r   r   r   r    r:   r   r   r9   r!   r!   r!   r"   r   }   s
   
 r   c                   @   s6   e Zd ZU eeef ed< eeef ed< eed< dS )r   gradients_to_parametersgradients_to_user_inputsloss_outputN)r   r   r   dictr   r    r!   r!   r!   r"   r      s   
 r   c                	   @   s  e Zd ZU dZee ed< ee ed< ede	e
 fddZede	e
 fddZede	e
 fd	d
Zede	e
 fddZede	e
 fddZede	eeeede
f  fddZede	eeeede
f  fddZedee
e
f fddZedee
e
f fddZedee
e
f fddZedee
e
f fddZedee
e
f fddZedee
e
f fddZedee fd d!Zedeeee
f  fd"d#Z ede	e
 fd$d%Z!ede	e
 fd&d'Z"d1d(d)Z#d*e
d+e
fd,d-Z$d2d/d0Z%dS )3r   a  
    :class:`ExportGraphSignature` models the input/output signature of Export Graph,
    which is a fx.Graph with stronger invariants gurantees.

    Export Graph is functional and does not access "states" like parameters
    or buffers within the graph via ``getattr`` nodes. Instead, :func:`export`
    gurantees that parameters, buffers, and constant tensors are lifted out of
    the graph as inputs.  Similarly, any mutations to buffers are not included
    in the graph either, instead the updated values of mutated buffers are
    modeled as additional outputs of Export Graph.

    The ordering of all inputs and outputs are::

        Inputs = [*parameters_buffers_constant_tensors, *flattened_user_inputs]
        Outputs = [*mutated_inputs, *flattened_user_outputs]

    e.g. If following module is exported::

        class CustomModule(nn.Module):
            def __init__(self) -> None:
                super(CustomModule, self).__init__()

                # Define a parameter
                self.my_parameter = nn.Parameter(torch.tensor(2.0))

                # Define two buffers
                self.register_buffer('my_buffer1', torch.tensor(3.0))
                self.register_buffer('my_buffer2', torch.tensor(4.0))

            def forward(self, x1, x2):
                # Use the parameter, buffers, and both inputs in the forward method
                output = (x1 + self.my_parameter) * self.my_buffer1 + x2 * self.my_buffer2

                # Mutate one of the buffers (e.g., increment it by 1)
                self.my_buffer2.add_(1.0) # In-place addition

                return output

    Resulting Graph would be::

        graph():
            %arg0_1 := placeholder[target=arg0_1]
            %arg1_1 := placeholder[target=arg1_1]
            %arg2_1 := placeholder[target=arg2_1]
            %arg3_1 := placeholder[target=arg3_1]
            %arg4_1 := placeholder[target=arg4_1]
            %add_tensor := call_function[target=torch.ops.aten.add.Tensor](args = (%arg3_1, %arg0_1), kwargs = {})
            %mul_tensor := call_function[target=torch.ops.aten.mul.Tensor](args = (%add_tensor, %arg1_1), kwargs = {})
            %mul_tensor_1 := call_function[target=torch.ops.aten.mul.Tensor](args = (%arg4_1, %arg2_1), kwargs = {})
            %add_tensor_1 := call_function[target=torch.ops.aten.add.Tensor](args = (%mul_tensor, %mul_tensor_1), kwargs = {})
            %add_tensor_2 := call_function[target=torch.ops.aten.add.Tensor](args = (%arg2_1, 1.0), kwargs = {})
            return (add_tensor_2, add_tensor_1)

    Resulting ExportGraphSignature would be::

        ExportGraphSignature(
            input_specs=[
                InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='arg0_1'), target='my_parameter'),
                InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg1_1'), target='my_buffer1'),
                InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg2_1'), target='my_buffer2'),
                InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg3_1'), target=None),
                InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg4_1'), target=None)
            ],
            output_specs=[
                OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='add_2'), target='my_buffer2'),
                OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_1'), target=None)
            ]
        )
    input_specsoutput_specsreturnc                 C      t dd | jD S )Nc                 s   .    | ]}|j tjkrt|jtr|jV  qd S rA   )r1   r   r,   r5   r3   r   .0sr!   r!   r"   	<genexpr>       
z2ExportGraphSignature.parameters.<locals>.<genexpr>tuplerF   r7   r!   r!   r"   
parameters      zExportGraphSignature.parametersc                 C   rI   )Nc                 s   rJ   rA   )r1   r   r-   r5   r3   r   rK   r!   r!   r"   rN      rO   z/ExportGraphSignature.buffers.<locals>.<genexpr>rP   r7   r!   r!   r"   buffers   rS   zExportGraphSignature.buffersc                 C   rI   )Nc                 s   s8    | ]}|j tjkr|jd u rt|jtr|jV  qdS )FN)r1   r   r-   r4   r5   r3   r   rK   r!   r!   r"   rN      s    

z>ExportGraphSignature.non_persistent_buffers.<locals>.<genexpr>rP   r7   r!   r!   r"   non_persistent_buffers   rS   z+ExportGraphSignature.non_persistent_buffersc                 C   rI   )Nc                 s   rJ   rA   )r1   r   r.   r5   r3   r   rK   r!   r!   r"   rN     rO   z?ExportGraphSignature.lifted_tensor_constants.<locals>.<genexpr>rP   r7   r!   r!   r"   lifted_tensor_constants  rS   z,ExportGraphSignature.lifted_tensor_constantsc                 C   rI   )Nc                 s   rJ   rA   )r1   r   r/   r5   r3   r   rK   r!   r!   r"   rN     rO   z:ExportGraphSignature.lifted_custom_objs.<locals>.<genexpr>rP   r7   r!   r!   r"   lifted_custom_objs  rS   z'ExportGraphSignature.lifted_custom_objsNc                 C   sv   g }| j D ]1}|jtjkrqt|jtttt	t
fr!||jj qt|jtr/||jj qt|j dt|S )Nz is not a valid user inputs)rF   r1   r   r+   r5   r2   r   r   r   r   r   appendr   r   r'   RuntimeErrorrQ   )r8   user_inputsrM   r!   r!   r"   rZ     s$   

z ExportGraphSignature.user_inputsc                 C   s   g }| j D ]A}|jtjtjfvrqt|jttt	t
fr#||jj qt|jtr1||jj qt|jtr?||jj qt|j dt|S )Nz is not a valid user output)rG   r1   r   r;   r<   r5   r2   r   r   r   r   rX   r   r   r'   r   rY   rQ   )r8   user_outputsrM   r!   r!   r"   r[   0  s$   

z!ExportGraphSignature.user_outputsc                 C   rI   )Nc                 s   B    | ]}|j tjkrt|jtrt|jtr|jj|jfV  qd S rA   )	r1   r   r,   r5   r2   r   r3   r   r   rK   r!   r!   r"   rN   K      

z<ExportGraphSignature.inputs_to_parameters.<locals>.<genexpr>_immutable_dictrF   r7   r!   r!   r"   inputs_to_parametersI  rS   z)ExportGraphSignature.inputs_to_parametersc                 C   rI   )Nc                 s   r\   rA   )	r1   r   r-   r5   r2   r   r3   r   r   rK   r!   r!   r"   rN   W  r]   z9ExportGraphSignature.inputs_to_buffers.<locals>.<genexpr>r^   r7   r!   r!   r"   inputs_to_buffersU  rS   z&ExportGraphSignature.inputs_to_buffersc                 C   rI   )Nc                 s   r\   rA   )	r1   r   r=   r5   r2   r   r3   r   r   rK   r!   r!   r"   rN   c  r]   z9ExportGraphSignature.buffers_to_mutate.<locals>.<genexpr>r_   rG   r7   r!   r!   r"   buffers_to_mutatea  rS   z&ExportGraphSignature.buffers_to_mutatec                 C   rI   )Nc                 s   r\   rA   )	r1   r   r@   r5   r2   r   r3   r   r   rK   r!   r!   r"   rN   m  r]   z=ExportGraphSignature.user_inputs_to_mutate.<locals>.<genexpr>rb   r7   r!   r!   r"   user_inputs_to_mutatek  rS   z*ExportGraphSignature.user_inputs_to_mutatec                 C   rI   )Nc                 s   r\   rA   )	r1   r   r.   r5   r2   r   r3   r   r   rK   r!   r!   r"   rN   x  r]   zIExportGraphSignature.inputs_to_lifted_tensor_constants.<locals>.<genexpr>r^   r7   r!   r!   r"   !inputs_to_lifted_tensor_constantsv  rS   z6ExportGraphSignature.inputs_to_lifted_tensor_constantsc                 C   rI   )Nc                 s   r\   rA   )	r1   r   r/   r5   r2   r   r3   r   r   rK   r!   r!   r"   rN     r]   zDExportGraphSignature.inputs_to_lifted_custom_objs.<locals>.<genexpr>r^   r7   r!   r!   r"   inputs_to_lifted_custom_objs  rS   z1ExportGraphSignature.inputs_to_lifted_custom_objsc                 C   s   d }i }i }| j D ]V}|jtjkr$|d u sJ t|jtsJ |jj}q	|jtjkrBt|j	t
s2J t|jts:J |j	||jj< q	|jtjkr_t|j	t
sPJ t|jtsXJ |j	||jj< q	|d u rfd S t|||dS )N)rD   rB   rC   )rG   r1   r   r<   r5   r2   r   r   r>   r3   r   r?   r   )r8   rD   rB   rC   specr!   r!   r"   backward_signature  s0   

z'ExportGraphSignature.backward_signaturec                 C   s   d S rA   r!   r7   r!   r!   r"   assertion_dep_token  s   z(ExportGraphSignature.assertion_dep_tokenc                 C   B   g }| j D ]}|jtjkrt|jtsJ ||jj qt	|S rA   )
rF   r1   r   r0   r5   r2   r$   rX   r   rQ   )r8   input_tokensrM   r!   r!   r"   rk        
z!ExportGraphSignature.input_tokensc                 C   rj   rA   )
rG   r1   r   r0   r5   r2   r$   rX   r   rQ   )r8   output_tokensrM   r!   r!   r"   rm     rl   z"ExportGraphSignature.output_tokensc                 C   sR   | j }|d u r	d S t|dksJ tt| }t| jt| j |ks'J d S )N   )ri   lennextiterkeysr[   rc   )r8   ri   Zassertion_dep_token_indexr!   r!   r"   r9     s   z"ExportGraphSignature.__post_init__oldnewc                 C   s   t |tsJ t |tsJ ttttttf}| jD ]}t |j	|r+|j	j
|kr+||j	_
q| jD ]}t |j	|rA|j	j
|krA||j	_
q/dS )zR
        Replace all uses of the old name with new name in the signature.
        N)r5   r   r   r   r   r   r   r$   rG   r2   r   rF   )r8   rs   rt   Z	arg_typesoir!   r!   r"   replace_all_uses  s(   

z%ExportGraphSignature.replace_all_usesFc                    s    fdd}|S )Nc                    s@   |j dkr| j|  r| j dkr| j| d S d S d S )Noutputplaceholder)oprw   r   )rs   rt   userreplace_inputsr8   r!   r"   _  s
   
z0ExportGraphSignature.get_replace_hook.<locals>._r!   )r8   r}   r~   r!   r|   r"   get_replace_hook  s   z%ExportGraphSignature.get_replace_hook)rH   N)F)&r   r   r   __doc__listr   r    r   propertyr   r   rR   rT   rU   rV   rW   r   r(   r)   r*   rZ   r[   r   r`   ra   rc   rd   re   rf   r   r   rh   ri   rk   rm   r9   rw   r   r!   r!   r!   r"   r      sR   
 F	
	""	
		
r   c                 C   s   ddl m} |t| S )z
    Creates a mapping where items cannot be added, deleted, or updated.
    NOTE: The immutability is shallow (like tuple is an immutable collection).
    r   )MappingProxyType)typesr   rE   )itemsr   r!   r!   r"   r_     s   r_   rH   c                 C   sL  ddl m}m}m}m} ddlm} t| tt	t
td tfr$td| dS d| jv s0J |  d| jd }| j|v r@t| jdS t|rJt| jdS t||rUt| jdS t||r`t| jdS t||rkt| jdS t||r{t| j|  d	S t||rt| j|j|d
S t|tt	tt
td frt| j|dS tdt| d)Nr   )ScriptObjectSymBoolSymFloatSymIntr	    )r   r'   valz8 is not a constant or a node with a 'val' metadata field)r   )r   r%   )r   r%   r&   z*Encountered an unsupported object of type z0 while writing the metadata for exported program)torchr   r   r   r   "torch._library.fake_class_registryr
   r5   r(   r*   r)   r6   r   r   metar   r$   r   r   r   r   r   r   _typeZqualified_nameZscript_class_nameAssertionError)nodeZtoken_namesr   r   r   r   r
   r   r!   r!   r"   _make_argument_spec  s:   







r   graph_signaturer   gmztorch.fx.GraphModulerU   c           	         s4  ddl m} | jd u}t| j| j| jt| j| j | j	|r'| jj
ni |r/| jjni |r7| jjnd | j| jfdd|jjD }fdd|ttt|jjjD }dtdtffdd		d
tdtdtf fdd
	fdd|D }
fddt|D }t||dS )Nr   )_pytreec                    s    g | ]}|j d krt| qS )ry   )rz   r   rL   r   )rk   r!   r"   
<listcomp>0  s
    
z6_convert_to_export_graph_signature.<locals>.<listcomp>c                    s   g | ]}t | qS r!   )r   r   )rm   r!   r"   r   5  s    inprH   c                    s   t | trttj| d dS t | tsttj| d dS | j}|v r)ttj| d dS |v r7ttj| | dS | v rJttj	|  |  | vdS t
d| )Nr1   r2   r3   )r1   r2   r3   r4   zUnknown tensor input kind: )r5   r$   r   r   r0   r   r+   r   r,   r-   r   )r   r   )ra   r`   rU   rZ   r!   r"   to_input_spec:  s*   


z9_convert_to_export_graph_signature.<locals>.to_input_specidxru   c                    s  t |trttj|d dS t |tsttj|d dS |j}| t t t k rN| v r9ttj	| | dS |v rGttj
|| dS td| |v rZttj|d dS |v rhttj|| dS |v rvttj|| dS |krttj|d dS td| )Nr   zUnknown tensor mutation kind: zUnknown tensor output kind: )r5   r$   r   r   r0   r   r;   r   ro   r=   r@   r   r>   r?   r<   )r   ru   r   )buffer_mutationsgrad_paramsgrad_user_inputsrD   rm   user_input_mutationsr[   r!   r"   to_output_specS  sH   

z:_convert_to_export_graph_signature.<locals>.to_output_specc                    s   g | ]} |qS r!   r!   )rL   r   )r   r!   r"   r     s    c                    s   g | ]	\}} ||qS r!   r!   )rL   r   ru   )r   r!   r"   r     s    )rF   rG   )Ztorch.utilsr   rh   setrZ   r`   ra   r[   rc   rd   Zgradients_to_parameterrC   rD   rk   rm   graphnodesZtree_leavesrp   rq   reversedargsr:   r   r(   r   	enumerater   )	r   r   rU   ZpytreeZis_jointZinputsZoutputsrF   rG   r!   )r   r   r   rk   ra   r`   rD   rU   rm   r   r   r   rZ   r[   r"   "_convert_to_export_graph_signature  s0   




&,r   )'dataclassescollections.abcr   r   enumr   r   typingr   r   r   r   r
   Ztorch._subclasses.fake_tensorr   r   Z&torch._functorch._aot_autograd.schemasr   __all__	dataclassr   r$   r   r   r   r   r   r:   r   r   r   r   r   r   r_   r   r   r   r   r!   r!   r!   r"   <module>   sr   	
  S
%