o
    Zh                     @   sf  d Z ddlmZmZ ddlZddlm  m  mZ	 ddl
m  m  m  mZ ddlmZ ddlm  mZ ddlmZ ddlmZ ddlmZmZmZ ddlmZ dd	lmZmZ g d
ZddhZde e! de e! fddZ"G dd deZ#G dd de#Z$G dd de#Z%G dd de#Z&G dd de#Z'G dd de'Z(G dd de'Z)G dd  d e'Z*dS )!zQuantized convolution modules.    )ClassVarOptionalN)ops)	_size_1_t)_pair_single_triple)fuse_conv_bn_weights   )_quantize_weightWeightedQuantizedModule)Conv1dConv2dConv3dConvTranspose1dConvTranspose2dConvTranspose3dzerosreflectpaddingreturnc                    s>   g }t  t D ]| fddtdD  q
|S )Nc                 3   s     | ]}  d   V  qdS )r
   N ).0_Nidxr   r   Q/var/www/auris/lib/python3.10/site-packages/torch/ao/nn/quantized/modules/conv.py	<genexpr>#   s    z*_reverse_repeat_padding.<locals>.<genexpr>   )lenrangeextend)r    _reversed_padding_repeated_twicer   r   r   _reverse_repeat_padding   s
   "r$   c                       s   e Zd Z								d&ddZ			d'	d( fd	d
Zdd Zdd Zdd Zdd Z fddZ	e
jjdd Z fddZe
jjdd Zdd Zdd Zed)dd Zed*d"d#Zed$d% Z  ZS )+_ConvNdr
   r   Tr   Nc                 C      t NNotImplementedError)selfin_channelsout_channelskernel_sizestrider   dilationgroupsbiaspadding_modedevicedtyper   r   r   __init__(   s   z_ConvNd.__init__r   c                    s0  ||d}t    ||	 dkrtd||	 dkrtd|| _|| _|| _|| _|| _|| _|| _	|| _
|	| _|tvrEtd| d|| _| j	rS||| j g}n||| j g}tj|t| fddtjdd	d
 | D }|
rtj|fdtjidd
 | D nd }| || d| _d| _d S )Nr3   r4   r   z'in_channels must be divisible by groupsz(out_channels must be divisible by groupsz'padding_mode' z* is not supported by quantized convolutionr
   )scale
zero_pointr4   c                 S      i | ]\}}|d kr||qS r4   r   r   kvr   r   r   
<dictcomp>h       z!_ConvNd._init.<locals>.<dictcomp>r4   c                 S   r9   r:   r   r;   r   r   r   r>   n   r?   g      ?)superr5   
ValueErrorr+   r,   r-   r.   r   r/   
transposedoutput_paddingr0   _SUPPORTED_PADDINGr2   torchZ_empty_affine_quantizedlistqint8itemsr   floatset_weight_biasr7   r8   )r*   r+   r,   r-   r.   r   r/   rB   rC   r0   r1   r2   r3   r4   factory_kwargsZweight_shapeqweight
bias_float	__class__r   r   _init9   sZ   





z_ConvNd._initc                 C   r&   r'   r(   )r*   rL   rM   r   r   r   rJ   x      z_ConvNd.set_weight_biasc                 C   r&   r'   r(   r*   r   r   r   r1   {   rQ   z_ConvNd.biasc                 C   r&   r'   r(   rR   r   r   r   _weight_bias~   rQ   z_ConvNd._weight_biasc                 C   s   d}| j dt| j  kr|d7 }| jdt| j kr|d7 }| jdt| j kr,|d7 }| jdkr5|d7 }|  d u r?|d	7 }|jd
i | jS )Nzq{in_channels}, {out_channels}, kernel_size={kernel_size}, stride={stride}, scale={scale}, zero_point={zero_point})r   z, padding={padding})r
   z, dilation={dilation}z!, output_padding={output_padding}r
   z, groups={groups}z, bias=Falser   )r   r    r/   rC   r0   r1   format__dict__)r*   sr   r   r   
