a
    hZ                     @   s0  d Z ddlZddlZddlZddlmZmZmZ ddlm	Z	m
Z
mZ ddlmZmZ ddlmZmZmZmZmZ dd	lmZmZ dd
lmZ erddlmZ ddlmZ dZG dd deZG dd deZ G dd de Z!G dd de Z"G dd de Z#G dd de Z$G dd de$Z%G dd de Z&dS )a  
This module provides iterator-related variable tracking functionality for Dynamo.
It implements variable classes for handling Python iterators and itertools functions
during symbolic execution and tracing.

The module includes:
- Base iterator variable classes for tracking iterator state
- Implementations of built-in iterators (zip, map, filter)
- Support for itertools functions (product, accumulate, combinations, etc.)
- Mutation tracking and reconstruction capabilities for iterator operations

These classes integrate with Dynamo's variable tracking system to enable proper
handling of iterator operations during code transformation and optimization.
    N)OptionalTYPE_CHECKINGUnion   )graph_break_hints	polyfills	variables)create_call_functioncreate_instruction)handle_observed_exceptionObservedUserStopIterationraise_observed_exceptionunimplemented_v2	UserError   )ValueMutationNewVariableTracker)ConstantVariable)	PyCodegen)InstructionTranslatori  c                       sP   e Zd Zdd fddZedddZdd Zd	d
ddd fddZ  ZS )ItertoolsVariableNreturnc                    s   t  jf i | || _d S N)super__init__value)selfr   kwargs	__class__ J/var/www/auris/lib/python3.9/site-packages/torch/_dynamo/variables/iter.pyr   -   s    zItertoolsVariable.__init__c                 C   s   d| j  dS )NzItertoolsVariable()r   r   r!   r!   r"   __repr__1   s    zItertoolsVariable.__repr__c                 C   s   | j S r   r$   r%   r!   r!   r"   as_python_constant4   s    z$ItertoolsVariable.as_python_constantr   zlist[VariableTracker]zdict[str, VariableTracker]r   )txargsr   r   c                    s  j tju r\s\tfdd D r\fdd D }dd tj| D }tj|t dS j tju rddlm	} t
d	d  D rtd
d d  d ddt ddh  g tjd t dv r d r d }dv r t dkr d j}n^t dkr: d j}nDt dkrV|tjj}n(tdd d  d dg tjd n8tdd d  d d  d ddgtjd g }d}	|	d ur||	 |D ]}
|	d u r|
}	nxz||	|
gi }	W nb tyh } zHtdd d  d d| d|
 |	 dg tj|d  W Y d }~n
d }~0 0 ||	 qtj|t dS j tju rst dkr d r d  r d } d  }g }t||D ]}
|tt |
 qtj|t dS j tj!u rt
