o
    GZh[>                     @   s  d Z ddlmZ ddlmZ ddlmZ ddlmZ ddl	m
Z
mZ ddlmZ ddlmZ dd	lmZ dd
lmZ ddlmZmZ ddlmZ ddlmZmZ ddlmZmZmZ ddl m!Z!m"Z" ddl#m$Z$m%Z% ddl&m'Z'm(Z(m)Z)m*Z*m+Z+m,Z,m-Z-m.Z. ddl/m0Z0 ddl1m2Z2 ddl3m4Z4 G dd de'Z5G dd dZ6G dd dZ7G dd dZ8e6e8e8e7dZ9G dd de)e(Z:G d d! d!e+Z;G d"d# d#e)Z<d$S )%zq
Joint Random Variables Module

See Also
========
sympy.stats.rv
sympy.stats.frv
sympy.stats.crv
sympy.stats.drv
    )prod)Basic)Lambda)S)DummySymbol)sympify)
ProductSetIndexed)Product)Sum	summation)Tuple)Integral	integrate)ImmutableMatrixmatrix2numpy
list2numpy)SingleContinuousDistributionSingleContinuousPSpace)SingleDiscreteDistributionSingleDiscretePSpace)ProductPSpaceNamedArgsMixinDistributionProductDomainRandomSymbolrandom_symbolsSingleDomain_symbol_converter)iterable)
filldedent)import_modulec                   @   s   e Zd ZdZdd Zedd Zedd Zedd	 Zed
d Z	edd Z
edd Zedd Zdd Zdd Zd$ddZdd Zdd Zd%d d!Zd"d# ZdS )&JointPSpacezt
    Represents a joint probability space. Represented using symbols for
    each component and a distribution.
    c                 C   s>   t |tr
t||S t |trt||S t|}t| ||S N)
isinstancer   r   r   r   r    r   __new__)clssymdist r+   C/var/www/auris/lib/python3.10/site-packages/sympy/stats/joint_rv.pyr'   )   s   



zJointPSpace.__new__c                 C   s   | j jS r%   )domainsetselfr+   r+   r,   r.   1   s   zJointPSpace.setc                 C   
   | j d S )Nr   argsr/   r+   r+   r,   symbol5      
zJointPSpace.symbolc                 C   r1   N   r2   r/   r+   r+   r,   distribution9   r5   zJointPSpace.distributionc                 C   s   t | j| S r%   )JointRandomSymbolr4   r/   r+   r+   r,   value=      zJointPSpace.valuec                 C   s>   | j j}t|trtt|jS t|tr|jd d S tj	S )Nr   )
r8   r.   r&   r	   r   lenr3   r   limitsZOne)r0   _setr+   r+   r,   component_countA   s   

zJointPSpace.component_countc                    s"    fddt  jD } j| S )Nc                       g | ]}t  j|qS r+   r   r4   .0ir/   r+   r,   
<listcomp>L       z#JointPSpace.pdf.<locals>.<listcomp>)ranger@   r8   r0   r)   r+   r/   r,   pdfJ   s   
zJointPSpace.pdfc                 C   s0   t | j}|st| j| jjS tdd |D  S )Nc                 S      g | ]}|j jqS r+   )pspacer-   rD   rvr+   r+   r,   rF   T       z&JointPSpace.domain.<locals>.<listcomp>)r   r8   r   r4   r.   r   r0   rvsr+   r+   r,   r-   O   s   
zJointPSpace.domainc                 C   s   | j j| S r%   )r.   r3   )r0   indexr+   r+   r,   component_domainV   s   zJointPSpace.component_domainc           
         s   j }|trtd fddt|D }dd |D }tt||}t fdd|D fdd|D }d}t|D ]}||vr]||  j	j