extra_repr   s   
z_ConvNd.extra_reprc                    s`   t  ||| |  \}}|||d < |||d < t| j||d < t| j||d < d S )Nweightr1   r7   r8   )r@   _save_to_state_dictrS   rE   Ztensorr7   r8   )r*   destinationprefixZ	keep_varswbrN   r   r   rY      s   z_ConvNd._save_to_state_dictc                 C   sH   |   \}}| j| j| j| j| j| j| j| j| j	| j
||| j| j| jfS r'   )rS   r+   r,   r-   r.   r   r/   rB   rC   r0   r2   r7   r8   trainingr*   r\   r]   r   r   r   __getstate__   s"   z_ConvNd.__getstate__c              	      s   |  ||d  ||d   ||d  ||d  t||d  | _||d  t||d  | _||d  t |||d||| d S )NrX   r1   r7   r8   F)rJ   poprI   r7   intr8   r@   _load_from_state_dict)r*   Z
state_dictr[   Zlocal_metadatastrictZmissing_keysZunexpected_keys
error_msgsrN   r   r   rc      s    
z_ConvNd._load_from_state_dictc                 C   s   |d | _ |d | _|d | _|d | _|d | _|d | _|d | _|d | _|d	 | _|d
 | _	| 
|d |d  |d | _|d | _|d | _d S )Nr   r
   r                     	   
               )r+   r,   r-   r.   r   r/   rB   rC   r0   r2   rJ   r7   r8   r^   )r*   stater   r   r   __setstate__   s   











z_ConvNd.__setstate__c                 C   s6   t | t | }tjj| |  }|| |S r'   )type__new__rE   nnModuler5   r`   rs   )r*   memoZnew_instancerr   r   r   r   __deepcopy__   s
   
z_ConvNd.__deepcopy__c                 C   s
   |  i S r'   )ry   rR   r   r   r   __copy__      
z_ConvNd.__copy__c              
   C   s   |du r	|j  }||j |jtjksJ dt|j |}| |j|j|j	|j
|j|j|j|jdu|j	}|||j |du sH|jtjkrJ|S | \}}t||_t||_|S )z&Creates a qconv object and returns it.N*Weight observer must have a dtype of qint8)qconfigrX   r4   rE   rG   r   rI   r+   r,   r-   r.   r   r/   r0   r1   r2   rJ   calculate_qparamsr7   rb   r8   )clsmodactivation_post_processweight_post_processrL   qconv	act_scaleact_zpr   r   r   	get_qconv   s4   



z_ConvNd.get_qconvFc                 C   s  t |dr6t|| jkr&t|j|j|jj|jj|jj	|jj|jj\|_|_t |ds/J d|j
}|j}nDt|| jksRJ d| j d | jj d tt| t |ds[J dt |dsbd n|j}t|| j| j| jfv ru|d	 }|j }| |||S )
Nweight_fake_quantr   z,Input QAT module must have observer attached nnq..from_float only works for z	 but got:r}   -Input float module must have qconfig defined.r   )hasattrrt   _NNIQAT_CONV_BN_MODULEr	   rX   r1   ZbnZrunning_meanZrunning_varZepsr   r   _FLOAT_MODULE__name__str_NNI_CONV_RELU_MODULE_NNI_CONV_ADD_MODULE_NNI_CONV_ADD_RELU_MODULEr}   r   )r   r   use_precomputed_fake_quantr   r   r   r   r   
from_float  s`   
	

z_ConvNd.from_floatc                 C   sj   | |j |j|j|j|j|j|j|jdu|j|j	j
|j	jd}| }|||j t||_t||_|S )a  Create a (fbgemm/qnnpack) quantized module from a reference quantized module
        Args:
            ref_qconv (Module): a reference quantized  module, either produced by torch.ao.quantization
                                utilities or provided by the user
            output_scale (float): scale for output Tensor
            output_zero_point (int): zero point for output Tensor
        Nr6   )r+   r,   r-   r.   r   r/   r0   r1   r2   rX   r3   r4   get_quantized_weightrJ   rI   r7   rb   r8   )r   Z	ref_qconvoutput_scaleoutput_zero_pointr   rL   r   r   r   from_referenceD  s$   	

z_ConvNd.from_referencer
   r   r
   r
   Tr   NN)r   NN)r   Nr'   F)r   
__module____qualname__r5   rP   rJ   r1   rS   rW   rY   rE   jitZexportr`   rc   rs   ry   rz   classmethodr   staticmethodr   r   __classcell__r   r   rN   r   r%   '   sD    
?