d!d  D rtd"d d  d d#dt d$h  g tjd  fd%d&t dkrĈ d rĈ d }n8td'd d  d d(  d d)d*gtjd d$v rfd+d,}nfd-d,}g }zbtj!||d.D ]N\}}|tjtj"#|rbtj"$|n|tjt |t dgt d q8W nP ty } z6td/d d  d d0g tj|d  W Y d }~n
d }~0 0 tj|t dS j tj%u r2t dk rtj& d1t iS 't()t*j% S j tj+u rRtj, d1t iS j tj-u rrtj. d1t iS t/  S d S )2Nc                 3   s   | ]}|  V  qd S r   )has_unpack_var_sequence.0argr(   r!   r"   	<genexpr>B       z2ItertoolsVariable.call_function.<locals>.<genexpr>c                    s   g | ]}|  qS r!   )unpack_var_sequencer+   r.   r!   r"   
<listcomp>D   r0   z3ItertoolsVariable.call_function.<locals>.<listcomp>c                 S   s   g | ]}t t|qS r!   r   TupleVariablelist)r,   itemr!   r!   r"   r2   E   s   )mutation_typer   )BuiltinVariablec                 s   s   | ]}|d vV  qdS ))initialfuncNr!   )r,   keyr!   r!   r"   r/   N   r0   z+Unsupported kwargs for itertools.accumulatecall_function  z,Expected kwargs: 'initial', 'func', but got ,r9   r:   gb_typecontextexplanationhints)r   r   r   r   z*Unsupported `func` in itertools.accumulatezDynamo does not know how to get the function to use for itertools.accumulate. itertools.accumulate expects the `func` as the second argument or as a keyword argument.z.Unsupported arguments for itertools.accumulatezBDynamo does not know how to trace itertools.accumulate with args: z and kwargs: z. itertools.accumulate expects an iterable, an optional binary function for accumulation, and an optional initial value to set the starting state.z<Make sure the arguments to itertools.accumulate are correct.z:Unexpected failure during itertools.accumulate() iterationzOUnexpected failure in invoking function during accumulate. Failed running func (r#   )r@   rA   rB   rC   Zfrom_excc                 s   s   | ]}|d kV  qdS )r;   Nr!   )r,   kwr!   r!   r"   r/      r0   z(Unsupported kwargs for itertools.groupbyz Expected kwargs: 'key', but got r;   c                    sd   t | tjr|  S t | tjr(|  S tdd d  d dtt|  dg t	j
d d S )Nz*Unsupported key type for itertools.groupbyr<   r=   zCDynamo does not know how to trace itertools.groupby with key type: zJ. We only support grouping keys that are constants (int, float, str, etc.)r?   )
isinstancer   ZSymNodeVariableZevaluate_exprr   r'   r   strtyper   SUPPORTABLEr;   )r)   r   r   r!   r"   retrieve_const_key   s    
z;ItertoolsVariable.call_function.<locals>.retrieve_const_keyz+Unsupported arguments for itertools.groupbyz?Dynamo does not know how to trace itertools.groupby with args: ze. itertools.groupby expects an iterable to group and an optional key function to determine groupings.z9Make sure the arguments to itertools.groupby are correct.c                    s     d| gi S )Nr;   )getcall_functionx)r   rK   r(   r!   r"   keyfunc   s    z0ItertoolsVariable.call_function.<locals>.keyfuncc                    s    | S r   r!   rN   )rK   r!   r"   rP      s    rJ   z7Unexpected failure during itertools.groupby() iterationz6Unexpected failure in invoking function during groupbyr7   )0r   	itertoolsproductallr   ZListIteratorVariabler   
accumulatebuiltinr8   anykeysr   joinsetr   Z
USER_ERRORlenr*   r1   rM   operatoraddrI   rL   append	ExceptionZ	DIFFICULTcombinationsZis_python_constantr'   r4   r5   groupbyr   Z
is_literalcreaterepeatRepeatIteratorVariableZinline_user_function_returnr   buildr   countCountIteratorVariablecycleCycleIteratorVariabler   )r   r(   r)   r   Zseqsitemsr8   seqr:   accr6   eiterablerrP   resultkvr   )r)   r   rK   r   r(   r"   rM   7   s\   	











zItertoolsVariable.call_function)	__name__
__module____qualname__r   rG   r&   r'   rM   __classcell__r!   r!   r   r"   r   ,   s   r   c                       sX   e Zd Zdd fddZdd Zee dddZddd	d
Ze	dddZ
  ZS )IteratorVariableNr   c                    s   t  jf i | d S r   )r   r   )r   r   r   r!   r"   r     s    zIteratorVariable.__init__c                 C   s"   t dd|  ddg tjd d S )NzUnimplemented next() callnext(r#   z(This abstract method must be implementedr?   )r   r   Z
DYNAMO_BUGr   r(   r!   r!   r"   next_variable  s    
zIteratorVariable.next_variablec                 C   s   g }|  ||j |S r   )force_apply_to_var_sequencer]   )r   r(   ro   r!   r!   r"   force_unpack_var_sequence  s    z*IteratorVariable.force_unpack_var_sequencec                 C   s8   z||  | W q  ty0   t| Y q4Y q 0 q d S r   )ry   r   r   )r   r(   fnr!   r!   r"   rz     s
    z,IteratorVariable.force_apply_to_var_sequencec                 C   s   dS )NTr!   rx   r!   r!   r"   has_force_unpack_var_sequence&  s    z.IteratorVariable.has_force_unpack_var_sequence)rr   rs   rt   r   ry   r5   r   r{   rz   boolr}   ru   r!   r!   r   r"   rv     s
   
rv   c                       s:   e Zd Zedd fddZdd Zddd	d
Z  ZS )rc   N)r6   r   c                    s   t  jf i | || _d S r   )r   r   r6   )r   r6   r   r   r!   r"   r   +  s    zRepeatIteratorVariable.__init__c                 C   s   | j S r   )r6   rx   r!   r!   r"   ry   0  s    z$RepeatIteratorVariable.next_variabler   codegenc                    s0      fdd  | j  tdd d S )Nc                      s      t dgS )Nrb   extend_outputZcreate_load_python_modulerQ   Zcreate_load_attrr!   r   r!   r"   <lambda>5  s   z4RepeatIteratorVariable.reconstruct.<locals>.<lambda>r   F)add_push_nullr6   r   r	   r   r   r!   r   r"   reconstruct3  s
    

z"RepeatIteratorVariable.reconstruct)rr   rs   rt   r   r   ry   r   ru   r!   r!   r   r"   rc   *  s   rc   c                       s>   e Zd Zdeedd fddZdd Zd	d
ddZ  ZS )rf   r   r   N)r6   stepr   c                    sJ   t  jf i | t|ts&t|}t|ts:t|}|| _|| _d S r   )r   r   rF   r   r   ra   r6   r   )r   r6   r   r   r   r!   r"   r   A  s    



