a
    h@n                     @   s   d dl Z d dlmZ d dlmZ d dlZd dlmZ ddlmZ	m
Z
 g dZeeee
eee  dd	d
ZG dd deZG dd dejjZG dd dejjZG dd dejjZG dd dejjZdS )    N)Enum)Optional)Tensor   )
functionalInterpolationMode)AutoAugmentPolicyAutoAugmentRandAugmentTrivialAugmentWideAugMiximgop_name	magnitudeinterpolationfillc                 C   s   |dkr>t j| dddgdtt|dg||ddgd} n|dkr|t j| dddgddtt|g||ddgd} n|dkrt j| dt|dgd|ddg|d} nP|d	krt j| ddt|gd|ddg|d} n |d
krt j| |||d} n|dkrt | d| } n|dkr2t | d| } n|dkrNt 	| d| } n|dkrjt 
| d| } n|dkrt | t|} nv|dkrt | |} n^|dkrt | } nH|dkrt | } n2|dkrt | } n|dkrntd| d| S )NShearX        r         ?)angle	translatescaleshearr   r   centerShearY
TranslateX)r   r   r   r   r   r   
TranslateYRotater   r   
BrightnessColorContrast	Sharpness	PosterizeSolarizeAutoContrastEqualizeInvertIdentityzThe provided operator  is not recognized.)FZaffinemathdegreesatanintrotateZadjust_brightnessZadjust_saturationZadjust_contrastZadjust_sharpnessZ	posterizeZsolarizeZautocontrastZequalizeinvert
ValueErrorr    r3   P/var/www/auris/lib/python3.9/site-packages/torchvision/transforms/autoaugment.py	_apply_op   s    





	

	









r5   c                   @   s   e Zd ZdZdZdZdZdS )r   zoAutoAugment policies learned on different datasets.
    Available policies are IMAGENET, CIFAR10 and SVHN.
    ZimagenetZcifar10ZsvhnN)__name__
__module____qualname____doc__IMAGENETCIFAR10SVHNr3   r3   r3   r4   r   ]   s   r   c                	       s   e Zd ZdZejejdfeeee	e
  dd fddZee	eeee
ee f eee
ee f f  dddZeeeef eeeeef f d	d
dZeeeeeef dddZeedddZedddZ  ZS )r	   a?  AutoAugment data augmentation method based on
    `"AutoAugment: Learning Augmentation Strategies from Data" <https://arxiv.org/pdf/1805.09501.pdf>`_.
    If the image is torch Tensor, it should be of type torch.uint8, and it is expected
    to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
    If img is PIL Image, it is expected to be in mode "L" or "RGB".

    Args:
        policy (AutoAugmentPolicy): Desired policy enum defined by
            :class:`torchvision.transforms.autoaugment.AutoAugmentPolicy`. Default is ``AutoAugmentPolicy.IMAGENET``.
        interpolation (InterpolationMode): Desired interpolation enum defined by
            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
        fill (sequence or number, optional): Pixel fill value for the area outside the transformed
            image. If given a number, the value is used for all bands respectively.
    N)policyr   r   returnc                    s,   t    || _|| _|| _| || _d S N)super__init__r=   r   r   _get_policiespolicies)selfr=   r   r   	__class__r3   r4   rA   y   s
    
zAutoAugment.__init__)r=   r>   c                 C   sJ   |t jkrg dS |t jkr$g dS |t jkr6g dS td| dd S )N)))r$   皙?   )r   333333?	   )r%   rI      r&   rI   Nr'   皙?Nr'   rI   N))r$   rI      )r$   rI      r'   rG   N)r%   皙?   )rU   r   rP   rH   ))r%   rI      rQ   ))r$   rP   rL   r'   r   N))r   rV   rY   )r%   rI   rH   )rQ   )r$   rG   rS   )rX   r!   rG   r   ))r   rG   rJ   rQ   ))r'   r   NrO   r(   rI   NrZ   )r!   rI   rW   )r"   r   rH   )rX   )r!   r      ))r!   rP   rH   )r%   rP   rR   ))r#   rG   rR   r]   ))r   rI   rL   rZ   )r[   rQ   rT   rK   r\   r^   rN   ))r(   皙?N)r"   rV   rS   ))r   ffffff?r_   )r   333333?rJ   ))r#   rP   r   )r#   ?rY   ))r         ?rH   r   rb   rJ   ))r&   re   Nr'   rd   N))r   rV   rR   )r$   rc   rR   ))r!   rG   rY   )r    rI   rR   ))r#   rc   rJ   )r    rb   rJ   )rQ   )r'   re   N))r"   rI   rR   )r#   rI   rL   ))r!   rb   rR   )r   re   rH   ))r'   rc   N)r&   rG   N))r   rG   rY   )r#   rV   rS   ))r    rd   rS   )r!   rV   rH   ))r%   re   r_   )r(   r   N)r'   rV   NrM   )rh   rQ   ))r!   rd   rJ   rQ   )r&   rP   N)r%   rV   rH   ))r    ra   rY   )r!   rb   r   ))r%   rG   rL   r&   rd   N))r   rd   rJ   rf   )rj   )r%   rP   rY   )rO   r`   )rf   rj   ))r   rd   rW   )r(   rV   N)r   rd   rH   r(   rb   N)rQ   )r%   rI   rS   r(   rd   NrQ   rQ   )r   rd   rY   )rk   ri   )rl   )r(   rG   N))r   rd   rL   )r%   rV   rS   )ro   ri   rp   )rk   )r%   rc   rY   ))r   rP   rH   rm   )rg   )r   rI   rS   rn   ))r"   rc   rY   r   rP   rW   )r(   rP   N)r   r   r_   ))r   rb   rS   )r%   rG   rH   )r]   rq   ))r   rc   rR   )r   rd   rY   ))r   ra   rS   r]   ))r%   rb   r_   )r   rI   rR   ))r   rP   rW   rr   ))r   rb   rJ   )r   rP   rY   ))r   rP   rL   )r&   rb   N))r   rb   r_   r`   zThe provided policy r*   )r   r:   r;   r<   r2   )rD   r=   r3   r3   r4   rB      s    


