a
    hS>                     @   s|  d Z ddlZddlZddlZddlZddlZddlZddlZddlZddl	Z	ddl
Z
ddlZddlmZ ddlmZmZmZ ddlZddlZddlmZmZ ddlmZ ddlmZ eeZejd	d
G dd dZejG dd dZ ede!Z"ede!Z#ejG dd dZ$ejG dd dZ%ej&G dd dee% Z'G dd dZ(ej&G dd dee Z)G dd dZ*dS )a  
This module provides the infrastructure for creating and managing compile package
for torch.compile. We mainly have two abstractions here:
  - CompilePackage: Overarching data structure for store and lookup a list of compiled codes.
  - CodeCacheEntry: Data structure for a single code being compiled by torch.compile.
The caching behavior is always under user control explicitly so that a stronger guarantee can
be provided about cache hit for a specific compiled model. Users can load the compile package
from a different process or host.
    N)	Generator)AnyNewTypeOptional)PrecompileCacheArtifactPrecompileContext)CacheArtifactFactory   )get_code_keysT)frozenc                   @   s&  e Zd ZU eed< eed< eed< eed< eed< eed< eed< eedf ed	< eedf ed
< eedf ed< eed< eed< eed< eedf ed< eedf ed< dZ	e
e ed< dZe
e ed< dZe
e ed< dZe
e ed< eejejd dddZeejd ejdddZdS )SerializedCodeco_argcountco_posonlyargcountco_kwonlyargcount
co_nlocalsco_stacksizeco_flagsco_code.	co_constsco_namesco_varnamesco_filenameco_nameco_firstlinenoco_cellvarsco_freevarsNco_linetableco_qualnameco_exceptiontable	co_lnotabcodereturnc                    s@   fddt  D }t fdd|d D |d<  f i |S )Nc                    s   i | ]}|t  |qS  getattr.0key)r!   r#   C/var/www/auris/lib/python3.9/site-packages/torch/_dynamo/package.py
<dictcomp>=       z3SerializedCode.from_code_object.<locals>.<dictcomp>c                 3   s(   | ] }t |tjr |n|V  qd S N)
isinstancetypesCodeTypefrom_code_objectr'   cclsr#   r)   	<genexpr>>   s   z2SerializedCode.from_code_object.<locals>.<genexpr>r   )r
   tuple)r4   r!   kwargsr#   )r4   r!   r)   r0   :   s
    zSerializedCode.from_code_object)serialized_coder"   c                    s@   fddt  D }t fdd|d D |d< tj|  S )Nc                    s   i | ]}|t  |qS r#   r$   r&   )r8   r#   r)   r*   G   r+   z1SerializedCode.to_code_object.<locals>.<dictcomp>c                 3   s&   | ]}t |tr |n|V  qd S r,   )r-   r   to_code_objectr1   r3   r#   r)   r5   H   s   z0SerializedCode.to_code_object.<locals>.<genexpr>r   )r
   r6   r.   r/   values)r4   r8   r7   r#   )r4   r8   r)   r9   D   s    zSerializedCode.to_code_object)__name__
__module____qualname__int__annotations__bytesr6   r   strr   r   r   r   r   classmethod	functoolscacher.   r/   r0   r9   r#   r#   r#   r)   r   $   s2   
r   c                   @   s"   e Zd ZU dZeed< eed< dS )_GuardedCodeCacheEntrya  
    Contains the serializable information associated with a single compilation in dynamo.
    To restore an execution of compiled code, we will need to serialize the following data:
      - Dynamo bytecode for mapping Python inputs/outputs.
      - Dynamo guards.
    guards_statedynamo_codeN)r;   r<   r=   __doc__r@   r?   r   r#   r#   r#   r)   rE   Q   s   
