a
    h`                     @   s  d dl Z d dlmZmZ d dlmZmZ d dlmZm	Z	m
Z
 d dlmZ d dlmZ e	rld dlZd dlmZ g dZe jG d	d
 d
Ze jG dd dZe jG dd dZe jG dd dZe jG dd dZe jG dd dZe jG dd dZe
eeeeeeef ZG dd deZe jG dd dZG dd deZe jG dd dZe jG dd  d Z e jG d!d" d"Z!d#d$ Z"ed%d&d'Z#d(d)e$e% d"d*d+d,Z&dS )-    N)
CollectionMapping)autoEnum)OptionalTYPE_CHECKINGUnionFakeScriptObject)is_fake)GraphSignature)ConstantArgumentCustomObjArgumentExportBackwardSignatureExportGraphSignature	InputKind	InputSpec
OutputKind
OutputSpecSymIntArgumentSymFloatArgumentSymBoolArgumentTensorArgumentc                   @   s   e Zd ZU eed< dS )r   nameN__name__
__module____qualname__str__annotations__ r    r    J/var/www/auris/lib/python3.9/site-packages/torch/export/graph_signature.pyr      s   
r   c                   @   s   e Zd ZU eed< dS )TokenArgumentr   Nr   r    r    r    r!   r"   $   s   
r"   c                   @   s   e Zd ZU eed< dS )r   r   Nr   r    r    r    r!   r   )   s   
r   c                   @   s   e Zd ZU eed< dS )r   r   Nr   r    r    r    r!   r   .   s   
r   c                   @   s   e Zd ZU eed< dS )r   r   Nr   r    r    r    r!   r   3   s   
r   c                   @   s.   e Zd ZU eed< eed< dZee ed< dS )r   r   	class_fqnNfake_val)r   r   r   r   r   r$   r   r
   r    r    r    r!   r   8   s   
r   c                   @   s,   e Zd ZU eed< eeeeedf ed< dS )r   r   Nvalue)	r   r   r   r   r   r   intfloatboolr    r    r    r!   r   ?   s   
r   c                   @   s0   e Zd Ze Ze Ze Ze Ze Ze Z	dS )r   N)
r   r   r   r   
USER_INPUT	PARAMETERBUFFERCONSTANT_TENSOR
CUSTOM_OBJTOKENr    r    r    r!   r   P   s   r   c                   @   sJ   e Zd ZU eed< eed< ee ed< dZee	 ed< dd Z
dd	 ZdS )
r   kindargtargetN
persistentc              	   C   sP   | j tjkr| jd usJ dt| jtttt	t
ttfsLJ dt| j d S )Nz,Failed to specify persistent flag on BUFFER.zgot )r/   r   r+   r2   
isinstancer0   r   r   r   r   r   r   r"   typeselfr    r    r!   __post_init__`   s"    zInputSpec.__post_init__c                 C   s\   | j d u rdnd| j  d}| jd u r*dn
d| j }t| jj dt| jj | | S )N 	 target=''z persistent=: )r1   r2   r   r0   r   r/   )r6   r1   r2   r    r    r!   __str__r   s    zInputSpec.__str__)r   r   r   r   r   ArgumentSpecr   r   r2   r(   r7   r<   r    r    r    r!   r   Y   s   
r   c                   @   s6   e Zd Ze Ze Ze Ze Ze Ze Z	e Z
dS )r   N)r   r   r   r   USER_OUTPUTLOSS_OUTPUTBUFFER_MUTATIONGRADIENT_TO_PARAMETERGRADIENT_TO_USER_INPUTUSER_INPUT_MUTATIONr.   r    r    r    r!   r   x   s   r   c                   @   s:   e Zd ZU eed< eed< ee ed< dd Zdd Z	dS )	r   r/   r0   r1   c              	   C   s(   t | jtttttttfs$J | jd S N)	r3   r0   r   r   r   r   r   r"   r   r5   r    r    r!   r7      s    zOutputSpec.__post_init__c                 C   s>   | j d u rdnd| j  d}t| jj dt| jj | S )Nr8   r9   r:   r;   )r1   r   r0   r   r/   )r6   r1   r    r    r!   r<      s    zOutputSpec.__str__N)
r   r   r   r   r   r=   r   r   r7   r<   r    r    r    r!   r      s
   
r   c                   @   s6   e Zd ZU eeef ed< eeef ed< eed< dS )r   gradients_to_parametersgradients_to_user_inputsloss_outputN)r   r   r   dictr   r   r    r    r    r!   r      s   
r   c                   @   s  e Zd ZU dZee ed< ee ed< ee	e
 dddZee	e
 dddZee	e
 dd	d