zAutoAugment._get_policiesnum_bins
image_sizer>   c                 C   s   t dd|dft dd|dft dd|d  |dft dd|d  |dft dd|dft dd|dft dd|dft dd|dft dd|dfd	t ||d d
     dft dd|dft ddft ddft ddfdS )Nr   rc   Tt ?r   r         >@rd   rH   rW   F     o@)r   r   r   r   r   r    r!   r"   r#   r$   r%   r&   r'   r(   )torchlinspacearangeroundr/   tensorrD   rt   ru   r3   r3   r4   _augmentation_space   s    $zAutoAugment._augmentation_space)transform_numr>   c                 C   s4   t t| d }td}tdd}|||fS )zGet parameters for autoaugment transformation

        Returns:
            params required by the autoaugment transformation
        r   )r_   r_   )r/   ry   randintitemZrand)r   Z	policy_idprobssignsr3   r3   r4   
get_params   s    
zAutoAugment.get_paramsr   r>   c                 C   s   | j }t|\}}}t|trTt|ttfr>t|g| }n|durTdd |D }| t| j	\}}}| 
d||f}	t| j	| D ]n\}
\}}}||
 |kr|	| \}}|durt||  nd}|r||
 dkr|d9 }t|||| j|d}q|S )	z
            img (PIL Image or Tensor): Image to be transformed.

        Returns:
            PIL Image or Tensor: AutoAugmented image.
        Nc                 S   s   g | ]}t |qS r3   float.0fr3   r3   r4   
<listcomp>      z'AutoAugment.forward.<locals>.<listcomp>
   r   r         r   )r   r+   get_dimensions
isinstancer   r/   r   r   lenrC   r   	enumerater   r5   r   )rD   r   r   channelsheightwidthZtransform_idr   r   op_metair   pZmagnitude_id
magnitudessignedr   r3   r3   r4   forward   s"    
zAutoAugment.forwardr>   c                 C   s   | j j d| j d| j dS )Nz(policy=, fill=))rF   r6   r=   r   )rD   r3   r3   r4   __repr__  s    zAutoAugment.__repr__)r6   r7   r8   r9   r   r:   r   NEARESTr   listr   rA   tuplestrr/   rB   dictr   boolr   staticmethodr   r   r   __classcell__r3   r3   rE   r4   r	   h   s$   
*Z*r	   c                       s   e Zd ZdZdddejdfeeeeeee	  dd fddZ
eeeef eeeeef f d	d
dZeedddZedddZ  ZS )r
   a~  RandAugment data augmentation method based on
    `"RandAugment: Practical automated data augmentation with a reduced search space"
    <https://arxiv.org/abs/1909.13719>`_.
    If the image is torch Tensor, it should be of type torch.uint8, and it is expected
    to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
    If img is PIL Image, it is expected to be in mode "L" or "RGB".

    Args:
        num_ops (int): Number of augmentation transformations to apply sequentially.
        magnitude (int): Magnitude for all the transformations.
        num_magnitude_bins (int): The number of different magnitude values.
        interpolation (InterpolationMode): Desired interpolation enum defined by
            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
        fill (sequence or number, optional): Pixel fill value for the area outside the transformed
            image. If given a number, the value is used for all bands respectively.
    r_   rJ      N)num_opsr   num_magnitude_binsr   r   r>   c                    s,   t    || _|| _|| _|| _|| _d S r?   )r@   rA   r   r   r   r   r   )rD   r   r   r   r   r   rE   r3   r4   rA   2  s    
