a
    h:                     @   s  d dl mZ d dlmZ d dlmZmZ d dlZd dlmZ d dl	m
Z ddlmZ d	d
lmZmZmZ d	dlmZ d	dlmZmZmZ d	dlmZmZmZ d	dlmZmZmZm Z m!Z! ddlm"Z" ddl#m$Z$ g dZ%G dd de"Z&G dd dej'Z(G dd dej'Z)G dd dej'Z*G dd dej+Z,ee-ee. e&dddZ/ed d!d"Z0G d#d$ d$eZ1G d%d& d&eZ2G d'd( d(eZ3ee-ee. e&dd)d*Z4e ed+e1j5fd,e!j6fd-dd.dde!j6d/ee1 e.ee- ee. ee! ee&d0d1d2Z7e ed+e2j5fd,ej6fd-dd.ddej6d/ee2 e.ee- ee. ee ee&d0d3d4Z8e ed+e3j5fd,ej6fd-dd.ddej6d/ee3 e.ee- ee. ee ee&d0d5d6Z9dS )7    )Sequence)partial)AnyOptionalN)nn)
functional   )SemanticSegmentation   )register_modelWeightsWeightsEnum)_VOC_CATEGORIES)_ovewrite_value_paramhandle_legacy_interfaceIntermediateLayerGetter)mobilenet_v3_largeMobileNet_V3_Large_WeightsMobileNetV3)ResNet	resnet101ResNet101_Weightsresnet50ResNet50_Weights   )_SimpleSegmentationModel)FCNHead)	DeepLabV3DeepLabV3_ResNet50_WeightsDeepLabV3_ResNet101_Weights$DeepLabV3_MobileNet_V3_Large_Weightsdeeplabv3_mobilenet_v3_largedeeplabv3_resnet50deeplabv3_resnet101c                   @   s   e Zd ZdZdS )r   a  
    Implements DeepLabV3 model from
    `"Rethinking Atrous Convolution for Semantic Image Segmentation"
    <https://arxiv.org/abs/1706.05587>`_.

    Args:
        backbone (nn.Module): the network used to compute the features for the model.
            The backbone should return an OrderedDict[Tensor], with the key being
            "out" for the last feature map used, and "aux" if an auxiliary classifier
            is used.
        classifier (nn.Module): module that takes the "out" element returned from
            the backbone and returns a dense prediction.
        aux_classifier (nn.Module, optional): auxiliary classifier used during training
    N)__name__
__module____qualname____doc__ r(   r(   W/var/www/auris/lib/python3.9/site-packages/torchvision/models/segmentation/deeplabv3.pyr      s   r   c                       s.   e Zd Zdeeee dd fddZ  ZS )DeepLabHead      $   N)in_channelsnum_classesatrous_ratesreturnc                    sB   t  t||tjddddddtdt td|d d S )N   r   r   F)paddingbias)super__init__ASPPr   Conv2dBatchNorm2dReLU)selfr/   r0   r1   	__class__r(   r)   r7   2   s    zDeepLabHead.__init__)r+   )r$   r%   r&   intr   r7   __classcell__r(   r(   r=   r)   r*   1   s   r*   c                       s(   e Zd Zeeedd fddZ  ZS )ASPPConvN)r/   out_channelsdilationr2   c                    s6   t j||d||ddt |t  g}t j|  d S )Nr   F)r4   rC   r5   )r   r9   r:   r;   r6   r7   )r<   r/   rB   rC   modulesr=   r(   r)   r7   =   s
    zASPPConv.__init__)r$   r%   r&   r?   r7   r@   r(   r(   r=   r)   rA   <   s   rA   c                       s:   e Zd Zeedd fddZejejdddZ  ZS )ASPPPoolingN)r/   rB   r2   c              	      s4   t  tdtj||dddt|t  d S )Nr   Fr5   )r6   r7   r   ZAdaptiveAvgPool2dr9   r:   r;   )r<   r/   rB   r=   r(   r)   r7   G   s    zASPPPooling.__init__xr2   c                 C   s2   |j dd  }| D ]}||}qtj||dddS )NZbilinearF)sizemodeZalign_corners)shapeFZinterpolate)r<   rH   rJ   modr(   r(   r)   forwardO   s    
zASPPPooling.forward)	r$   r%   r&   r?   r7   torchTensorrO   r@   r(   r(   r=   r)   rE   F   s   rE   c                       sB   e Zd Zd	eee edd fddZejejdddZ  Z	S )
r8   r3   N)r/   r1   rB   r2   c              
      s   t    g }|ttj||dddt|t  t|}|D ]}|t	||| qF|t