Zee	e
 dddZee	e
 dddZee	eeeede
f  dddZee	eeeede
f  dddZeee
e
f dddZeee
e
f dddZeee
e
f dddZeee
e
f dddZeee
e
f dddZeee
e
f dddZeee dd d!Zeeeee
f  dd"d#Z ee	e
 dd$d%Z!ee	e
 dd&d'Z"ddd(d)Z#e
e
d*d+d,Z$d2d.d/Z%d0d1 Z&dS )3r   a,  
    :class:`ExportGraphSignature` models the input/output signature of Export Graph,
    which is a fx.Graph with stronger invariants gurantees.

    Export Graph is functional and does not access "states" like parameters
    or buffers within the graph via ``getattr`` nodes. Instead, :func:`export`
    gurantees that parameters, buffers, and constant tensors are lifted out of
    the graph as inputs.  Similarly, any mutations to buffers are not included
    in the graph either, instead the updated values of mutated buffers are
    modeled as additional outputs of Export Graph.

    The ordering of all inputs and outputs are::

        Inputs = [*parameters_buffers_constant_tensors, *flattened_user_inputs]
        Outputs = [*mutated_inputs, *flattened_user_outputs]

    e.g. If following module is exported::

        class CustomModule(nn.Module):
            def __init__(self) -> None:
                super(CustomModule, self).__init__()

                # Define a parameter
                self.my_parameter = nn.Parameter(torch.tensor(2.0))

                # Define two buffers
                self.register_buffer("my_buffer1", torch.tensor(3.0))
                self.register_buffer("my_buffer2", torch.tensor(4.0))

            def forward(self, x1, x2):
                # Use the parameter, buffers, and both inputs in the forward method
                output = (
                    x1 + self.my_parameter
                ) * self.my_buffer1 + x2 * self.my_buffer2

                # Mutate one of the buffers (e.g., increment it by 1)
                self.my_buffer2.add_(1.0)  # In-place addition

                return output


        mod = CustomModule()
        ep = torch.export.export(mod, (torch.tensor(1.0), torch.tensor(2.0)))

    Resulting Graph is non-functional::

        graph():
            %p_my_parameter : [num_users=1] = placeholder[target=p_my_parameter]
            %b_my_buffer1 : [num_users=1] = placeholder[target=b_my_buffer1]
            %b_my_buffer2 : [num_users=2] = placeholder[target=b_my_buffer2]
            %x1 : [num_users=1] = placeholder[target=x1]
            %x2 : [num_users=1] = placeholder[target=x2]
            %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x1, %p_my_parameter), kwargs = {})
            %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %b_my_buffer1), kwargs = {})
            %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x2, %b_my_buffer2), kwargs = {})
            %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul, %mul_1), kwargs = {})
            %add_ : [num_users=0] = call_function[target=torch.ops.aten.add_.Tensor](args = (%b_my_buffer2, 1.0), kwargs = {})
            return (add_1,)

    Resulting ExportGraphSignature of the non-functional Graph would be::

        # inputs
        p_my_parameter: PARAMETER target='my_parameter'
        b_my_buffer1: BUFFER target='my_buffer1' persistent=True
        b_my_buffer2: BUFFER target='my_buffer2' persistent=True
        x1: USER_INPUT
        x2: USER_INPUT

        # outputs
        add_1: USER_OUTPUT

    To get a functional Graph, you can use :func:`run_decompositions`::

        mod = CustomModule()
        ep = torch.export.export(mod, (torch.tensor(1.0), torch.tensor(2.0)))
        ep = ep.run_decompositions()

    Resulting Graph is functional::

        graph():
            %p_my_parameter : [num_users=1] = placeholder[target=p_my_parameter]
            %b_my_buffer1 : [num_users=1] = placeholder[target=b_my_buffer1]
            %b_my_buffer2 : [num_users=2] = placeholder[target=b_my_buffer2]
            %x1 : [num_users=1] = placeholder[target=x1]
            %x2 : [num_users=1] = placeholder[target=x2]
            %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x1, %p_my_parameter), kwargs = {})
            %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %b_my_buffer1), kwargs = {})
            %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x2, %b_my_buffer2), kwargs = {})
            %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul, %mul_1), kwargs = {})
            %add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%b_my_buffer2, 1.0), kwargs = {})
            return (add_2, add_1)

    Resulting ExportGraphSignature of the functional Graph would be::

        # inputs
        p_my_parameter: PARAMETER target='my_parameter'
        b_my_buffer1: BUFFER target='my_buffer1' persistent=True
        b_my_buffer2: BUFFER target='my_buffer2' persistent=True
        x1: USER_INPUT
        x2: USER_INPUT

        # outputs
        add_2: BUFFER_MUTATION target='my_buffer2'
        add_1: USER_OUTPUT

    input_specsoutput_specsreturnc                 C   s   t dd | jD S )Nc                 s   s,   | ]$}|j tjkrt|jtr|jV  qd S rD   )r/   r   r*   r3   r1   r   .0sr    r    r!   	<genexpr>  s   z2ExportGraphSignature.parameters.<locals>.<genexpr>tuplerI   r5   r    r    r!   