rE   
_BackendId_FunctionIdc                   @   sV   e Zd ZU dZeed< eed< ee ed< ee	 ed< e
eef ed< ee ed< dS )	_DynamoCodeCacheEntrya  
    Contains the serializable information associated with a single code object
    in dynamo. To restore an execution of compiled code, we will need the following
    ingredients:
      1. The "original" code object, which serves as the entry point for eager
         execution, i.e. the code only executed when there's no cache entry hit.
      2. The python module name this code object belongs to, for identifying the
         enclosing global scope to inject compiled and resume functions.
      3. A list of function names that pointing to this code object. There could be
         multiple function objects pointing to the same code such as recursive functions.
      4. A list of guarded code that eval frame dispatches to.
      5. A list of imported module objects unioned from all compiled branches.
      6. A list of "backends" (compiled fx graph) unioned from all compield branches.
    python_codepython_modulefunction_namesguarded_codesimport_sourcesbackend_idsN)r;   r<   r=   rH   r   r?   rA   listrJ   rE   dictrI   r#   r#   r#   r)   rK   b   s   
rK   c                   @   sN   e Zd ZU ee ed< e Zeed< e	j
Zeed< eee dddZdS )_DynamoCacheEntrycodespython_versiontorch_versionr"   c                 C   s   dd | j D S )Nc                 S   s   h | ]}|j D ]}|qqS r#   )rQ   )r'   r!   
backend_idr#   r#   r)   	<setcomp>   r+   z0_DynamoCacheEntry.backend_ids.<locals>.<setcomp>rU   selfr#   r#   r)   rQ      s    z_DynamoCacheEntry.backend_idsN)r;   r<   r=   rR   rK   r?   platformrV   rA   torch__version__rW   propertysetrI   rQ   r#   r#   r#   r)   rT   {   s
   
rT   c                   @   s,   e Zd ZeedddZedddZdS )_DynamoCacheArtifactrX   c                   C   s   dS )NZprecompile_dynamor#   r#   r#   r#   r)   type   s    z_DynamoCacheArtifact.typec                 C   s   t | jS r,   pickleloadscontentr\   r#   r#   r)   after_deserialization   s    z*_DynamoCacheArtifact.after_deserializationN)r;   r<   r=   staticmethodrA   rd   rT   ri   r#   r#   r#   r)   rc      s   rc   c                   @   sV  e Zd ZdZd+eee ddddZd,eee ddddZd-e	j
eee ddd	d
Zeeeef dddZejedddZeje	j
ed dddZee	j
ddddZe	j
eee ddddZeeddddZd.eee ddddZdddd Ze	jeedd!d"d#Z ddd$d%Z!eeef dd&d'd(Z"edd)d*Z#dS )/CompilePackagea  
    CompilePackage is considered a low level component and should not be directly exposed to
    end users. It has the following interface:

    1. `CompilePackage.__init__()` which optionally takes previously serialized dynamo states.
        a. when `dynamo` argument is None, it will construct a brand new CompilePackage object.
        b. when `dynamo` argument is not None, it will load a pre-compiled dynamo state.
    2. `package.save()` which dumps the dynamo and backend states to a DynamoCacheEntry object.
    3. `package.install(backends) which will handle all the side-effectful global scope
        updates with compiled functions and resume functions.
    N)fndynamor"   c                 C   s>   d | _ i | _d | _i | _i | _| || |   |   d S r,   )_innermost_fn_codes_current_entry_installed_globals_cached_backends_initialize	uninstallvalidate)r]   rl   rm   r#   r#   r)   __init__   s    zCompilePackage.__init__c                 C   s   ddl m} ||| _| jd us$J |d urt|ts:J |jt krXtd|j |jt	j
krttd|j |j^}}| jj|i| _|D ]}|| jt|j< qn| | jj| jj d S )Nr	   )innermost_fnz=Compile package was created with a different Python version: z>Compile package was created with a different PyTorch version: )Z
eval_framerw   rn   r-   rT   rV   r^   RuntimeErrorrW   r_   r`   rU   __code__ro   r   r9   rL   _add_functionr<   )r]   rl   rm   rw   mainrU   r!   r#   r#   r)   rs      s(    