zRandAugment.__init__rs   c                 C   s   t ddft dd|dft dd|dft dd|d  |dft dd|d  |dft dd|dft dd	|dft dd	|dft dd	|dft dd	|dfd
t ||d d     dft dd|dft ddft ddfdS )Nr   Frc   Trv   r   r   rw   rd   rH   rW   rx   r)   r   r   r   r   r   r    r!   r"   r#   r$   r%   r&   r'   ry   r}   rz   r{   r|   r/   r~   r3   r3   r4   r   A  s    $zRandAugment._augmentation_spacer   c                 C   s   | j }t|\}}}t|trTt|ttfr>t|g| }n|durTdd |D }| | j||f}t	| j
D ]}ttt|d }t| | }	||	 \}
}|
jdkrt|
| j  nd}|rtddr|d9 }t||	|| j|d	}qp|S )

            img (PIL Image or Tensor): Image to be transformed.

        Returns:
            PIL Image or Tensor: Transformed image.
        Nc                 S   s   g | ]}t |qS r3   r   r   r3   r3   r4   r   a  r   z'RandAugment.forward.<locals>.<listcomp>r   r   r   r_   r   r   )r   r+   r   r   r   r/   r   r   r   ranger   ry   r   r   r   r   keysndimr   r5   r   )rD   r   r   r   r   r   r   _op_indexr   r   r   r   r3   r3   r4   r   T  s"    
 zRandAugment.forwardr   c                 C   s:   | j j d| j d| j d| j d| j d| j d}|S )Nz	(num_ops=z, magnitude=z, num_magnitude_bins=, interpolation=r   r   )rF   r6   r   r   r   r   r   rD   sr3   r3   r4   r   o  s    
	zRandAugment.__repr__)r6   r7   r8   r9   r   r   r/   r   r   r   rA   r   r   r   r   r   r   r   r   r   r3   r3   rE   r4   r
     s"   
*r
   c                       s|   e Zd ZdZdejdfeeeee	  dd fddZ
eeeeeef f ddd	Zeed
ddZedddZ  ZS )r   a  Dataset-independent data-augmentation with TrivialAugment Wide, as described in
    `"TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation" <https://arxiv.org/abs/2103.10158>`_.
    If the image is torch Tensor, it should be of type torch.uint8, and it is expected
    to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
    If img is PIL Image, it is expected to be in mode "L" or "RGB".

    Args:
        num_magnitude_bins (int): The number of different magnitude values.
        interpolation (InterpolationMode): Desired interpolation enum defined by
            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
        fill (sequence or number, optional): Pixel fill value for the area outside the transformed
            image. If given a number, the value is used for all bands respectively.
    r   N)r   r   r   r>   c                    s    t    || _|| _|| _d S r?   )r@   rA   r   r   r   )rD   r   r   r   rE   r3   r4   rA     s    
zTrivialAugmentWide.__init__)rt   r>   c                 C   s   t ddft dd|dft dd|dft dd|dft dd|dft dd|dft dd|dft dd|dft dd|dft dd|dfdt ||d d	     dft d
d|dft ddft ddfdS )Nr   FgGz?Tg      @@g     `@rH   r   rS   rx   r   r   )rD   rt   r3   r3   r4   r     s    $z&TrivialAugmentWide._augmentation_spacer   c                 C   s   | j }t|\}}}t|trTt|ttfr>t|g| }n|durTdd |D }| | j}tt	
t|d }t| | }|| \}	}
|	jdkrt|	t	j
t|	dt	jd  nd}|
rt	
ddr|d	9 }t|||| j|d
S )r   Nc                 S   s   g | ]}t |qS r3   r   r   r3   r3   r4   r     r   z.TrivialAugmentWide.forward.<locals>.<listcomp>r   r   dtyper   r_   r   r   )r   r+   r   r   r   r/   r   r   r   ry   r   r   r   r   r   r   longr5   r   )rD   r   r   r   r   r   r   r   r   r   r   r   r3   r3   r4   r     s$    
$zTrivialAugmentWide.forwardr   c                 C   s*   | j j d| j d| j d| j d}|S )Nz(num_magnitude_bins=r   r   r   )rF   r6   r   r   r   r   r3   r3   r4   r     s    
zTrivialAugmentWide.__repr__)r6   r7   r8   r9   r   r   r/   r   r   r   rA   r   r   r   r   r   r   r   r   r   r3   r3   rE   r4   r   |  s   
 r   c                
       s   e Zd ZdZdddddejdfeeeeeee	e
e  dd fdd	Zeeeef eeeeef f d
ddZejjedddZejjedddZeedddZeedddZedddZ  ZS )r   a  AugMix data augmentation method based on
    `"AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty" <https://arxiv.org/abs/1912.02781>`_.
    If the image is torch Tensor, it should be of type torch.uint8, and it is expected
    to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
    If img is PIL Image, it is expected to be in mode "L" or "RGB".

    Args:
        severity (int): The severity of base augmentation operators. Default is ``3``.
        mixture_width (int): The number of augmentation chains. Default is ``3``.
        chain_depth (int): The depth of augmentation chains. A negative value denotes stochastic depth sampled from the interval [1, 3].
            Default is ``-1``.
        alpha (float): The hyperparameter for the probability distributions. Default is ``1.0``.
        all_ops (bool): Use all operations (including brightness, contrast, color and sharpness). Default is ``True``.
        interpolation (InterpolationMode): Desired interpolation enum defined by
            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
        fill (sequence or number, optional): Pixel fill value for the area outside the transformed
            image. If given a number, the value is used for all bands respectively.
    rY   r   TN)severitymixture_widthchain_depthalphaall_opsr   r   r>   c                    sn   t    d| _d|  kr&| jks@n td| j d| d|| _|| _|| _|| _|| _|| _	|| _