parameters  s    zExportGraphSignature.parametersc                 C   s   t dd | jD S )Nc                 s   s,   | ]$}|j tjkrt|jtr|jV  qd S rD   )r/   r   r+   r3   r1   r   rM   r    r    r!   rP     s   z/ExportGraphSignature.buffers.<locals>.<genexpr>rQ   r5   r    r    r!   buffers  s    zExportGraphSignature.buffersc                 C   s   t dd | jD S )Nc                 s   s6   | ].}|j tjkr|jd u rt|jtr|jV  qdS )FN)r/   r   r+   r2   r3   r1   r   rM   r    r    r!   rP   (  s
   
z>ExportGraphSignature.non_persistent_buffers.<locals>.<genexpr>rQ   r5   r    r    r!   non_persistent_buffers&  s    z+ExportGraphSignature.non_persistent_buffersc                 C   s   t dd | jD S )Nc                 s   s,   | ]$}|j tjkrt|jtr|jV  qd S rD   )r/   r   r,   r3   r1   r   rM   r    r    r!   rP   3  s   z?ExportGraphSignature.lifted_tensor_constants.<locals>.<genexpr>rQ   r5   r    r    r!   lifted_tensor_constants1  s    z,ExportGraphSignature.lifted_tensor_constantsc                 C   s   t dd | jD S )Nc                 s   s,   | ]$}|j tjkrt|jtr|jV  qd S rD   )r/   r   r-   r3   r1   r   rM   r    r    r!   rP   <  s   z:ExportGraphSignature.lifted_custom_objs.<locals>.<genexpr>rQ   r5   r    r    r!   lifted_custom_objs:  s    z'ExportGraphSignature.lifted_custom_objsNc                 C   sx   g }| j D ]d}|jtjkrq
t|jtttt	t
frB||jj q
t|jtr^||jj q
t|j dq
t|S )Nz is not a valid user inputs)rI   r/   r   r)   r3   r0   r   r   r   r   r   appendr   r   r%   RuntimeErrorrR   )r6   user_inputsrO   r    r    r!   rZ   D  s$    

