o
    ZŽh|u  ã                   @   s  d dl Z d dlZd dlZd dlmZ d dlmZ ddlmZ e 	e
¡ZG dd„ dejƒZG dd	„ d	ejƒZG d
d„ dejƒZG dd„ dejƒZG dd„ dejƒZG dd„ dejƒZd$dd„Zd$dd„Zd$dd„ZG dd„ deƒZG dd„ deƒZG dd„ deƒZd%d d!„ZG d"d#„ d#eƒZdS )&é    N)Únn)ÚFunctioné   )Úloggingc                       s>   e Zd ZdZ									d‡ fdd„	Zdd	d
„Z‡  ZS )ÚQuantEmbeddingaÞ  
    Quantized version of `torch.nn.Embedding`. Adds quantization-specific arguments on top of `torch.nn.Embedding`.

    Args:
        weight_bit (`int`, *optional*, defaults to `8`):
            Bitwidth for the quantized weight.
        momentum (`float`, *optional*, defaults to `0.95`):
            Momentum for updating the activation quantization range.
        quant_mode (`bool`, *optional*, defaults to `False`):
            Whether or not the layer is quantized.
    Nç       @Fé   çffffffî?c                    s”   t ƒ  ¡  || _|| _|| _|| _|| _|| _|| _t	 
t ||g¡¡| _|  dt d¡¡ |  dt | j¡¡ |	| _|
| _|| _d| _tj| _d S )NÚweight_scaling_factoré   Úweight_integerF)ÚsuperÚ__init__Znum_ÚdimÚpadding_idxÚmax_normÚ	norm_typeÚscale_grad_by_freqÚsparser   Ú	ParameterÚtorchÚzerosÚweightÚregister_bufferÚ
zeros_likeÚ
weight_bitÚmomentumÚ
quant_modeÚpercentile_modeÚSymmetricQuantFunctionÚapplyÚweight_function)ÚselfZnum_embeddingsZembedding_dimr   r   r   r   r   Z_weightr   r   r   ©Ú	__class__© úV/var/www/auris/lib/python3.10/site-packages/transformers/models/ibert/quant_modules.pyr   ,   s    
zQuantEmbedding.__init__c           	   	   C   sº   | j stj || j| j| j| j| j| j	¡d fS | j}|j
 ¡ }| ¡  d¡}| ¡  d¡}t| j||dƒ| _|  | j| j| j| j¡| _tj || j| j| j| j| j| j	¡}|| j | jfS )Nr   F)r   r   Ú
functionalZ	embeddingr   r   r   r   r   r   ÚdataÚdetachÚminÚexpandÚmaxÚ$symmetric_linear_quantization_paramsr   r
   r!   r   r   )	r"   ÚxZ	positionsZincremental_stateÚwÚw_transformÚw_minÚw_maxZemb_intr%   r%   r&   ÚforwardM   s<   ù	ö
ÿù	zQuantEmbedding.forward)	NNr   FFNr   r	   F©NN)Ú__name__Ú
__module__Ú__qualname__Ú__doc__r   r3   Ú__classcell__r%   r%   r#   r&   r      s    ô!r   c                       s>   e Zd ZdZd‡ fdd„	Zdd„ Z					dd	d
„Z‡  ZS )ÚQuantActap  
    Quantizes the given activation.

    Args:
        activation_bit (`int`):
            Bitwidth for the quantized activation.
        act_range_momentum (`float`, *optional*, defaults to `0.95`):
            Momentum for updating the activation quantization range.
        per_channel (`bool`, *optional*, defaults to `False`):
            Whether to or not use channel-wise quantization.
        channel_len (`int`, *optional*):
            Specify the channel length when set the *per_channel* True.
        quant_mode (`bool`, *optional*, defaults to `False`):
            Whether or not the layer is quantized.
    r	   FNc                    s”   t ƒ  ¡  || _|| _|| _|| _d| _tj| _	| jsF|  