j|  t|| ||< |d	7 }q? j	jrqtt j	| g|R  }	n j	jrtt j	| g|R  }	|	|S )
Nz_Marginal distributions cannot be computed for symbolic dimensions. It is a work under progress.c                    rA   r+   rB   rC   r/   r+   r,   rF   ^   rG   z5JointPSpace.marginal_distribution.<locals>.<listcomp>c                 S   s   g | ]}t t|qS r+   )r   strrC   r+   r+   r,   rF   _   rG   c                 3   s$    | ]}t tt j|V  qd S r%   )r   rT   r   r4   rC   r/   r+   r,   	<genexpr>a   s   " z4JointPSpace.marginal_distribution.<locals>.<genexpr>c                    s   g | ]	}| vr|gqS r+   r+   rC   )r)   r+   r,   rF   b       r   r7   )r@   Zatomsr   
ValueErrorrH   dictziptupleappendr8   r.   r3   is_Continuousr   r   is_Discreter   xreplace)
r0   indicescountorigZall_symsZreplace_dictr>   rR   rE   fr+   rI   r,   marginal_distributionY   s(   

z!JointPSpace.marginal_distributionNFc                    s   t fddtjD } p| t fdd|D s|S |j } D ]%}t|tr>||tt|j	|j
d i}q&t|trK|||ji}q&jt|v rYttdt fdd|D }t|g|R  S )Nc                 3   s    | ]} j | V  qd S r%   )r:   rC   r/   r+   r,   rU   p       z2JointPSpace.compute_expectation.<locals>.<genexpr>c                 3   s    | ]}| v V  qd S r%   r+   rC   rQ   r+   r,   rU   r   s    r7   zq
            Expectations of expression with unindexed joint random symbols
            cannot be calculated yet.c                 3   s:    | ]}t t|j|jd   jjj|jd   fV  qdS )r7   N)r   rT   baser3   r8   r.   rM   r/   r+   r,   rU   ~   s    )rZ   rH   r@   anyrJ   r&   r   r^   rT   rf   r3   r   r4   r:   r   NotImplementedErrorr"   r   )r0   exprrQ   evaluatekwargssymsrN   r>   r+   )rQ   r0   r,   compute_expectationo   s"   

"
zJointPSpace.compute_expectationc                 C      t  r%   rh   r0   	conditionr+   r+   r,   where      zJointPSpace.wherec                 C   rn   r%   ro   )r0   ri   r+   r+   r,   compute_density   rs   zJointPSpace.compute_densityr+   scipyc                 C   s   t | j| | jj|||diS )zo
        Internal sample method

        Returns dictionary mapping RandomSymbol to realization value.
        )libraryseed)r   r4   r8   sample)r0   sizerv   rw   r+   r+   r,   rx      s   zJointPSpace.samplec                 C   rn   r%   ro   rp   r+   r+   r,   probability   rs   zJointPSpace.probability)NFr+   ru   N)__name__
__module____qualname____doc__r'   propertyr.   r4   r8   r:   r@   rJ   r-   rS   rc   rm   rr   rt   rx   rz   r+   r+   r+   r,   r$   $   s0    








	r$   c                   @   &   e Zd ZdZdddZedd ZdS )SampleJointScipyz7Returns the sample from scipy of the given distributionNc                 C      |  |||S r%   )_sample_scipyr(   r*   ry   rw   r+   r+   r,   r'         zSampleJointScipy.__new__c           	         s   ddl }|du st|tr|jj|d n| ddlm  fdd fdd fddd	}d
d dd dd d	}| }|jj	|vrJdS ||jj	 ||}|
|||jj	 | S )zSample from SciPy.r   Nrw   )statsc                    s$   j jt| j t| j| dS )N)meancovry   random_state)multivariate_normalrQ   r   muflattensigmar*   ry   
rand_stateZscipy_statsr+   r,   <lambda>   s    z0SampleJointScipy._sample_scipy.<locals>.<lambda>c                    s   j jt| jt | dS )N)alphary   r   )	dirichletrQ   r   r   floatr   r   r   r+   r,   r      s    c                    s&   j jt| jt| jt | dS )N)npry   r   )multinomialrQ   intr   r   r   r   r   r   r   r+   r,   r      s    ZMultivariateNormalDistributionZMultivariateBetaDistributionZMultinomialDistributionc                 S      t | j jS r%   r   r   r   shaper*   r+   r+   r,   r          c                 S   r   r%   r   r   r   r   r   r+   r+   r,   r      r   c                 S   r   r%   r   r   r   r   r   r+   r+   r,   r      r   )numpyr&   r   randomdefault_rngru   r   keys	__class__r|   reshape)	r(   r*   ry   rw   r   Zscipy_rv_mapsample_shape	dist_listsamplesr+   r   r,   r      s$   zSampleJointScipy._sample_scipyr%   )r|   r}   r~   r   r'   classmethodr   r+   r+   r+   r,   r      s
    