d S )Nr   r   z!The severity must be between [1, z]. Got z	 instead.)r@   rA   _PARAMETER_MAXr2   r   r   r   r   r   r   r   )rD   r   r   r   r   r   r   r   rE   r3   r4   rA     s    

zAugMix.__init__rs   c                 C   s
  t dd|dft dd|dft d|d d |dft d|d d |dft dd|dfdt ||d d     d	ft d
d|d	ft dd	ft dd	fd	}| jr|t dd|dft dd|dft dd|dft dd|dfd |S )Nr   rc   Tr   g      @r   rw   rW   Frx   )	r   r   r   r   r   r$   r%   r&   r'   rd   )r    r!   r"   r#   )ry   rz   r{   r|   r/   r}   r   update)rD   rt   ru   r   r3   r3   r4   r     s&    $zAugMix._augmentation_spacer   c                 C   s
   t |S r?   )r+   Zpil_to_tensorrD   r   r3   r3   r4   _pil_to_tensor  s    zAugMix._pil_to_tensor)r   c                 C   s
   t |S r?   )r+   Zto_pil_imager   r3   r3   r4   _tensor_to_pil  s    zAugMix._tensor_to_pil)paramsr>   c                 C   s
   t |S r?   )ry   _sample_dirichlet)rD   r   r3   r3   r4   r     s    zAugMix._sample_dirichlet)orig_imgr>   c              	   C   s|  | j }t|\}}}t|trZ|}t|ttfrBt|g| }qd|durddd |D }n
| |}| | j	||f}t
|j}|dgtd|j d | }	|	dgdg|	jd   }
| tj| j| jg|	jd|
d d}| tj| jg| j |	jd|
d d|dddf |
d dg }|dddf |
|	 }t| jD ]}|	}| jdkrn| jnttjddd	d
 }t|D ]}ttt|d	 }t
| | }|| \}}|jdkrt|tj| jd	tjd  nd}|rtdd	r|d9 }t|||| j |d}q|!|dd|f |
|  qT||j"|j#d}t|tsx| $|S |S )r   Nc                 S   s   g | ]}t |qS r3   r   r   r3   r3   r4   r   /  r   z"AugMix.forward.<locals>.<listcomp>r   rW   r   )devicer   r   )lowhighsizer   r   r_   r   r   )%r   r+   r   r   r   r/   r   r   r   r   r   shapeviewmaxr   r   r   ry   r}   r   r   expandr   r   r   r   r   r   r   r   r   r5   r   Zadd_tor   r   )rD   r   r   r   r   r   r   r   Z	orig_dimsbatchZ
batch_dimsmZcombined_weightsZmixr   augdepthr   r   r   r   r   r   r3   r3   r4   r   !  sR    


 "$*$$
zAugMix.forwardc                 C   sJ   | j j d| j d| j d| j d| j d| j d| j d| j d}|S )	Nz
(severity=z, mixture_width=z, chain_depth=z, alpha=z
, all_ops=r   r   r   )	rF   r6   r   r   r   r   r   r   r   r   r3   r3   r4   r   [  s"    
zAugMix.__repr__)r6   r7   r8   r9   r   ZBILINEARr/   r   r   r   r   rA   r   r   r   r   r   ry   ZjitZunusedr   r   r   r   r   r   r3   r3   rE   r4   r     s4   
*:r   )r,   enumr   typingr   ry   r    r   r+   r   __all__r   r   r   r5   r   nnModuler	   r
   r   r   r3   r3   r3   r4   <module>   s   P 8]V