dt d¡¡ |  
dt d¡¡ |  
dt d¡¡ |  jd8  _|  jd7  _d S tdƒ‚)NFÚx_minr   Úx_maxÚact_scaling_factorgñhãˆµøä>ú;per-channel mode is not currently supported for activation.)r   r   Úactivation_bitÚact_range_momentumr   Úper_channelÚ
percentiler   r    Úact_functionr   r   r   r;   r<   ÚNotImplementedError)r"   r?   r@   rA   Zchannel_lenr   r#   r%   r&   r   ƒ   s   
zQuantAct.__init__c              
   C   s:   | j j› d| j› d| j› d| j ¡ d›d| j ¡ d›d
S )Nz(activation_bit=z, quant_mode: z, Act_min: z.2fz, Act_max: ú))r$   r5   r?   r   r;   Úitemr<   )r"   r%   r%   r&   Ú__repr__–   s   ÿ
ÿ
þÿzQuantAct.__repr__c                 C   s¦  |d u r|n|| }| j r†| jrJ dƒ‚| jrJ dƒ‚|j ¡ }|j ¡ }	|	 ¡  ¡ dkr5| ¡  ¡ dks9J dƒ‚| j ¡ dkrT| j	 ¡ dk rT| j| | _| j	|	 | _	n2| j
dkrjt | j|¡| _t | j	|	¡| _	n| j| j
 |d| j
   | _| j	| j
 |	d| j
   | _	| js|d fS |d u r”| jn|}|d u r| j	n|}	t| j||	| jd	| _|d u rº|  || j| j| j¡}
nt ||| j| j||¡}
| j d¡}|
| | jfS )
Nz:percentile mode is not currently supported for activation.r>   r   z5NaN detected when computing min/max of the activationg¢&ú|”ç¾g¢&ú|”ç>éÿÿÿÿr   )rA   )ÚtrainingrB   rA   r(   r*   r,   ÚisnanÚsumr;   r<   r@   r   r   r-   r?   r=   rC   ÚFixedPointMulr    Úview)r"   r.   Úpre_act_scaling_factorÚidentityÚidentity_scaling_factorZspecified_minZspecified_maxZx_actr;   r<   Zquant_act_intZcorrect_output_scaler%   r%   r&   r3      sH   	

"ÿ
ÿú	zQuantAct.forward)r	   FNF)NNNNN©r5   r6   r7   r8   r   rG   r3   r9   r%   r%   r#   r&   r:   r   s    
ùr:   c                       s:   e Zd ZdZ	d‡ fdd„	Z‡ fdd	„Zddd„Z‡  ZS )ÚQuantLineara8  
    Quantized version of `torch.nn.Linear`. Adds quantization-specific arguments on top of `torch.nn.Linear`.

    Args:
        weight_bit (`int`, *optional*, defaults to `8`):
            Bitwidth for the quantized weight.
        bias_bit (`int`, *optional*, defaults to `32`):
            Bitwidth for the quantized bias.
        per_channel (`bool`, *optional*, defaults to `False`):
            Whether or not to use channel-wise quantization.
        quant_mode (`bool`, *optional*, defaults to `False`):
            Whether or not the layer is quantized.
    Tr   é    Fc                    s®   t ƒ  ¡  || _|| _t t ||g¡¡| _|  	dt 
| j¡¡ |  	dt | j¡¡ |r?t t |¡¡| _|  	dt 
| j¡¡ || _|| _|| _|| _|| _d| _tj| _d S )Nr   Úfc_scaling_factorÚbias_integerF)r   r   Úin_featuresÚout_featuresr   r   r   r   r   r   r   Úbiasr   r   rA   Úbias_bitr   r   r    r!   )r"   rV   rW   rX   r   rY   rA   r   r#   r%   r&   r   ë   s    
zQuantLinear.__init__c                    s*   t ƒ  ¡ }d|› d| j› d| j› d}|S )Nú(z weight_bit=z, quant_mode=rE   )r   rG   r   r   )r"   Úsr#   r%   r&   rG     s   
zQuantLinear.__repr__Nc           
      C   s   | j stjj|| j| jdd fS |d ur|jdksJ dƒ‚| j}|j ¡ }| j	r=t