!-r%   c                       s2  e Zd ZU dZejZeeej  e	d< e
jZeeeej   e	d< ejZeeeej   e	d< dZeeeej   e	d< dZeeeej   e	d< 						
			d)dededededededededef fddZdd Zdejdeej ddfddZdd Zd d! Zd"d# Zd$d% Z e!d*d'd(Z"  Z#S )+r   a`  Applies a 1D convolution over a quantized input signal composed of
    several quantized input planes.

    For details on input arguments, parameters, and implementation see
    :class:`~torch.nn.Conv1d`.

    .. note::
        Only `zeros` is supported for the :attr:`padding_mode` argument.

    .. note::
        Only `torch.quint8` is supported for the input data type.


    Attributes:
        weight (Tensor):     packed tensor derived from the learnable weight
                             parameter.
        scale (Tensor):      scalar for the output scale
        zero_point (Tensor): scalar for the output zero point

    See :class:`~torch.nn.Conv1d` for other attributes.

    Examples::

        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
        >>> m = nn.quantized.Conv1d(16, 33, 3, stride=2)
        >>> input = torch.randn(20, 16, 100)
        >>> # quantize input to quint8
        >>> # xdoctest: +SKIP
        >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0,
        ...                                     dtype=torch.quint8)
        >>> output = m(q_input)

    r   r   r   Nr   r   r
   r   Tr   r+   r,   r-   r.   r   r/   r0   r1   r2   c                    sh   |
|d}t |}t |}t|tr|nt |}t |}t j||||||dt d|||	fi | d S Nr6   Fr   )r   
isinstancer   r@   rP   r*   r+   r,   r-   r.   r   r/   r0   r1   r2   r3   r4   rK   rN   r   r   r5     s(   

zConv1d.__init__c                 C      dS )NZQuantizedConv1dr   rR   r   r   r   	_get_name  rQ   zConv1d._get_namer\   r]   r   c                 C   X   | j dkrtjj||| j| j| j| j| _	d S tjj||| jt
d| j| j| _	d S Nr   r   )r2   rE   r   	quantizedZconv1d_prepackr.   r   r/   r0   _packed_paramsr   r_   r   r   r   rJ        


zConv1d.set_weight_biasc                 C      t jj| j\}}||fS r'   )rE   r   r   Zconv1d_unpackr   r_   r   r   r   rS        zConv1d._weight_biasc                 C      |   d S Nr   rS   rR   r   r   r   rX        zConv1d.weightc                 C   r   Nr
   r   rR   r   r   r   r1     r   zConv1d.biasc                 C   s\   t |jdkrtd| jdkr"t| jd d }tj||| jd}tj	
|| j| j| jS )Nrf    Input shape must be `(N, C, L)`!r   r
   mode)r    shaperA   r2   r$   r   Fpadr   r   Zconv1dr   r7   r8   r*   inputr#   r   r   r   forward  s   
zConv1d.forwardFc                 C      t j| ||dS zCreates a quantized module from a float module or qparams_dict.

        Args:
            mod (Module): a float module, either produced by torch.ao.quantization
              utilities or provided by the user
        )r   r%   r   r   r   r   r   r   r   r        zConv1d.from_floatr   r   )$r   r   r   __doc__rv   r   r   r   rt   __annotations__nniqatZConvBn1dr   r   rw   nniZ