zCompilePackage._initialize)rL   rM   namer"   c                 C   s`   || j vr0tt||g g i g d}|| j |< n| j | }|j|ksHJ |d ur\|j| d S )N)rL   rM   rN   rO   rP   rQ   )ro   rK   r   r0   rM   rN   append)r]   rL   rM   r|   r!   r#   r#   r)   rz      s    

zCompilePackage._add_functionrX   c                 C   s   | j S r,   )rr   r\   r#   r#   r)   cached_backends   s    zCompilePackage.cached_backendsc                 C   sH   | j d usJ t }|| j j  |t| j jj  |	 S r,   )
rn   hashlibsha256updater=   encoderA   ry   r   	hexdigest)r]   Zsha256_hashr#   r#   r)   	source_id   s
    zCompilePackage.source_id)NNNr    c                 c   s<   | j d u sJ | j| }|| _ zd V  W d | _ nd | _ 0 d S r,   )rp   ro   )r]   r!   entryr#   r#   r)   code_context   s    
zCompilePackage.code_context)rF   rG   r"   c                 C   s2   | j d usJ t|t|d}| j j| d S )N)rF   rG   )rp   rE   r   r0   rO   r}   )r]   rF   rG   Zguarded_code_entryr#   r#   r)   add_guarded_code   s    zCompilePackage.add_guarded_codec                 C   s   |  |||rt|nd  d S r,   )rz   rJ   )r]   rL   rM   r|   r#   r#   r)   add_resume_function   s    z"CompilePackage.add_resume_function)aliasmodule_namer"   c                 C   s   | j d usJ || j j|< d S r,   )rp   rP   )r]   r   r   r#   r#   r)   add_import_source  s    z CompilePackage.add_import_sourcerY   backendr"   c                 C   sH   | j d usJ |dsJ t|}| j j| |d urD|| j|< d S )NZ__compiled_fn_)rp   
startswithrI   rQ   r}   rr   )r]   rY   r   r#   r#   r)   add_backend_id  s    zCompilePackage.add_backend_idc                 C   s:   | j d u sJ | jd usJ tt| j| jju s6J d S r,   )rp   rn   nextiterro   ry   r\   r#   r#   r)   ru     s    zCompilePackage.validate)moduler|   valuer"   c                 C   s"   ||j |< | j|g | d S r,   )__dict__rq   
setdefaultr}   )r]   r   r|   r   r#   r#   r)   _install_global  s    
zCompilePackage._install_globalc                 C   sZ   ddl m} | jd usJ | j D ]\}}|D ]}|j| q0q$i | _|| jj d S )Nr   )_reset_precompile_entries)torch._C._dynamo.eval_framer   rn   rq   itemsr   popry   )r]   r   r   namesr|   r#   r#   r)   rt     s    zCompilePackage.uninstall)backendsr"   c              	   C   s>  ddl m} |   | j D ]\}}tj|j }|j D ]\}}| 	||t
| q<|jD ]"}t||j|}	| 	|||	 q`|jD ]:}
|
|vrtd|
 d||
 }| 	||
tj| qq| j D ]f\}}|jD ]V}t|j}t|tjjjsJ tjjj||jd|jd}|||jt !|j" qqdS )a3  
        Sync the package states to the compiled function. This includes the following actions:
          1. Clean up the previously installed states.
          2. Install the compiled functions to global scopes.
          3. Install the precompiled cache entries to ExtraStates on the code object.
        r   )_load_precompile_entryBackend # is not found in the given backendsload)Zguards_serialization_modeshape_code_partsN)#r   r   rt   ro   r   sysmodulesrM   rP   r   	importlibimport_modulerN   r.   FunctionTyper   rQ   rx   r_   Z_dynamodisablerO   rf   rg   rF   r-   ZguardsZGuardsStateZCheckFunctionManagerZoutput_graphr   Zguard_managerr   r9   rG   )r]   r   r   r!   r   r   r   r   Zfunction_namerl   rY   r   Zguarded_coderF   Zcheck_fn_managerr#   r#   r)   install)  sJ    