|| t|| _ttjt| j| |dddt|t td| _d S )Nr   FrF   g      ?)r6   r7   appendr   
Sequentialr9   r:   r;   tuplerA   rE   Z
ModuleListconvslenZDropoutproject)r<   r/   r1   rB   rD   ZratesZrater=   r(   r)   r7   W   s     
$zASPP.__init__rG   c                 C   s6   g }| j D ]}||| q
tj|dd}| |S )Nr   )Zdim)rU   rR   rP   catrW   )r<   rH   Z_resconvresr(   r(   r)   rO   m   s
    
zASPP.forward)r3   )
r$   r%   r&   r?   r   r7   rP   rQ   rO   r@   r(   r(   r=   r)   r8   V   s   r8   )backboner0   auxr2   c                 C   sH   ddi}|rd|d< t | |d} |r.td|nd }td|}t| ||S )NZlayer4outr\   Zlayer3return_layersi   i   )r   r   r*   r   )r[   r0   r\   r_   aux_classifier
classifierr(   r(   r)   _deeplabv3_resnetu   s    
rb   )r   r   z
        These weights were trained on a subset of COCO, using only the 20 categories that are present in the Pascal VOC
        dataset.
    )
categoriesZmin_sizeZ_docsc                
   @   sD   e Zd Zedeeddi edddddd	id