ConvReLU1dr   r   r   rb   r   boolr   r5   r   rE   TensorrJ   rS   rX   r1   r   r   r   r   r   r   rN   r   r   a  sT   
 "	
%
r   c                       s  e Zd ZU dZejZeeej  e	d< e
jZeeeej   e	d< ejZeeeej   e	d< ejZeeej  e	d< ejZeeej  e	d< 							
		d  fdd	Zdd Zdejdeej ddfddZdd Zdd Zdd Zdd Zed!ddZ   Z!S )"r   a  Applies a 2D convolution over a quantized input signal composed of
    several quantized input planes.

    For details on input arguments, parameters, and implementation see
    :class:`~torch.nn.Conv2d`.

    .. note::
        Only `zeros` is supported for the :attr:`padding_mode` argument.

    .. note::
        Only `torch.quint8` is supported for the input data type.


    Attributes:
        weight (Tensor):     packed tensor derived from the learnable weight
                             parameter.
        scale (Tensor):      scalar for the output scale
        zero_point (Tensor): scalar for the output zero point

    See :class:`~torch.nn.Conv2d` for other attributes.

    Examples::

        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
        >>> # With square kernels and equal stride
        >>> m = nn.quantized.Conv2d(16, 33, 3, stride=2)
        >>> # non-square kernels and unequal stride and with padding
        >>> m = nn.quantized.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
        >>> # non-square kernels and unequal stride and with padding and dilation
        >>> m = nn.quantized.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1))
        >>> input = torch.randn(20, 16, 50, 100)
        >>> # quantize input to quint8
        >>> # xdoctest: +SKIP
        >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
        >>> output = m(q_input)

    r   r   r   r   r   r
   r   Tr   Nc                    sZ   |
|d}t |}t |}t |}t |}t j||||||dt d|||	fi | d S r   )r   r@   rP   r   rN   r   r   r5     s(   

zConv2d.__init__c                 C   r   )NZQuantizedConv2dr   rR   r   r   r   r   2  rQ   zConv2d._get_namer\   r]   r   c                 C   r   r   )r2   rE   r   r   Zconv2d_prepackr.   r   r/   r0   r   r   r_   r   r   r   rJ   5  r   zConv2d.set_weight_biasc                 C   
   | j  S r'   r   unpackrR   r   r   r   rS   ?  r{   zConv2d._weight_biasc                 C   r   r   r   rR   r   r   r   rX   B  r   zConv2d.weightc                 C   r   r   r   rR   r   r   r   r1   E  r   zConv2d.biasc                 C   T   t |jdkrtd| jdkrt| j}tj||| jd}tj	
|| j| j| jS )Nrg   #Input shape must be `(N, C, H, W)`!r   r   )r    r   rA   r2   r$   r   r   r   r   r   Zconv2dr   r7   r8   r   r   r   r   r   H     

zConv2d.forwardFc                 C   r   r   r   r   r   r   r   r   V  r   zConv2d.from_floatr   r   )"r   r   r   r   rv   r   r   r   rt   r   r   ZConvBn2dr   r   rw   r   Z
ConvReLU2dr   Z	ConvAdd2dr   ZConvAddReLU2dr   r5   r   rE   r   rJ   rS   rX   r1   r   r   r   r   r   r   rN   r   r     s0   
 %$