z ExportGraphSignature.user_inputsc                 C   s   g }| j D ]}|jtjtjfvr"q
t|jttt	t
frF||jj q
t|jtrb||jj q
t|jtr~||jj q
t|j dq
t|S )Nz is not a valid user output)rJ   r/   r   r>   r?   r3   r0   r   r   r   r   rX   r   r   r%   r   rY   rR   )r6   user_outputsrO   r    r    r!   r[   ^  s$    

z!ExportGraphSignature.user_outputsc                 C   s   t dd | jD S )Nc                 s   s@   | ]8}|j tjkrt|jtrt|jtr|jj|jfV  qd S rD   )	r/   r   r*   r3   r0   r   r1   r   r   rM   r    r    r!   rP   y  s
   z<ExportGraphSignature.inputs_to_parameters.<locals>.<genexpr>_immutable_dictrI   r5   r    r    r!   inputs_to_parametersw  s    z)ExportGraphSignature.inputs_to_parametersc                 C   s   t dd | jD S )Nc                 s   s@   | ]8}|j tjkrt|jtrt|jtr|jj|jfV  qd S rD   )	r/   r   r+   r3   r0   r   r1   r   r   rM   r    r    r!   rP     s
   z9ExportGraphSignature.inputs_to_buffers.<locals>.<genexpr>r\   r5   r    r    r!   inputs_to_buffers  s    z&ExportGraphSignature.inputs_to_buffersc                 C   s   t dd | jD S )Nc                 s   s@   | ]8}|j tjkrt|jtrt|jtr|jj|jfV  qd S rD   )	r/   r   r@   r3   r0   r   r1   r   r   rM   r    r    r!   rP     s
   z9ExportGraphSignature.buffers_to_mutate.<locals>.<genexpr>r]   rJ   r5   r    r    r!   buffers_to_mutate  s    z&ExportGraphSignature.buffers_to_mutatec                 C   s   t dd | jD S )Nc                 s   s@   | ]8}|j tjkrt|jtrt|jtr|jj|jfV  qd S rD   )	r/   r   rC   r3   r0   r   r1   r   r   rM   r    r    r!   rP     s
   z=ExportGraphSignature.user_inputs_to_mutate.<locals>.<genexpr>r`   r5   r    r    r!   user_inputs_to_mutate  s    z*ExportGraphSignature.user_inputs_to_mutatec                 C   s   t dd | jD S )Nc                 s   s@   | ]8}|j tjkrt|jtrt|jtr|jj|jfV  qd S rD   )	r/   r   r,   r3   r0   r   r1   r   r   rM   r    r    r!   rP     s
   zIExportGraphSignature.inputs_to_lifted_tensor_constants.<locals>.<genexpr>r\   r5   r    r    r!   !inputs_to_lifted_tensor_constants  s    z6ExportGraphSignature.inputs_to_lifted_tensor_constantsc                 C   s   t dd | jD S )Nc                 s   s@   | ]8}|j tjkrt|jtrt|jtr|jj|jfV  qd S rD   )	r/   r   r-   r3   r0   r   r1   r   r   rM   r    r    r!   rP     s
   zDExportGraphSignature.inputs_to_lifted_custom_objs.<locals>.<genexpr>r\   r5   r    r    r!   inputs_to_lifted_custom_objs  s    z1ExportGraphSignature.inputs_to_lifted_custom_objsc                 C   s   d }i }i }| j D ]}|jtjkrH|d u s.J t|jts>J |jj}q|jtjkrt|j	t
sdJ t|jtstJ |j	||jj< q|jtjkrt|j	t
sJ t|jtsJ |j	||jj< q|d u rd S t|||dS )N)rG   rE   rF   )rJ   r/   r   r?   r3   r0   r   r   rA   r1   r   rB   r   )r6   rG   rE   rF   specr    r    r!   backward_signature  s.    

z'ExportGraphSignature.backward_signaturec                 C   s   d S rD   r    r5   r    r    r!   assertion_dep_token  s    z(ExportGraphSignature.assertion_dep_tokenc                 C   sB   g }| j D ].}|jtjkr
t|jts*J ||jj q
t	|S rD   )
rI   r/   r   r.   r3   r0   r"   rX   r   rR   )r6   input_tokensrO   r    r    r!   rh     s    
z!ExportGraphSignature.input_tokensc                 C   sB   g }| j D ].}|jtjkr
t|jts*J ||jj q
t	|S rD   )
rJ   r/   r   r.   r3   r0   r"   rX   r   rR   )r6   output_tokensrO   r    r    r!   ri     s    
z"ExportGraphSignature.output_tokensc                 C   sR   | j }|d u rd S t|dks"J tt| }t| jt| j |ksNJ d S )N   )rg   lennextiterkeysr[   ra   )r6   rg   Zassertion_dep_token_indexr    r    r!   r7     s    z"ExportGraphSignature.__post_init__)oldnewc                 C   s   t |tsJ t |tsJ ttttttf}| jD ]$}t |j	|r2|j	j
|kr2||j	_
q2| jD ]$}t |j	|r^|j	j
|kr^||j	_
q^dS )zR
        Replace all uses of the old name with new name in the signature.
        N)r3   r   r   r   r   r   r   r"   rJ   r0   r   rI   )r6   ro   rp   Z	arg_typesoir    r    r!   replace_all_uses  s"    


z%ExportGraphSignature.replace_all_usesFc                    s    fdd}|S )Nc                    s8   |j dkr| j|  r4| j dkr4| j| d S )Noutputplaceholder)oprs   r   )ro   rp   userreplace_inputsr6   r    r!   _  s    
z0ExportGraphSignature.get_replace_hook.<locals>._r    )r6   ry   rz   r    rx   r!   get_replace_hook  s    z%ExportGraphSignature.get_replace_hookc                 C   s>   d dd | jD }d dd | jD }d| d| dS )N
c                 s   s   | ]}t |V  qd S rD   r   rM   r    r    r!   rP         z/ExportGraphSignature.__str__.<locals>.<genexpr>c                 s   s   | ]}t |V  qd S rD   r}   rM   r    r    r!   rP     r~   z

# inputs
z

# outputs
)joinrI   rJ   )r6   rI   rJ   r    r    r!   r<     s    zExportGraphSignature.__str__)F)'r   r   r   __doc__listr   r   r   propertyr   r   rS   rT   rU   rV   rW   r   r&   r'   r(   rZ   r[   r   r^   r_   ra   rb   rc   rd   r   r   rf   rg   rh   ri   r7   rs   r{   r<   r    r    r    r!   r      sR   
k	
	""	
		
	r   c                 C   s   ddl m} |t| S )z
    Creates a mapping where items cannot be added, deleted, or updated.
    NOTE: The immutability is shallow (like tuple is an immutable collection).
    r   )MappingProxyType)typesr   rH   )itemsr   r    r    r!   r]     s    r]   rK   c                 C   sV  ddl m}m}m}m} ddlm} t| tt	t
td tfrHtd| dS d| jv s`J |  d| jd }| j|v rt| jdS t|rt| jdS t||rt| jdS t||rt| jdS t||rt| jdS t|| rt| j|  d	S t||rt| j|j|d
S t|tt	tt
td fr>t| j|dS tdt| dd S )Nr   )ScriptObjectSymBoolSymFloatSymIntr	   r8   )r   r%   valz8 is not a constant or a node with a 'val' metadata field)r   )r   r#   )r   r#   r$   z*Encountered an unsupported object of type z0 while writing the metadata for exported program)torchr   r   r   r   "torch._library.fake_class_registryr
   r3   r&   r(   r'   r4   r   r   metar   r"   r   r   r   r   r   r   _typeZqualified_nameZscript_class_nameAssertionError)nodeZtoken_namesr   r   r   r   r
   r   r    r    r!   _make_argument_spec(  s:    





r   r   ztorch.fx.GraphModule)graph_signaturegmrU   rL   c           	         s.  ddl m} | jd u}t| j| j| jt| j| j | j	|rN| jj
ni |r^| jjni |rn| jjnd | j| jfdd|jjD }fdd|ttt|jjjD }ttdfdd	tttd	 fd
d
	fdd|D }