j|dd d\}}t
j|dd d\}}n| ¡  d¡}| ¡  d¡}t| j||| j	ƒ| _|  | j| j| j| j¡| _| j| }| jd urw|  | j| jd|¡| _| dd¡}|| }	tjj|	| j| jd| |fS )N)r   rX   )r   z«Input activation to the QuantLinear layer should be globally (non-channel-wise) quantized. Please add a QuantAct layer with `per_channel = True` before this QuantAct layerr   )r   ÚoutFrH   )r   r   r'   Zlinearr   rX   Úshaper(   r)   rA   r   r*   r,   r+   r-   r   rT   r!   r   r   rY   rU   rM   )
r"   r.   Zprev_act_scaling_factorr/   r0   r1   Ú_r2   Zbias_scaling_factorÚx_intr%   r%   r&   r3     s0   ÿ
ÿ

þzQuantLinear.forward)Tr   rS   FF©NrQ   r%   r%   r#   r&   rR   Ü   s    ÿrR   c                       s4   e Zd ZdZd‡ fdd„	Zdd„ Zdd	d
„Z‡  ZS )ÚIntGELUa}  
    Quantized version of `torch.nn.GELU`. Adds quantization-specific arguments on top of `torch.nn.GELU`.

    Args:
        quant_mode (`bool`, *optional*, defaults to `False`):
            Whether or not the layer is quantized.
        force_dequant (`str`, *optional*, defaults to `"none"`):
            Force dequantize the layer if either "gelu" or "nonlinear" is given.
    TÚnonec                    sj   t ƒ  ¡  || _|dv rt d¡ d| _| jst ¡ | _d| _d| _	g d¢| _
| j
d  | j
d   < d S )	N)Ú	nonlinearZgeluzForce dequantize geluFgà- ö?é   )g]mÅþ²{Ò¿gçû©ñÒMü¿r   é   r   )r   r   r   ÚloggerÚinfor   ZGELUÚactivation_fnÚkÚconstÚcoeff)r"   r   Úforce_dequantr#   r%   r&   r   7  s   



zIntGELU.__init__c                 C   sš   t  | jd | ¡}t  | jd |d  ¡}t  |¡}t  t  |¡| ¡}||| d |  }|d | jd  }t |d| j  ¡}|d| j  }||fS ©Nr   re   r   )	r   Úfloorrk   Úsignr*   ÚabsÚ	floor_ster    rj   )r"   r_   Úscaling_factorÚb_intÚc_intro   Zabs_intÚy_intr%   r%   r&   Úint_erfG  s   
zIntGELU.int_erfNc                 C   s^   | j s
|  |¡d fS || }|  ||| j ¡\}}d| }|||  }|| d }|| |fS )Nç      ð?re   )r   rh   rv   ri   )r"   r.   rr   r_   Zsigmoid_intZsigmoid_scaling_factorZ	shift_intr%   r%   r&   r3   V  s   zIntGELU.forward)Trb   r`   )r5   r6   r7   r8   r   rv   r3   r9   r%   r%   r#   r&   ra   ,  s
    
ra   c                       s:   e Zd ZdZd‡ fdd„	Zdd„ Zdd	„ Zd
d„ Z‡  ZS )Ú
IntSoftmaxaØ  
    Quantized version of `torch.nn.Softmax`. Adds quantization-specific arguments on top of `torch.nn.Softmax`.

    Args:
        output_bit (`int`):
            Bitwidth for the layer output activation.
        quant_mode (`bool`, *optional*, defaults to `False`):
            Whether or not the layer is quantized.
        force_dequant (`str`, *optional*, defaults to `"none"`):
            Force dequantize the layer if either "softmax" or "nonlinear" is given.
    Frb   c                    sŽ   t ƒ  ¡  || _d| _|| _|dv rt d¡ d| _td| jd| _d| _	d| _