r   c                       s  e Zd ZU dZejZeeej  e	d< e
jZeeeej   e	d< ejZeeeej   e	d< dZeeeej   e	d< dZeeeej   e	d< 						
			d  fdd	Zdd Zdejdeej ddfddZdd Zdd Zdd Zdd Zed!ddZ  ZS )"r   a  Applies a 3D convolution over a quantized input signal composed of
    several quantized input planes.

    For details on input arguments, parameters, and implementation see
    :class:`~torch.nn.Conv3d`.

    .. note::
        Only `zeros` is supported for the :attr:`padding_mode` argument.

    .. note::
        Only `torch.quint8` is supported for the input data type.


    Attributes:
        weight (Tensor):     packed tensor derived from the learnable weight
                             parameter.
        scale (Tensor):      scalar for the output scale
        zero_point (Tensor): scalar for the output zero point

    See :class:`~torch.nn.Conv3d` for other attributes.

    Examples::

        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
        >>> # With square kernels and equal stride
        >>> m = nn.quantized.Conv3d(16, 33, 3, stride=2)
        >>> # non-square kernels and unequal stride and with padding
        >>> m = nn.quantized.Conv3d(16, 33, (3, 5, 5), stride=(1, 2, 2), padding=(1, 2, 2))
        >>> # non-square kernels and unequal stride and with padding and dilation
        >>> m = nn.quantized.Conv3d(16, 33, (3, 5, 5), stride=(1, 2, 2), padding=(1, 2, 2), dilation=(1, 2, 2))
        >>> input = torch.randn(20, 16, 56, 56, 56)
        >>> # quantize input to quint8
        >>> # xdoctest: +SKIP
        >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
        >>> output = m(q_input)

    r   r   r   Nr   r   r
   r   Tr   c                    sj   |	dksJ d|
|d}t |}t |}t |}t |}t j||||||dt d|||	fi | d S )Nr   z*Conv3d does not support reflection paddingr6   Fr   )r   r@   rP   r   rN   r   r   r5     s*   

zConv3d.__init__c                 C   r   )NZQuantizedConv3dr   rR   r   r   r   r     rQ   zConv3d._get_namer\   r]   r   c                 C   r   r   )r2   rE   r   r   Zconv3d_prepackr.   r   r/   r0   r   r   r_   r   r   r   rJ     r   zConv3d.set_weight_biasc                 C   r   r'   r   rR   r   r   r   rS     r{   zConv3d._weight_biasc                 C   r   r   r   rR   r   r   r   rX     r   zConv3d.weightc                 C   r   r   r   rR   r   r   r   r1     r   zConv3d.biasc                 C   r   )Nrh   z&Input shape must be `(N, C, D, H, W)`!r   r   )r    r   rA   r2   r$   r   r   r   r   r   Zconv3dr   r7   r8   r   r   r   r   r     r   zConv3d.forwardFc                 C   r   r   r   r   r   r   r   r     r   zConv3d.from_floatr   r   ) r   r   r   r   rv   r   r   r   rt   r   r   ZConvBn3dr   r   rw   r   Z
ConvReLU3dr   r   r   r5   r   rE   r   rJ   rS   rX   r1   r   r   r   r   r   r   rN   r   r   c  s0   
 %%
r   c                	       s~   e Zd ZU eeejjj  e	d< 		d fdd	Z
dee dee dee dee fd	d
ZedddZedd Z  ZS )_ConvTransposeNdr   Nc                    sP   |dkrt d| jj ||d}t j|||||||||	|
|fi | d S )Nr   z+Only "zeros" padding mode is supported for r6   )rA   rO   r   r@   rP   )r*   r+   r,   r-   r.   r   r/   rB   rC   r0   r1   r2   r3   r4   rK   rN   r   r   r5     s(   

z_ConvTransposeNd.__init__r-   r/   r   r   c                 C   sN   t jtt g }tt|D ]}|| || d  ||  }|| q|S r   )rE   r   ZannotaterF   rb   r!   r    append)r*   r-   r/   r   resZkdxr   r   r   r   _input_padding  s
   z_ConvTransposeNd._input_paddingFc           	      C   s   d| j  d | jj  }t|| jksJ |t|dsJ d|j }||j |jtjks3J dt	|j
 |}| |j|j|j|j|j|j|j|jdu|j|j
}|||j t|drg|jjtj
kri|S |j \}}t
||_t||_|S )zCreates a quantized module from a float module or qparams_dict.
        Args:
            mod (Module): a float module, either produced by torch.ao.quantization
              utilities or provided by the user
        r   r   r}   r   r|   Nr   )r   r   rt   r   r}   rX   r4   rE   rG   r   rI   r+   r,   r-   r.   r   rC   r0   r1   r/   r2   rJ   r   r~   r7   rb   r8   )	r   r   r   msgr   rL   r   r   r   r   r   r   r     sJ   	