fddt|D }t||dS )Nr   )_pytreec                    s    g | ]}|j d krt| qS )ru   )rv   r   rN   r   )rh   r    r!   
<listcomp>o  s   
z6_convert_to_export_graph_signature.<locals>.<listcomp>c                    s   g | ]}t | qS r    )r   r   )ri   r    r!   r   t  s   )inprL   c                    s   t | trttj| d dS t | ts4ttj| d dS | j}|v rRttj| d dS |v rnttj| | dS | v rttj	|  |  | vdS t
d| d S )Nr/   r0   r1   )r/   r0   r1   r2   zUnknown tensor input kind: )r3   r"   r   r   r.   r   r)   r   r*   r+   r   )r   r   )r_   r^   rU   rZ   r    r!   to_input_specy  s*    


z9_convert_to_export_graph_signature.<locals>.to_input_spec)idxrq   rL   c                    s  t |trttj|d dS t |ts4ttj|d dS |j}| t t t k r| v rrttj	| | dS |v rttj
|| dS td| nx|v rttj|d dS |v rttj|| dS |v rttj|| dS |krttj|d dS td| d S )Nr   zUnknown tensor mutation kind: zUnknown tensor output kind: )r3   r"   r   r   r.   r   r>   r   rk   r@   rC   r   rA   rB   r?   )r   rq   r   )buffer_mutationsgrad_paramsgrad_user_inputsrG   ri   user_input_mutationsr[   r    r!   to_output_spec  sH    


z:_convert_to_export_graph_signature.<locals>.to_output_specc                    s   g | ]} |qS r    r    )rN   r   )r   r    r!   r     r~   c                    s   g | ]\}} ||qS r    r    )rN   r   rq   )r   r    r!   r     r~   )rI   rJ   )Ztorch.utilsr   rf   setrZ   r^   r_   r[   ra   rb   Zgradients_to_parameterrF   rG   rh   ri   graphnodesZtree_leavesrl   rm   reversedargsr=   r   r&   r   	enumerater   )	r   r   rU   ZpytreeZis_jointinputsoutputsrI   rJ   r    )r   r   r   rh   r_   r^   rG   rU   ri   r   r   r   rZ   r[   r!   "_convert_to_export_graph_signatureM  sB    







",r   )'Zdataclassescollections.abcr   r   enumr   r   typingr   r   r   r   r
   Ztorch._subclasses.fake_tensorr   r   Z&torch._functorch._aot_autograd.schemasr   __all__Z	dataclassr   r"   r   r   r   r   r   r=   r   r   r   r   r   r   r]   r   r   r   r   r    r    r    r!   <module>   sh   	
  }
&