r   c                   @   r   )SampleJointNumpyz7Returns the sample from numpy of the given distributionNc                 C   r   r%   )_sample_numpyr   r+   r+   r,   r'      r   zSampleJointNumpy.__new__c           	         s   ddl }|du st|tr|jj|d n|  fdd fdd fddd}d	d d
d dd d}| }|jj|vrAdS ||jj |t|}|	|||jj | S )zSample from NumPy.r   Nr   c                    s$    j t| jt t| jt|dS )N)r   r   ry   )r   r   r   r   r   r   r   r   r+   r,   r      s    z0SampleJointNumpy._sample_numpy.<locals>.<lambda>c                    s    j t| jt |dS )N)r   ry   )r   r   r   r   r   r   r   r+   r,   r      s    c                    s"    j t| jt| jt |dS )N)r   Zpvalsry   )r   r   r   r   r   r   r   r   r   r+   r,   r      s    r   c                 S   r   r%   r   r   r+   r+   r,   r      r   c                 S   r   r%   r   r   r+   r+   r,   r      r   c                 S   r   r%   r   r   r+   r+   r,   r      r   )
r   r&   r   r   r   r   r   r|   r   r   )	r(   r*   ry   rw   r   Znumpy_rv_mapr   r   r   r+   r   r,   r      s"   


zSampleJointNumpy._sample_numpyr%   )r|   r}   r~   r   r'   r   r   r+   r+   r+   r,   r      
    
r   c                   @   r   )SampleJointPymcz6Returns the sample from pymc of the given distributionNc                 C   r   r%   )_sample_pymcr   r+   r+   r,   r'      r   zSampleJointPymc.__new__c           	   	      s  zddl  W n ty   ddl Y nw  fdd fdd fddd}dd d	d d
d d}| }|jj|vr>dS ddl}|d|j	  
 # ||jj |  jt|dd|ddddd d }W d   n1 suw   Y  ||||jj | S )zSample from PyMC.r   Nc                    s2    j dt| jt t| jtd| jjd fdS )NXr7   r   )r   r   r   )ZMvNormalr   r   r   r   r   r   r   pymcr+   r,   r      s   z.SampleJointPymc._sample_pymc.<locals>.<lambda>c                    s    j dt| jt dS )Nr   )a)Z	Dirichletr   r   r   r   r   r   r+   r,   r      s   c                    s.    j dt| jt| jt dt| jfdS )Nr   r7   )r   r   r   )ZMultinomialr   r   r   r   r   r   r=   r   r   r+   r,   r      s   r   c                 S   r   r%   r   r   r+   r+   r,   r      r   c                 S   r   r%   r   r   r+   r+   r,   r      r   c                 S   r   r%   r   r   r+   r+   r,   r      r   pymc3r7   F)ZdrawschainsZprogressbarZrandom_seedZreturn_inferencedataZcompute_convergence_checksr   )r   ImportErrorr   r   r   r|   logging	getLoggersetLevelERRORZModelrx   r   r   )	r(   r*   ry   rw   Zpymc_rv_mapr   r   r   r   r+   r   r,   r      s.   