z_ConvTransposeNd.from_floatc                 C   sn   | |j |j|j|j|j|j|j|jdu|j|j	|j
j|j
jd}| }|||j t||_t||_|S )a  Create a (fbgemm/qnnpack) quantized module from a reference quantized module
        Args:
            ref_qconvt (Module): a reference quantized  module, either produced by torch.ao.quantization
                                 utilities or provided by the user
            output_scale (float): scale for output Tensor
            output_zero_point (int): zero point for output Tensor
        Nr6   )r+   r,   r-   r.   r   rC   r0   r1   r/   r2   rX   r3   r4   r   rJ   rI   r7   rb   r8   )r   
ref_qconvtr   r   r   rL   r   r   r   r   I  s&   	

z_ConvTransposeNd.from_reference)NNr   )r   r   r   r   rt   rv   modulesconvr%   r   r5   rF   rb   r   r   r   r   r   r   r   r   rN   r   r     s$   
 &
	.r   c                          e Zd ZU dZejZeeej  e	d< 									d fdd		Z
d
d Zdejdeej ddfddZdd Zdd Zdd Zdd Zedd Z  ZS )r   a  Applies a 1D transposed convolution operator over an input image
    composed of several input planes.
    For details on input arguments, parameters, and implementation see
    :class:`~torch.nn.ConvTranspose1d`.

    .. note:: Currently only the QNNPACK engine is implemented.
        Please, set the `torch.backends.quantized.engine = 'qnnpack'`

    For special notes, please, see :class:`~torch.ao.nn.quantized.Conv1d`

    Attributes:
        weight (Tensor):     packed tensor derived from the learnable weight
                             parameter.
        scale (Tensor):      scalar for the output scale
        zero_point (Tensor): scalar for the output zero point
    See :class:`~torch.nn.ConvTranspose2d` for other attributes.

    Examples::

        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
        >>> torch.backends.quantized.engine = 'qnnpack'
        >>> from torch.ao.nn import quantized as nnq
        >>> # With square kernels and equal stride
        >>> m = nnq.ConvTranspose1d(16, 33, 3, stride=2)
        >>> # non-square kernels and unequal stride and with padding
        >>> m = nnq.ConvTranspose1d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
        >>> input = torch.randn(20, 16, 50)
        >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
        >>> output = m(q_input)
        >>> # exact output size can be also specified as an argument
        >>> input = torch.randn(1, 16, 12)
        >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
        >>> downsample = nnq.Conv1d(16, 16, 3, stride=2, padding=1)
        >>> upsample = nnq.ConvTranspose1d(16, 16, 3, stride=2, padding=1)
        >>> h = downsample(q_input)
        >>> h.size()
        torch.Size([1, 16, 6])
        >>> # xdoctest: +SKIP("FIXME: output_size is not a parameter)
        >>> output = upsample(h, output_size=input.size())
        >>> output.size()
        torch.Size([1, 16, 12])
    r   r
   r   Tr   Nc                    ^   ||d}t |}t |}t |}t |	}	t |}t j||||||	d||||
fi | d S Nr6   T)r   r@   r5   r*   r+   r,   r-   r.   r   rC   r0   r1   r/   r2   r3   r4   rK   rN   r   r   r5     *   