zCountIteratorVariable.__init__c                 C   s<   |   sJ | j}|jj|  | j|d| jgi | _|S )N__add__)
is_mutabler6   outputside_effectsmutationZcall_methodr   )r   r(   Zold_itemr!   r!   r"   ry   J  s
    z#CountIteratorVariable.next_variabler   r   c                    s:      fdd  | j  | j  tdd d S )Nc                      s      t dgS )Nre   r   r!   r   r!   r"   r   S  s   z3CountIteratorVariable.reconstruct.<locals>.<lambda>r   F)r   r6   r   r   r	   r   r!   r   r"   r   Q  s    


z!CountIteratorVariable.reconstruct)r   r   )rr   rs   rt   intr   ry   r   ru   r!   r!   r   r"   rf   @  s   	rf   c                       s@   e Zd Zdeeee  eee dd fddZdd Z	  Z
S )	rh   Nr   )iteratorsavedsaved_indexr6   r   c                    s:   |d u rg }t  jf i | || _|| _|| _|| _d S r   )r   r   r   r   r   r6   )r   r   r   r   r6   r   r   r!   r"   r   `  s    zCycleIteratorVariable.__init__c                 C   s   |   sJ | jd urzv| j|}t| jtkrPtdd|  ddt g d |jj	|  | j
| || _| jd u r| |W S | jW S  ty   t| d | _| | Y S 0 nBt| jdkr|jj	|  | jd t| j | _| jS tt| d S )Nz4input iterator to itertools.cycle has too many itemsrw   r#   z0Has reached internal Dynamo max iterator limit: r?   r   r   )r   r   ry   rZ   r   MAX_ITERATOR_LIMITr   r   r   r   r]   r6   r   r   r   r   StopIteration)r   r(   Znew_itemr!   r!   r"   ry   p  s4    


z#CycleIteratorVariable.next_variable)Nr   N)rr   rs   rt   rv   r   r5   r   r   r   ry   ru   r!   r!   r   r"   rh   _  s      
rh   c                       s   e Zd ZdZddhejZdeeee ef  e	dd fddZ
d	d
 Ze	dddZed dddZdd ZddddZddddZ  ZS )ZipVariablez$
    Represents zip(*iterables)
    indexstrictFN)	iterablesr   r   c                    s6   t  jf i | t|ts J || _d| _|| _d S Nr   )r   r   rF   r5   r   r   r   )r   r   r   r   r   r!   r"   r     s
    zZipVariable.__init__c                 C   s   t S r   )zipr%   r!   r!   r"   python_type  s    zZipVariable.python_typer   c                    s   t  fdd| jD S )Nc                 3   s"   | ]}t |tp| V  qd S r   )rF   r5   r*   )r,   itr.   r!   r"   r/     s   z6ZipVariable.has_unpack_var_sequence.<locals>.<genexpr>)rS   r   rx   r!   r.   r"   r*     s    z#ZipVariable.has_unpack_var_sequencer   c                 C   s~   |  |sJ g }| jD ]4}t|tr<||| jd   q||| q| jr^d| jini }t|i |}dd |D S )Nr   c                 S   s   g | ]}t t|qS r!   r3   )r,   varr!   r!   r"   r2     r0   z3ZipVariable.unpack_var_sequence.<locals>.<listcomp>)	r*   r   rF   r5   r]   r   r1   r   r   )r   r(   r   r   r   Zzippedr!   r!   r"   r1     s    

zZipVariable.unpack_var_sequencec                    s   |   sJ | j g } fdd}z&t| jD ]\}}||| q0W nt ty   | jr|dkr| jD ]4}z|| W n ty   t Y qlY n0  qql t tt	dd  Y n0 j
j|  |  jd7  _t|S )Nc                    s6   t | tr( t| kr tt |   S | S d S r   )rF   r5   rZ   r   r   ry   )r   	old_indexr(   r!   r"   get_item  s
    

z+ZipVariable.next_variable.<locals>.get_itemr   z3zip() has one argument of len differing from othersr   )r   r   	enumerater   r]   r   r   r   r   
ValueErrorr   r   r   r   r4   )r   r(   r)   r   idxr   r!   r   r"   ry     s:    

zZipVariable.next_variabler   r   c                 C   sR   | j D ]F}t|trD|| jd  }|| |tdt|d q|| qd S NBUILD_TUPLEr-   )r   rF   r5   r   foreachappend_outputr
   rZ   )r   r   r   remaining_itemsr!   r!   r"   reconstruct_items  s    