(zSampleJointPymc._sample_pymcr%   )r|   r}   r~   r   r'   r   r   r+   r+   r+   r,   r      r   r   )ru   r   r   r   c                   @   sN   e Zd ZdZdZdd Zedd Zedd Zd	d
 Z	dddZ
dd ZdS )JointDistributionz
    Represented by the random variables part of the joint distribution.
    Contains methods for PDF, CDF, sampling, marginal densities, etc.
    rJ   c                 G   sP   t tt|}tt|D ]}t|| t rt|| ||< qtj| g|R  S r%   )	listmapr   rH   r=   r&   r   r   r'   )r(   r3   rE   r+   r+   r,   r'     s   zJointDistribution.__new__c                 C   s
   t | jS r%   )r   symbolsr/   r+   r+   r,   r-   %  r5   zJointDistribution.domainc                 C   s   | j jd S r6   )densityr3   r/   r+   r+   r,   rJ   )  r;   zJointDistribution.pdfc                 C   s   t |tstd|t|f | }| jjj}| t	dd | j
D }tt|D ]/}|| jrDt||| || j|||  f}q+|| jrZt||| || j|||  f}q+|S )Nz!%s should be of type dict, got %sc                 s   s    | ]}|j d  V  qdS r   Nr2   rC   r+   r+   r,   rU   2  rd   z(JointDistribution.cdf.<locals>.<genexpr>)r&   rX   rW   typer   r-   r.   ZsetsrJ   rZ   r   rH   r=   r\   r   infr]   r   )r0   otherrQ   r?   ri   rE   r   r+   r+   r,   cdf-  s    





zJointDistribution.cdfr+   ru   Nc                 C   sb   d}||vrt dt| t|std| t| | ||d}|dur'|S t d| jj|f )z, A random realization from the distribution )ru   r   r   r   z&Sampling from %s is not supported yet.zFailed to import %sr   Nz4Sampling for %s is not currently implemented from %s)rh   rT   r#   rW   _get_sample_class_jrvr   r|   )r0   ry   rv   rw   Z	librariesZsampsr+   r+   r,   rx   <  s   
zJointDistribution.samplec                 G   
   | j | S r%   r   r0   r3   r+   r+   r,   __call__O     
zJointDistribution.__call__r{   )r|   r}   r~   r   Z	_argnamesr'   r   r-   rJ   r   rx   r   r+   r+   r+   r,   r     s    


r   c                   @   s   e Zd ZdZdd ZdS )r9   zg
    Representation of random symbols with joint probability distributions
    to allow indexing."
    c                 C   sD   t | jtr | jj|kdkrtd| j| jjd f t| |S d S )NTz$Index keys for %s can only up to %s.r7   )r&   rL   r$   r@   rW   namer   )r0   keyr+   r+   r,   __getitem__W  s   
zJointRandomSymbol.__getitem__N)r|   r}   r~   r   r   r+   r+   r+   r,   r9   R  s    r9   c                   @   sX   e Zd ZdZdd Zdd Zedd Zedd	 Zd
d Z	dd Z
dd Zdd ZdS )MarginalDistributionz
    Represents the marginal distribution of a joint probability space.

    Initialised using a probability distribution and random variables(or
    their indexed components) which should be a part of the resultant
    distribution.
    c                 G   s   t |dkrt|d rt|d }tdd |D s!ttdtdd |D }t|t	s:t t