zCompilePackage.installc                 C   s   |    tt| j dS )Nr[   )ru   rT   rR   ro   r:   r\   r#   r#   r)   cache_entryY  s    zCompilePackage.cache_entry)N)N)N)N)$r;   r<   r=   rH   r   r   rT   rv   rs   r.   r/   rA   rJ   rz   ra   rS   rI   r~   rC   cached_propertyr   
contextlibcontextmanagerr   r   r@   r   r   r   r   ru   
ModuleTyper   rt   r   r   r#   r#   r#   r)   rk      s@    
0rk   c                   @   s,   e Zd ZeedddZedddZdS )EagerCacheArtifactrX   c                   C   s   dS )NZprecompile_eagerr#   r#   r#   r#   r)   rd   `  s    zEagerCacheArtifact.typec                 C   s   t | jS r,   re   r\   r#   r#   r)   ri   d  s    z(EagerCacheArtifact.after_deserializationN)r;   r<   r=   rj   rA   rd   r   ri   r#   r#   r#   r)   r   ^  s   r   c                   @   sf   e Zd ZdZeddddZeeddddZee	dd	d
dZ
ee	eeeeef f dddZdS )DynamoStorezg
    A DynamoStore tracks active CompilePackages, and provides methods to store and retrieve them.
    N)packager"   c                 C   s,   |  }t|}tjt |j|d dS )zKRecords a package to PrecompileContext, so that it can be serialized later.r(   rh   N)r   rf   dumpsr   record_artifactrc   rd   r   )r]   r   r   pickled_resultr#   r#   r)   record_packagem  s
    
zDynamoStore.record_packager   c                 C   s"   t |}tjt ||d dS )zBRecords eager fx graphs to PrecompileContext for testing purposes.r   N)rf   r   r   r   r   rd   )r]   rY   r   r   r#   r#   r)   record_eager_backendu  s    

z DynamoStore.record_eager_backend)r   pathr"   c           
   
   C   s  i }|  }|jD ].}t|}|du r8td| d|||< qzttj|dd}t	
|| W d   n1 sz0    Y  ttj|dd}t	
|| W d   n1 s0    Y  W n< ty }	 z"td| d|	 |	W Y d}	~	n
d}	~	0 0 dS )	zGSaves a package to a given path. Grabs backends from PrecompileContext.Nr   r   rm   wbr   zFailed to save package to : )r   rQ   r   Zserialize_artifact_by_keyrx   openosr   joinrf   dump	Exception)
r]   r   r   backend_contentr   rY   Zserialized_backenddynamo_pathbackend_pather#   r#   r)   save_package|  s     



*.zDynamoStore.save_package)rl   r   r"   c              
   C   s   zt tj|dd}t|}W d   n1 s60    Y  t tj|dd}t|}W d   n1 st0    Y  W n: ty } z"td| d| |W Y d}~n
d}~0 0 | D ]\}}	|		 ||< qt
||}
|
|fS )zULoads a package from a given path and returns it plus a list of deserialized backendsrm   rbNr   z!Failed to load package from path r   )r   r   r   r   rf   r   r   rx   r   ri   rk   )r]   rl   r   r   r   r   r   r   rY   r   r   r#   r#   r)   load_package  s    (,,
zDynamoStore.load_package)r;   r<   r=   rH   rk   r   rI   r   r   rA   r   r6   rS   r   r#   r#   r#   r)   r   h  s   r   )+rH   r   ZdataclassesrC   r   r   loggingr   rf   r^   r   r.   collections.abcr   typingr   r   r   r_   Ztorch._inductor.packageZ torch._dynamo.precompile_contextr   r   Ztorch.compiler._cacher   Zbytecode_transformationr
   	getLoggerr;   loggerZ	dataclassr   rE   rA   rI   rJ   rK   rT   registerrc   rk   r   r   r#   r#   r#   r)   <module>   sH   


,


	 O	