zZipVariable.reconstruct_itemsc              	      s    j  fdddd |    tdt| jd tjdkrt  	d 	| j
td	d
dtdd
dg n tddd d S )Nc                      s     ddS )Nbuiltinsr   Zload_import_fromr!   r   r!   r"   r     r0   z)ZipVariable.reconstruct.<locals>.<lambda>TZcall_function_exr   r   )   
   r   	BUILD_MAPr   CALL_FUNCTION_EXr   )r   r   r   r
   rZ   r   sysversion_infor   Zcreate_load_constr   r   r!   r   r"   r     s     




	zZipVariable.reconstruct)F)rr   rs   rt   __doc__rv   _nonvar_fieldsr5   r   r   r~   r   r   r*   r1   ry   r   r   ru   r!   r!   r   r"   r     s$   	 *r   c                       sj   e Zd ZdZeeeee ef  dd fddZdd Ze	dd	d
Z
 fddZddddZ  ZS )MapVariablez(
    Represents map(fn, *iterables)
    N)r|   r   r   c                    s   t  j|fi | || _d S r   )r   r   r|   )r   r|   r   r   r   r!   r"   r     s    zMapVariable.__init__c                 C   s   t S r   )mapr%   r!   r!   r"   r     s    zMapVariable.python_typer   c                 C   s   dS )NFr!   rx   r!   r!   r"   r*     s    z#MapVariable.has_unpack_var_sequencec                    s   t  |}| j||ji S r   )r   ry   r|   rM   ri   )r   r(   r)   r   r!   r"   ry     s    zMapVariable.next_variabler   r   c                    sV    j  fdddd  | j |    tdt| jd dtdd	dg d S )
Nc                      s     ddS )Nr   r   r   r!   r   r!   r"   r      r0   z)MapVariable.reconstruct.<locals>.<lambda>Tr   r   r   r   r   r   )r   r|   r   r   r
   rZ   r   r   r!   r   r"   r     s    


zMapVariable.reconstruct)rr   rs   rt   r   r   r5   r   r   r   r~   r*   ry   r   ru   r!   r!   r   r"   r     s   	r   c                       s   e Zd ZdZdhejZeeee ef dd fddZ	dd Z
ed	d
dZed d	ddZdd ZddddZddddZ  ZS )FilterVariablez)
    Represents filter(fn, iterable)
    r   N)r|   rm   r   c                    s(   t  jf i | || _|| _d| _d S r   )r   r   r|   rm   r   )r   r|   rm   r   r   r!   r"   r   6  s    zFilterVariable.__init__c                 C   s   t S r   )filterr%   r!   r!   r"   r   A  s    zFilterVariable.python_typer   c                 C   s   t | jtp| j|S r   )rF   rm   r5   r*   rx   r!   r!   r"   r*   D  s    z&FilterVariable.has_unpack_var_sequencer   c                 C   sZ   |  |sJ d }t| jtr0| j| jd  }n| j|}| j||i }t	|ggS r   )
r*   rF   rm   r5   r   r1   r|   rM   r   r4   )r   r(   r   filteredr!   r!   r"   r1   I  s    z"FilterVariable.unpack_var_sequencec                    s^    fdd}| }  j d7  _  j|gi }ttj|gi }| r|S qd S )Nc                     sD    j } t jtr4| t jkr*tt  j|  S  jS d S r   )r   rF   rm   r5   rZ   r   r   ry   )r   rx   r!   r"   _nextT  s    

z+FilterVariable.next_variable.<locals>._nextr   )r   r|   rM   r   ZUserFunctionVariabler   	predicater'   )r   r(   r   r6   resZpred_resr!   rx   r"   ry   S  s    zFilterVariable.next_variabler   r   c                 C   sL   t | jtr>| j| jd  }|| |tdt|d n
|| j d S r   )rF   rm   r5   r   r   r   r
   rZ   )r   r   r   r!   r!   r"   r   h  s    
z FilterVariable.reconstruct_itemsc                    s:      fdd  | j |    tdd d S )Nc                      s     ddS )Nr   r   r   r!   r   r!   r"   r   s  r0   z,FilterVariable.reconstruct.<locals>.<lambda>r   F)r   r|   r   r   r	   r   r!   r   r"   r   r  s    

zFilterVariable.reconstruct)rr   rs   rt   r   rv   r   r   r   r5   r   r   r~   r*   r1   ry   r   r   ru   r!   r!   r   r"   r   ,  s   

r   )'r   rQ   r[   r   typingr   r   r    r   r   r   Zbytecode_transformationr	   r
   excr   r   r   r   r   baser   r   Zconstantr   Ztorch._dynamo.codegenr   Ztorch._dynamo.symbolic_convertr   r   r   rv   rc   rf   rh   r   r   r   r!   r!   r!   r"   <module>   s,    \#0w&