|dkr:|S t| ||S )Nr7   r   c                 s   s    | ]
}t |ttfV  qd S r%   )r&   r   r   rM   r+   r+   r,   rU   l  s    z/MarginalDistribution.__new__.<locals>.<genexpr>zMarginal distribution can be
             intitialised only in terms of random variables or indexed random
             variablesc                 s   s    | ]}|V  qd S r%   r+   rM   r+   r+   r,   rU   p  s    )r=   r!   rZ   allrW   r"   r   Zfromiterr&   r   r   r   r'   )r(   r*   rQ   r+   r+   r,   r'   i  s   zMarginalDistribution.__new__c                 C   s   d S r%   r+   r/   r+   r+   r,   checku  s   zMarginalDistribution.checkc                 C   s&   dd | j d D }tdd |D  S )Nc                 S   s   g | ]	}t |tr|qS r+   )r&   r   rC   r+   r+   r,   rF   z  rV   z,MarginalDistribution.set.<locals>.<listcomp>r7   c                 S   rK   r+   )rL   r.   rM   r+   r+   r,   rF   {  rO   )r3   r	   rP   r+   r+   r,   r.   x  s   zMarginalDistribution.setc                 C   s   | j d }dd |D S )Nr7   c                 S   s   h | ]}|j jqS r+   )rL   r4   rM   r+   r+   r,   	<setcomp>  rO   z/MarginalDistribution.symbols.<locals>.<setcomp>r2   rP   r+   r+   r,   r   }  s   
zMarginalDistribution.symbolsc                    s   | j d | j d }  fddt|D }t|tr8t|jj }tdddtfdd	|D }||}n	td
d	  D }t	|| 
|| S )Nr   r7   c                    s   g | ]}| vr|qS r+   r+   rC   re   r+   r,   rF     rG   z,MarginalDistribution.pdf.<locals>.<listcomp>xT)realc                 3   s    | ]}t  |V  qd S r%   r
   rC   )r   r+   r,   rU     rd   z+MarginalDistribution.pdf.<locals>.<genexpr>c                 s   s,    | ]}t |tr|jjn|jd  V  qdS r   )r&   r   rL   r4   r3   rM   r+   r+   r,   rU     s   * )r3   r   r&   r   r=   r-   r   rZ   rJ   r   compute_pdf)r0   r   ri   marginalise_outr`   rl   r+   )rQ   r   r,   rJ     s   
zMarginalDistribution.pdfc                 C   s4   |D ]}d}t |tr|jj}| || |}q|S r6   )r&   r   rL   rJ   r   )r0   ri   rQ   rN   Zlpdfr+   r+   r,   r     s   
z MarginalDistribution.compute_pdfc                 C   s   ddl m} t|tr|jj}nt|tr"|j|j|j	d }|
||jji}|jjr:t||jj|f}|S |jjrW|tjtjtjfv rN|j|jf}|||jj|f}|S )Nr   )r   r7   )sympy.concrete.summationsr   r&   r   rL   r.   r   rf   rS   r3   r^   r4   r\   r   r]   r   ZIntegersZNaturalsZ	Naturals0r   sup)r0   ri   rN   r   domr+   r+   r,   r     s    


z$MarginalDistribution.marginalise_outc                 G   r   r%   r   r   r+   r+   r,   r     r   zMarginalDistribution.__call__N)r|   r}   r~   r   r'   r   r   r.   r   rJ   r   r   r   r+   r+   r+   r,   r   `  s    

r   N)=r   mathr   Zsympy.core.basicr   Zsympy.core.functionr   Zsympy.core.singletonr   Zsympy.core.symbolr   r   Zsympy.core.sympifyr   Zsympy.sets.setsr	   Zsympy.tensor.indexedr   Zsympy.concrete.productsr   r   r   r   Zsympy.core.containersr   Zsympy.integrals.integralsr   r   Zsympy.matricesr   r   r   Zsympy.stats.crvr   r   Zsympy.stats.drvr   r   Zsympy.stats.rvr   r   r   r   r   r   r   r    Zsympy.utilities.iterablesr!   Zsympy.utilities.miscr"   Zsympy.externalr#   r$   r   r   r   r   r   r9   r   r+   r+   r+   r,   <module>   s@    
(q''-<