dddZeZdS )r   zHhttps://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth  Zresize_sizeijzVhttps://github.com/pytorch/vision/tree/main/references/segmentation#deeplabv3_resnet50COCO-val2017-VOC-labelsgP@皙W@ZmiouZ	pixel_accgvWf@gGzd@Z
num_paramsZrecipeZ_metricsZ_ops
_file_sizeurlZ
transformsmetaN	r$   r%   r&   r   r   r	   _COMMON_METACOCO_WITH_VOC_LABELS_V1DEFAULTr(   r(   r(   r)   r      s$   
r   c                
   @   sD   e Zd Zedeeddi edddddd	id
dddZeZdS )r   zIhttps://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pthrd   re   ijzQhttps://github.com/pytorch/vision/tree/main/references/segmentation#fcn_resnet101rf   gP@rg   rh   gS+p@gm&m@ri   rk   Nrn   r(   r(   r(   r)   r      s$   
r   c                
   @   sD   e Zd Zedeeddi edddddd	id
dddZeZdS )r    zMhttps://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pthrd   re   iPK z`https://github.com/pytorch/vision/tree/main/references/segmentation#deeplabv3_mobilenet_v3_largerf   gfffff&N@gV@rh   gCl$@gJ+&E@ri   rk   Nrn   r(   r(   r(   r)   r       s$   
r    c                 C   s   | j } dgdd t| D  t| d g }|d }| | j}|d }| | j}t|di}|rld|t|< t| |d	} |rt||nd }	t||}
t| |
|	S )
Nr   c                 S   s    g | ]\}}t |d dr|qS )Z_is_cnF)getattr).0ibr(   r(   r)   
<listcomp>       z*_deeplabv3_mobilenetv3.<locals>.<listcomp>r   r]   r\   r^   )	features	enumeraterV   rB   strr   r   r*   r   )r[   r0   r\   Zstage_indicesZout_posZout_inplanesZaux_posZaux_inplanesr_   r`   ra   r(   r(   r)   _deeplabv3_mobilenetv3   s    &


r}   Z
pretrainedZpretrained_backbone)weightsweights_backboneT)r~   progressr0   aux_lossr   )r~   r   r0   r   r   kwargsr2   c                 K   s   t | } t|}| durDd}td|t| jd }td|d}n|du rPd}t|g dd}t|||}| dur|| j	|dd	 |S )
ad  Constructs a DeepLabV3 model with a ResNet-50 backbone.

    .. betastatus:: segmentation module

    Reference: `Rethinking Atrous Convolution for Semantic Image Segmentation <https://arxiv.org/abs/1706.05587>`__.

    Args:
        weights (:class:`~torchvision.models.segmentation.DeepLabV3_ResNet50_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.segmentation.DeepLabV3_ResNet50_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        num_classes (int, optional): number of output classes of the model (including the background)
        aux_loss (bool, optional): If True, it uses an auxiliary loss
        weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The pretrained weights for the
            backbone
        **kwargs: unused

    .. autoclass:: torchvision.models.segmentation.DeepLabV3_ResNet50_Weights
        :members:
    Nr0   rc   r   T   FTTr~   Zreplace_stride_with_dilationr   Z
check_hash)
r   verifyr   r   rV   rm   r   rb   load_state_dictget_state_dictr~   r   r0   r   r   r   r[   modelr(   r(   r)   r"      s    %

r"   c                 K   s   t | } t|}| durDd}td|t| jd }td|d}n|du rPd}t|g dd}t|||}| dur|| j	|dd	 |S )
ai  Constructs a DeepLabV3 model with a ResNet-101 backbone.

    .. betastatus:: segmentation module

    Reference: `Rethinking Atrous Convolution for Semantic Image Segmentation <https://arxiv.org/abs/1706.05587>`__.

    Args:
        weights (:class:`~torchvision.models.segmentation.DeepLabV3_ResNet101_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.segmentation.DeepLabV3_ResNet101_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        num_classes (int, optional): number of output classes of the model (including the background)
        aux_loss (bool, optional): If True, it uses an auxiliary loss
        weights_backbone (:class:`~torchvision.models.ResNet101_Weights`, optional): The pretrained weights for the
            backbone
        **kwargs: unused

    .. autoclass:: torchvision.models.segmentation.DeepLabV3_ResNet101_Weights
        :members:
    Nr0   rc   r   Tr   r   r   r   )
r   r   r   r   rV   rm   r   rb   r   r   r   r(   r(   r)   r#     s    %

r#   c                 K   s   t | } t|}| durDd}td|t| jd }td|d}n|du rPd}t|dd}t|||}| dur|| j	|dd |S )	ak  Constructs a DeepLabV3 model with a MobileNetV3-Large backbone.

    Reference: `Rethinking Atrous Convolution for Semantic Image Segmentation <https://arxiv.org/abs/1706.05587>`__.

    Args:
        weights (:class:`~torchvision.models.segmentation.DeepLabV3_MobileNet_V3_Large_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.segmentation.DeepLabV3_MobileNet_V3_Large_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        num_classes (int, optional): number of output classes of the model (including the background)
        aux_loss (bool, optional): If True, it uses an auxiliary loss
        weights_backbone (:class:`~torchvision.models.MobileNet_V3_Large_Weights`, optional): The pretrained weights
            for the backbone
        **kwargs: unused

    .. autoclass:: torchvision.models.segmentation.DeepLabV3_MobileNet_V3_Large_Weights
        :members:
    Nr0   rc   r   Tr   )r~   Zdilatedr   )
r    r   r   r   rV   rm   r   r}   r   r   r   r(   r(   r)   r!   T  s    #

r!   ):collections.abcr   	functoolsr   typingr   r   rP   r   Ztorch.nnr   rM   Ztransforms._presetsr	   Z_apir   r   r   Z_metar   _utilsr   r   r   Zmobilenetv3r   r   r   Zresnetr   r   r   r   r   r   Zfcnr   __all__r   rS   r*   rA   rE   Moduler8   r?   boolrb   ro   r   r   r    r}   rp   ZIMAGENET1K_V1r"   r#   r!   r(   r(   r(   r)   <module>   s   
 
33