zConvTranspose1d.__init__c                 C   r   )NZQuantizedConvTranspose1dr   rR   r   r   r   r     rQ   zConvTranspose1d._get_namer\   r]   r   c              	   C   *   t jj||| j| j| j| j| j| _	d S r'   )
rE   r   r   Zconv_transpose1d_prepackr.   r   rC   r/   r0   r   r_   r   r   r   rJ        
zConvTranspose1d.set_weight_biasc                 C   r   r'   )rE   r   r   Zconv_transpose1d_unpackr   r_   r   r   r   rS     r   zConvTranspose1d._weight_biasc                 C      |   \}}|S r'   r   r*   r\   r   r   r   r   rX        zConvTranspose1d.weightc                 C      |   \}}|S r'   r   r*   r   r]   r   r   r   r1     r   zConvTranspose1d.biasc                 C   s0   t |jdkrtdtjj|| j| j| j	S )Nrf   r   )
r    r   rA   rE   r   r   Zconv_transpose1dr   r7   r8   r*   r   r   r   r   r     s
   zConvTranspose1d.forwardc                 C      t | |||S r'   r   r   r   r   r   r   r   r   r   r        zConvTranspose1d.from_reference	r
   r   r   r
   Tr
   r   NN)r   r   r   r   rv   r   r   r   rt   r   r5   r   rE   r   r   rJ   rS   rX   r1   r   r   r   r   r   r   rN   r   r   g  *   
 +%	r   c                       r   )r   a~  Applies a 2D transposed convolution operator over an input image
    composed of several input planes.
    For details on input arguments, parameters, and implementation see
    :class:`~torch.nn.ConvTranspose2d`.

    For special notes, please, see :class:`~torch.ao.nn.quantized.Conv2d`

    Attributes:
        weight (Tensor):     packed tensor derived from the learnable weight
                             parameter.
        scale (Tensor):      scalar for the output scale
        zero_point (Tensor): scalar for the output zero point
    See :class:`~torch.nn.ConvTranspose2d` for other attributes.

    Examples::

        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
        >>> # QNNPACK or FBGEMM as backend
        >>> torch.backends.quantized.engine = 'qnnpack'
        >>> # With square kernels and equal stride
        >>> import torch.ao.nn.quantized as nnq
        >>> m = nnq.ConvTranspose2d(16, 33, 3, stride=2)
        >>> # non-square kernels and unequal stride and with padding
        >>> m = nnq.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
        >>> input = torch.randn(20, 16, 50, 100)
        >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
        >>> output = m(q_input)
        >>> # exact output size can be also specified as an argument
        >>> input = torch.randn(1, 16, 12, 12)
        >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
        >>> downsample = nnq.Conv2d(16, 16, 3, stride=2, padding=1)
        >>> upsample = nnq.ConvTranspose2d(16, 16, 3, stride=2, padding=1)
        >>> h = downsample(q_input)
        >>> h.size()
        torch.Size([1, 16, 6, 6])
        >>> # xdoctest: +SKIP("FIXME: output_size is not a parameter)
        >>> output = upsample(h, output_size=input.size())
        >>> output.size()
        torch.Size([1, 16, 12, 12])
    r   r
   r   Tr   Nc                    r   r   )r   r@   r5   r   rN   r   r   r5     r   zConvTranspose2d.__init__c                 C   r   )NZQuantizedConvTranspose2dr   rR   r   r   r   r   5  rQ   zConvTranspose2d._get_namer\   r]   r   c              	   C   r   r'   )