g d	¢| _| jd
  | jd   < | jd  | jd   < d S )NrS   )rc   ÚsoftmaxzForce dequantize softmaxFé   ©r   gvqà-æ¿é   )gN„ª$ôëÖ?g¾Ã'|:ï?rw   r   r   re   )r   r   Ú
output_bitÚmax_bitr   rf   rg   r:   ÚactÚx0rj   Úcoef)r"   r}   r   rl   r#   r%   r&   r   r  s   


zIntSoftmax.__init__c                 C   s~   t  ¡  t  | jd | ¡}t  | jd |d  ¡}W d   ƒ n1 s%w   Y  || | | }| jd |d  }||fS rm   )r   Úno_gradrn   r   )r"   r_   rr   rs   rt   Úzr%   r%   r&   Úint_polynomialƒ  s   
þzIntSoftmax.int_polynomialc                 C   s¬   t  ¡  t  | j| ¡}W d   ƒ n1 sw   Y  t  || j| ¡}t || ¡}|||  }|  ||¡\}}t j	t |d| j|   ¡dd}|d| j  }||fS )Nre   r   ©r*   )
r   r‚   rn   r€   r,   rj   rq   r    r„   Úclamp)r"   r_   rr   Zx0_intÚqÚrÚexp_intÚexp_scaling_factorr%   r%   r&   Úint_exp‹  s   
ÿ"zIntSoftmax.int_expc                 C   s¾   | j stjj|ddd fS || }|jddd\}}|| }|  ||¡\}}|  ||¡\}}|| }|jddd}	t 	d| j
 |	 ¡}
t 	||
 d| j
| j   ¡}dd| j  }|| |fS )NrH   ©r   T)r   Úkeepdimre   r   )r   r   r'   ry   r,   r‹   r   rK   rq   r    r~   r}   )r"   r.   rr   r_   Z	x_int_maxr^   r‰   rŠ   ÚexpZexp_int_sumÚfactorr%   r%   r&   r3   —  s   zIntSoftmax.forward)Frb   )	r5   r6   r7   r8   r   r„   r‹   r3   r9   r%   r%   r#   r&   rx   e  s    rx   c                       s<   e Zd ZdZd‡ fdd„	Zdd„ Zd	d
„ Zddd„Z‡  ZS )ÚIntLayerNormaû  
    Quantized version of `torch.nn.LayerNorm`. Adds quantization-specific arguments on top of `torch.nn.LayerNorm`.

    Args:
        output_bit (`int`, *optional*, defaults to `8`):
            Bitwidth for the layer output activation.
        quant_mode (`bool`, *optional*, defaults to `False`):
            Whether or not the layer is quantized.
        force_dequant (`str`, *optional*, defaults to `"none"`):
            Force dequantize the layer if either "layernorm" or "nonlinear" is given.
    r   Frb   c                    s’   t ƒ  ¡  || _|| _t t |¡¡| _t t |¡¡| _	|| _
|dv r,t d¡ d| _
|  dt d¡¡ || _d| _d | _t| j| j
d| _d S )N)rc   Z	layernormzForce dequantize layernormFÚshiftr   rS   r{   )r   r   Únormalized_shapeÚepsr   r   r   r   r   rX   r   rf   rg   r   r}   r~   Údim_sqrtr:   Z
activation)r"   r’   r“   r}   r   rl   r#   r%   r&   r   ¹  s   

zIntLayerNorm.__init__c                 C   sž   t  ¡ A |d }t j|ddd}t  t  |d| j  ¡¡ ¡  ¡ }| j}t  | j|¡| _t	 
dt|ƒ› dt| jƒ› ¡ W d   ƒ d S 1 sHw   Y  d S )Nre   T©Zaxisr   zDynamic shift adjustment: z -> )r   r‚   rK   Úlog2Úsqrtr~   Úceilr,   r‘   rf   rg   Úint)r"   ru   Úy_sq_intÚvar_intr‘   Z	shift_oldr%   r%   r&   Ú	set_shiftÌ  s   
"""úzIntLayerNorm.set_shiftc                 C   s:   |   |¡ t |d| j  ¡}|d }tj|ddd}|S )z±
        This fallback function is called when overflow is detected during training time, and adjusts the `self.shift`
        to avoid overflow in the subsequent runs.
        re   Tr•   )rœ   rq   r    r‘   r   rK   )r"   ru   Úy_int_shiftedrš   r›   r%   r%   r&   Úoverflow_fallbackÕ  s
   
zIntLayerNorm.overflow_fallbackNc                 C   s¬  | j s.|jddd}|| }tj|d ddd}|t | j| ¡ }|| j | j }|d fS | jd u rHtj|j	d tj
d}t |¡ |j¡| _|| }t |jddd¡}|| }	t |	d| j  ¡}
|
d }tj|ddd}| jr| ¡ d| j kr|  |	¡}| ¡ d| j d k sJ dƒ‚t t |¡¡d| j  }t d| ¡}t |	| d ¡}	| jd }| jj ¡ | jj ¡  }t || ¡}|	| }	|| j }|	| }||fS )	Nre   Tr•   )Zdtypegš™™™™™¹?zfError detected in overflow handling: `var_int` exceeds `self.max_bit` (the maximum possible bit width)l        i   @)r   Úmeanr   r—   r“   r   rX   r”   Útensorr]   ÚfloatÚtoÚdeviceÚ	round_ster    rq   r‘   rK   rI   r,   r~   rž   r(   r)   )r"   r.   rr   rŸ   ÚyÚvarÚnr_   Zmean_intru   r   rš   r›   Zstd_intr   rX   Zbias_intr%   r%   r&   r3   à  s@   

ÿ

zIntLayerNorm.forward)r   Frb   r`   )	r5   r6   r7   r8   r   rœ   rž   r3   r9   r%   r%   r#   r&   r   ¬  s    	r   Fc           	      C   s€   | j d }t|d|d   ƒ}t|| d ƒ}tj| |dj}|dkr(|d }n
tj|  |dj }|s<| ¡ }| ¡ }||fS )aÆ  
    Calculate the percentile max and min values in a given tensor

    Args:
        input (`torch.Tensor`):
            The target tensor to calculate percentile max and min.
        lower_percentile (`float`):
            If 0.1, means we return the value of the smallest 0.1% value in the tensor as percentile min.
        upper_percentile (`float`):
            If 99.9, means we return the value of the largest 0.1% value in the tensor as percentile max.
        output_tensor (`bool`, *optional*, defaults to `False`):
            If True, this function returns tensors, otherwise it returns values.

    Returns:
        `Tuple(torch.Tensor, torch.Tensor)`: Percentile min and max value of *input*
    r   r   g{®Gáz„?)ri   )r]   Úroundr   ZkthvalueÚvaluesrF   )	ÚinputZlower_percentileZupper_percentileZoutput_tensorZinput_lengthZlower_indexZupper_indexÚupper_boundÚlower_boundr%   r%   r&   Úget_percentile_min_max  s   

r­   c                 C   s¢   t | jƒdkr| dddd¡}| dddd¡}nt | jƒdkr,| dd¡}| dd¡}n
| d¡}| d¡}|rF|  d| ¡ |¡ ¡  | S t d| |  | ¡S )a?  
    Quantize single-precision input tensor to integers with the given scaling factor and zeropoint.

    Args:
        input (`torch.Tensor`):
            Single-precision input tensor to be quantized.
        scale (`torch.Tensor`):
            Scaling factor for quantization.
        zero_pint (`torch.Tensor`):
            Shift for quantization.
        inplace (`bool`, *optional*, defaults to `False`):
            Whether to compute inplace or not.

    Returns:
        `torch.Tensor`: Linearly quantized value of *input* according to *scale* and *zero_point*.
    é   rH   r   re   rw   )Úlenr]   rM   Zmul_Zadd_Zround_r   r¨   )rª   ÚscaleÚ
zero_pointÚinplacer%   r%   r&   Úlinear_quantize5  s   

r³   c                 C   s²   t  ¡ K d| d  d }|r-t jt j| ¡ | ¡ gdddd\}}t j|dd| }nt| ¡ | ¡ ƒ}t j|dd| }W d  ƒ |S W d  ƒ |S 1 sRw   Y  |S )a/  
    Compute the scaling factor with the given quantization range for symmetric quantization.

    Args:
        saturation_min (`torch.Tensor`):
            Lower bound for quantization range.
        saturation_max (`torch.Tensor`):
            Upper bound for quantization range.
        per_channel (`bool`, *optional*, defaults to `False`):
            Whether to or not use channel-wise quantization.

    Returns:
        `torch.Tensor`: Scaling factor that linearly quantizes the given range between *saturation_min* and
        *saturation_max*.
    re   r   rŒ   g:Œ0âŽyE>r…   N)r   r‚   r,   Ústackrp   r†   )Znum_bitsZsaturation_minZsaturation_maxrA   r§   r°   r^   r%   r%   r&   r-   X  s   
(
÷ú
ûõr-   c                   @   ó(   e Zd ZdZedd„ ƒZedd„ ƒZdS )r   zw
    Class to quantize the given floating-point values using symmetric quantization with given range and bitwidth.
    c                 C   sN   t jd|jd}d|d  d }t|||dd}t  || |d ¡}|| _|S )a6  
        Args:
            x (`torch.Tensor`):
                Floating point tensor to be quantized.
            k (`int`):
                Quantization bitwidth.
            percentile_mode (`bool`):
                Whether or not to use percentile calibration.
            scale (`torch.Tensor`):
                Pre-calculated scaling factor for *x*. Note that the current implementation of SymmetricQuantFunction
                requires pre-calculated scaling factor.

        Returns:
            `torch.Tensor`: Symmetric-quantized value of *input*.
        g        )r£   re   r   F)r²   )r   r    r£   r³   r†   r°   )Úctxr.   ri   r   r°   r±   r§   Znew_quant_xr%   r%   r&   r3   }  s   zSymmetricQuantFunction.forwardc                 C   sb   | j }t|jƒdkr| dddd¡}nt|jƒdkr!| dd¡}n| d¡}| ¡ | d d d d fS )Nr®   rH   r   re   )r°   r¯   r]   rM   Úclone)r¶   Úgrad_outputr°   r%   r%   r&   Úbackward—  s   
zSymmetricQuantFunction.backwardN©r5   r6   r7   r8   Ústaticmethodr3   r¹   r%   r%   r%   r&   r   x  s    
r   c                   @   rµ   )rq   z;
    Straight-through Estimator(STE) for torch.floor()
    c                 C   ó
   t  |¡S r`   )r   rn   ©r¶   r.   r%   r%   r&   r3   ª  ó   
zfloor_ste.forwardc                 C   ó   |  ¡ S r`   ©r·   ©r¶   r¸   r%   r%   r&   r¹   ®  ó   zfloor_ste.backwardNrº   r%   r%   r%   r&   rq   ¥  ó    
rq   c                   @   rµ   )r¤   z;
    Straight-through Estimator(STE) for torch.round()
    c                 C   r¼   r`   )r   r¨   r½   r%   r%   r&   r3   ¸  r¾   zround_ste.forwardc                 C   r¿   r`   rÀ   rÁ   r%   r%   r&   r¹   ¼  rÂ   zround_ste.backwardNrº   r%   r%   r%   r&   r¤   ³  rÃ   r¤   é   c                 C   s®   |   ¡ }|  d¡} t |  ¡  ¡ ¡\}}g }|D ]}tt |d|  ¡j	t d¡tj
dƒ}| |¡ qt |¡}t|ƒ| }t |¡ | j¡ |¡t |¡ | j¡ |¡fS )zü
    Decompose the scaling factor into mantissa and twos exponent.

    Args:
        scaling_factor (`torch.Tensor`):
            Target scaling factor to decompose.

    Returns:
        ``Tuple(torch.Tensor, torch.Tensor)`: mantisa and exponent
    rH   re   Ú1)Úrounding)ÚsizerM   ÚnpÚfrexpÚcpuÚnumpyr™   ÚdecimalÚDecimalÚquantizeÚROUND_HALF_UPÚappendÚarrayr¡   r   Z
from_numpyr¢   r£   )Zinputsr~   Zshape_of_inputZoutput_mZoutput_eZtmp_mÚmZint_m_shiftedr%   r%   r&   Úbatch_frexpÁ  s   
"ÿ
þrÓ   c                   @   s.   e Zd ZdZe		ddd„ƒZedd„ ƒZdS )rL   aQ  
    Function to perform fixed-point arithmetic that can match integer arithmetic on hardware.

    Args:
        pre_act (`torch.Tensor`):
            Input tensor.
        pre_act_scaling_factor (`torch.Tensor`):
            Scaling factor of the input tensor *pre_act*.
        bit_num (`int`):
            Quantization bitwidth.
        z_scaling_factor (`torch.Tensor`):
            Scaling factor of the output tensor.
        identity (`torch.Tensor`, *optional*):
            Identity tensor, if exists.
        identity_scaling_factor (`torch.Tensor`, *optional*):
            Scaling factor of the identity tensor *identity*, if exists.

    Returns:
        `torch.Tensor`: Output tensor(*pre_act* if *identity* is not given, otherwise the addition of *pre_act* and
        *identity*), whose scale is rescaled to *z_scaling_factor*.
    Nc                 C   s”  t |jƒdkrdd„ }ndd„ }|| _d|d  d }t ¡ ¡ ||ƒ}|d ur,||ƒ}|| _t || ¡}	| tj¡}
| tj	¡ tj¡}|
| }||ƒ}t
|ƒ\}}|	 tj¡| tj¡ }t |d|  ¡}|d ur«t || ¡}| tj¡}
| tj	¡ tj¡}|
| }||ƒ}t
|ƒ\}}| tj¡| tj¡ }t |d|  ¡}|| }t | tj	¡| d |¡W  d   ƒ S 1 sÃw   Y  d S )Nr   c                 S   s   | S r`   r%   ©r.   r%   r%   r&   Ú<lambda>  s    z'FixedPointMul.forward.<locals>.<lambda>c                 S   s   |   ddd¡S )Nr   rH   )rM   rÔ   r%   r%   r&   rÕ     s    re   r   r   )r¯   r]   rO   r   r‚   Úz_scaling_factorr¨   ÚtypeÚdoubler¡   rÓ   r†   )r¶   Zpre_actrN   Zbit_numrÖ   rO   rP   Zreshaper§   Zz_intZ_AZ_BZ	new_scalerÒ   ÚeÚoutputZwx_intÚm1Úe1Zoutput1r%   r%   r&   r3   ú  s<   


$ßzFixedPointMul.forwardc                 C   s8   d }| j d ur| ¡ | j }| ¡ | j d d d d |d fS r`   )rO   r·   rÖ   )r¶   r¸   Zidentity_gradr%   r%   r&   r¹   /  s   
zFixedPointMul.backwardr4   rº   r%   r%   r%   r&   rL   ã  s    ù4rL   )F)rÄ   )rÌ   rË   rÈ   r   r   Ztorch.autogradr   Úutilsr   Z
get_loggerr5   rf   ÚModuler   r:   rR   ra   rx   r   r­   r³   r-   r   rq   r¤   rÓ   rL   r%   r%   r%   r&   Ú<module>   s*   
SjP9G
e
$
# -
"