rE   r   r   Zconv_transpose2d_prepackr.   r   rC   r/   r0   r   r_   r   r   r   rJ   8  r   zConvTranspose2d.set_weight_biasc                 C   r   r'   )rE   r   r   Zconv2d_unpackr   r_   r   r   r   rS   C  r   zConvTranspose2d._weight_biasc                 C   r   r'   r   r   r   r   r   rX   G  r   zConvTranspose2d.weightc                 C   r   r'   r   r   r   r   r   r1   K  r   zConvTranspose2d.biasc                 C   .   t |jdkrtdtj|| j| j| jS )Nrg   r   )	r    r   rA   r   r   Zconv_transpose2dr   r7   r8   r   r   r   r   r   O  
   zConvTranspose2d.forwardc                 C   r   r'   r   r   r   r   r   r   X  r   zConvTranspose2d.from_referencer   )r   r   r   r   rv   r   r   r   rt   r   r5   r   rE   r   r   rJ   rS   rX   r1   r   r   r   r   r   r   rN   r   r     s*   
 )%	r   c                       r   )r   a  Applies a 3D transposed convolution operator over an input image
    composed of several input planes.
    For details on input arguments, parameters, and implementation see
    :class:`~torch.nn.ConvTranspose3d`.

    .. note:: Currently only the FBGEMM engine is implemented.
        Please, set the `torch.backends.quantized.engine = 'fbgemm'`

    For special notes, please, see :class:`~torch.ao.nn.quantized.Conv3d`

    Attributes:
        weight (Tensor):     packed tensor derived from the learnable weight
                             parameter.
        scale (Tensor):      scalar for the output scale
        zero_point (Tensor): scalar for the output zero point
    See :class:`~torch.nn.ConvTranspose3d` for other attributes.

    Examples::

        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
        >>> torch.backends.quantized.engine = 'fbgemm'
        >>> from torch.ao.nn import quantized as nnq
        >>> # With cubic kernels and equal stride
        >>> m = nnq.ConvTranspose3d(16, 33, 3, stride=2)
        >>> # non-cubic kernels and unequal stride and with padding
        >>> m = nnq.ConvTranspose3d(16, 33, (3, 3, 5), stride=(2, 1, 1), padding=(4, 2, 2))
        >>> input = torch.randn(20, 16, 50, 100, 100)
        >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
        >>> output = m(q_input)
        >>> # exact output size can be also specified as an argument
        >>> input = torch.randn(1, 16, 12, 12, 12)
        >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
        >>> downsample = nnq.Conv3d(16, 16, 3, stride=2, padding=1)
        >>> upsample = nnq.ConvTranspose3d(16, 16, 3, stride=2, padding=1)
        >>> h = downsample(q_input)
        >>> h.size()
        torch.Size([1, 16, 6, 6, 6])
        >>> # xdoctest: +SKIP("FIXME: output_size is not a parameter)
        >>> output = upsample(h, output_size=input.size())
        >>> output.size()
        torch.Size([1, 16, 12, 12, 12])
    r   r
   r   Tr   Nc                    r   r   )r   r@   r5   r   rN   r   r   r5     r   zConvTranspose3d.__init__c                 C   r   )NZQuantizedConvTranspose3dr   rR   r   r   r   r     rQ   zConvTranspose3d._get_namer\   r]   r   c              	   C   r   r'   )
rE   r   r   Zconv_transpose3d_prepackr.   r   rC   r/   r0   r   r_   r   r   r   rJ     r   zConvTranspose3d.set_weight_biasc                 C   r   r'   )rE   r   r   Zconv3d_unpackr   r_   r   r   r   rS     r   zConvTranspose3d._weight_biasc                 C   r   r'   r   r   r   r   r   rX     r   zConvTranspose3d.weightc                 C   r   r'   r   r   r   r   r   r1     r   zConvTranspose3d.biasc                 C   r   )Nrh   z&Input shape must be `(N, C, T, H, W)`!)	r    r   rA   r   r   Zconv_transpose3dr   r7   r8   r   r   r   r   r     r   zConvTranspose3d.forwardc                 C   r   r'   r   r   r   r   r   r     r   zConvTranspose3d.from_referencer   )r   r   r   r   rv   r   r   r   rt   r   r5   r   rE   r   r   rJ   rS   rX   r1   r   r   r   r   r   r   rN   r   r   _  r   r   )+r   typingr   r   rE   Ztorch.ao.nn.intrinsicZaorv   Z	intrinsicr   Ztorch.ao.nn.intrinsic.qatZqatr   Ztorch.nnZtorch.nn.functionalZ
functionalr   Z
torch._opsr   Ztorch.nn.common_typesr   Ztorch.nn.modules.utilsr   r   r   Ztorch.nn.utilsr	   utilsr   r   __all__rD   rF   rb   r$   r%   r   r   r   r   r   r   r   r   r   r   r   <module>   s8   	  <   }{