o
    Zhs                  	   @   s4  d Z ddlZddlZddlmZmZ ddlZddlZddl	Zddlm
Z
mZ ddlmZmZmZ ddlmZ ddlmZmZmZmZ dd	lmZ dd
lmZmZ ddlmZ ddlmZ e e!Z"d>deee#f fddZ$G dd dej%Z&G dd dej'Z(G dd dej)Z*G dd dej+Z,G dd dej)Z-d?dej
de.d e#dej
fd!d"Z/G d#d$ d$ej)Z0d@d&d'Z1G d(d) d)ej)Z2G d*d+ d+ej)Z3G d,d- d-ej)Z4G d.d/ d/ej)Z5G d0d1 d1ej)Z6eG d2d3 d3eZ7eG d4d5 d5e7Z8ed6d7G d8d9 d9e7Z9ed:d7G d;d< d<e7eZ:g d=Z;dS )Az9PyTorch BiT model. Also supports backbone for ViT hybrid.    N)OptionalTuple)Tensornn)BCEWithLogitsLossCrossEntropyLossMSELoss   )ACT2FN)BackboneOutputBaseModelOutputWithNoAttention(BaseModelOutputWithPoolingAndNoAttention$ImageClassifierOutputWithNoAttention)PreTrainedModel)auto_docstringlogging)BackboneMixin   )	BitConfig   returnc                 C   s   d}| du r|d ||d   d } | |fS t | tr_|  } | dkrI|dkrA||d  d dkrA|d ||d   d } | |fS d} d}| |fS | dkrSd} | |fS |d ||d   d } | |fS )	al  
    Utility function to get the tuple padding value given the kernel_size and padding.

    Args:
        padding (Union[`str`, `int`], *optional*):
            Padding value, can be either `"same"`, `"valid"`. If a different value is provided the default padding from
            PyTorch is used.
        kernel_size (`int`, *optional*, defaults to 7):
            Kernel size of the convolution layers.
        stride (`int`, *optional*, defaults to 1):
            Stride value of the convolution layers.
        dilation (`int`, *optional*, defaults to 1):
            Dilation value of the convolution layers.
    FNr      Zsamer   TZvalid)
isinstancestrlower)paddingkernel_sizestridedilationZdynamic r   S/var/www/auris/lib/python3.10/site-packages/transformers/models/bit/modeling_bit.pyget_padding_value+   s$   
r!   c                       s6   e Zd ZdZ						d
 fdd	Zdd	 Z  ZS )WeightStandardizedConv2dzConv2d with Weight Standardization. Includes TensorFlow compatible SAME padding. Used for ViT Hybrid model.

    Paper: [Micro-Batch Training with Batch-Channel Normalization and Weight
    Standardization](https://arxiv.org/abs/1903.10520v2)
    r   SAMEFư>c
              
      sT   t ||||d\}}
t j||||||||d |
r"t|||| _nd | _|	| _d S )N)r   r   )r   r   r   groupsbias)r!   super__init__DynamicPad2dpadeps)selfZ
in_channelout_channelsr   r   r   r   r%   r&   r+   Z
is_dynamic	__class__r   r    r(   [   s   

z!WeightStandardizedConv2d.__init__c              	   C   sj   | j d ur
|  |}tjj| jd| jdd d dd| jd| j}tj	||| j
| j| j| j| j}|S )Nr   T        )trainingZmomentumr+   )r*   r   
functionalZ
batch_normweightZreshaper-   r+   Z
reshape_asZconv2dr&   r   r   r   r%   )r,   hidden_stater4   r   r   r    forwardx   s   

z WeightStandardizedConv2d.forward)r   r#   r   r   Fr$   __name__
__module____qualname____doc__r(   r6   __classcell__r   r   r.   r    r"   T   s    r"   c                       s*   e Zd ZdZd fdd	Zdd Z  ZS )	BitGroupNormActivationzQ
    A module that combines group normalization with an activation function.
    h㈵>Tc                    s<   t t| j|j|||d |rt|j | _d S t | _d S )N)r+   affine)	r'   r=   r(   
num_groupsr
   
hidden_act
activationr   Identity)r,   confignum_channelsr+   r?   apply_activationr.   r   r    r(      s   zBitGroupNormActivation.__init__c                 C   s*   t j|| j| j| j| j}| |}|S N)r   r3   Z
group_normr@   r4   r&   r+   rB   )r,   r5   r   r   r    r6      s   
zBitGroupNormActivation.forward)r>   TTr7   r   r   r.   r    r=      s    r=   c                       s*   e Zd ZdZd fdd	Zdd Z  ZS )r)   z
    A module that wraps dynamic padding of any input, given the parameters of the convolutional layer and the input
    hidden states.
    r   c                    sj   t    t|tr||f}t|tr||f}t|tr ||f}|| _|| _|| _|| _dd }|| _d S )Nc                 S   s0   t t| | d | |d |  d |  dS )Nr   r   )maxmathceil)xr   r   r   r   r   r    compute_padding   s   0z.DynamicPad2d.__init__.<locals>.compute_padding)	r'   r(   r   intr   r   r   valuerL   )r,   r   r   r   rN   rL   r.   r   r    r(      s   




zDynamicPad2d.__init__c                 C   s   |  dd  \}}| || jd | jd | jd }| || jd | jd | jd }|dks4|dkrNtjj||d ||d  |d ||d  g| jd}|S )Nr   r   r   )rN   )	sizerL   r   r   r   r   r3   r*   rN   )r,   inputZinput_heightZinput_widthpadding_heightpadding_widthr   r   r    r6      s   ""


zDynamicPad2d.forward)r   r7   r   r   r.   r    r)      s    r)   c                       s<   e Zd ZdZ						ddef fd	d
Zdd Z  ZS )BitMaxPool2dz1Tensorflow like 'SAME' wrapper for 2D max poolingNr   Fr   r   r   Tr   c                    s   t |tjjr	|n||f}t |tjjr|n||f}t |tjjr#|n||f}t ||||| |r=t||||| _d S t	 | _d S rG   )
r   collectionsabcIterabler'   r(   r)   r*   r   rC   )r,   r   r   r   	ceil_moder   Zpadding_valueuse_dynamic_paddingr.   r   r    r(      s   
zBitMaxPool2d.__init__c                 C   s*   |  |}tj|| j| j| j| j| jS rG   )	r*   r   r3   Z
max_pool2dr   r   r   r   rY   r,   hidden_statesr   r   r    r6      s   
zBitMaxPool2d.forward)Nr   FrU   r   T)r8   r9   r:   r;   rM   r(   r6   r<   r   r   r.   r    rT      s    rT   c                       s8   e Zd ZdZdef fddZdedefddZ  ZS )	BitEmbeddingszL
    BiT Embeddings (stem) composed of a single aggressive convolution.
    rD   c                    s   t    t|j|jddd|jd| _tdd|jd| _	|jd ur.|j
 dkr.t | _ntjdd	d
| _|jdksDt||jd| _nt | _|j| _d S )Nr   r   :0yE>)r   r   r+   r   r	   )r   r   rZ   r#   )r   r   r   r   r1   )r   rN   preactivationrE   )r'   r(   r"   rE   embedding_sizeglobal_paddingconvolutionrT   Zembedding_dynamic_paddingpoolerupperr   rC   r*   ZConstantPad2d
layer_typer=   normr,   rD   r.   r   r    r(      s"   
	

zBitEmbeddings.__init__pixel_valuesr   c                 C   sH   |j d }|| jkrtd| |}| |}| |}| |}|S )Nr   zeMake sure that the channel dimension of the pixel values match with the one set in the configuration.)shaperE   
ValueErrorrc   r*   rg   rd   )r,   ri   rE   Z	embeddingr   r   r    r6     s   





zBitEmbeddings.forward)	r8   r9   r:   r;   r   r(   r   r6   r<   r   r   r.   r    r]      s    r]   r1   FrQ   	drop_probr2   c                 C   sd   |dks|s| S d| }| j d fd| jd   }|tj|| j| jd }|  | || }|S )aF  
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

    Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
    however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
    layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
    argument.
    r1   r   r   )r   )dtypedevice)rj   ndimtorchZrandrm   rn   Zfloor_div)rQ   rl   r2   Z	keep_probrj   Zrandom_tensoroutputr   r   r    	drop_path  s   
rs   c                       sT   e Zd ZdZddee ddf fddZdejdejfdd	Z	de
fd
dZ  ZS )BitDropPathzXDrop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).Nrl   r   c                    s   t    || _d S rG   )r'   r(   rl   )r,   rl   r.   r   r    r(   1  s   

zBitDropPath.__init__r\   c                 C   s   t || j| jS rG   )rs   rl   r2   r[   r   r   r    r6   5     zBitDropPath.forwardc                 C   s   d | jS )Nzp={})formatrl   )r,   r   r   r    
extra_repr8  s   zBitDropPath.extra_reprrG   )r8   r9   r:   r;   r   floatr(   rp   r   r6   r   rw   r<   r   r   r.   r    rt   .  s
    rt      c                 C   s:   |}t |t| |d  | | }|d|  k r||7 }|S )Nr   g?)rH   rM   )rN   ZdivisorZ	min_value	new_valuer   r   r    make_div<  s
   r{   c                       :   e Zd ZdZ								d fdd	Zd	d
 Z  ZS )BitPreActivationBottleneckLayera  Pre-activation (v2) bottleneck block.
    Follows the implementation of "Identity Mappings in Deep Residual Networks":
    https://github.com/KaimingHe/resnet-1k-layers/blob/master/resnet-pre-act.lua

    Except it puts the stride on 3x3 conv when available.
    N      ?r   r1   Fc              	      s   t    |p|}|p|}t|| }|
r t||||dd| _nd | _t||| _t||dd|jd| _	t||d| _
t||d||d|jd| _t||| _t||dd|jd| _|	d	krdt|	| _d S t | _d S )
NTr   preactr   r^   r+   r   r`   r	   )r   r%   r+   r   r   )r'   r(   r{   BitDownsampleConv
downsampler=   norm1r"   rb   conv1norm2conv2norm3conv3rt   r   rC   rs   )r,   rD   in_channelsr-   bottle_ratior   r   first_dilationr%   drop_path_rateis_first_layerZmid_channelsr.   r   r    r(   L  s,   

$z(BitPreActivationBottleneckLayer.__init__c                 C   s^   |  |}|}| jd ur| |}| |}| | |}| | |}| |}|| S rG   )r   r   r   r   r   r   r   rs   )r,   r\   Zhidden_states_preactshortcutr   r   r    r6   x  s   




z'BitPreActivationBottleneckLayer.forwardNr~   r   r   Nr   r1   Fr7   r   r   r.   r    r}   D  s    ,r}   c                       r|   )BitBottleneckLayerz\Non Pre-activation bottleneck block, equivalent to V1.5/V1b bottleneck. Used for ViT Hybrid.Nr~   r   r1   Fc              
      s   t    |p|}|p|}t|| }|
r t||||dd| _nd | _t||dd|jd| _t||d| _	t||d|||d|jd| _
t||d| _t||dd|jd| _t||dd	| _|	d
kret|	nt | _t|j | _d S )NFr   r   r^   r   r`   r	   )r   r   r%   r+   r   rE   rF   r   )r'   r(   r{   r   r   r"   rb   r   r=   r   r   r   r   r   rt   r   rC   rs   r
   rA   rB   )r,   rD   r   r-   r   r   r   r   r%   r   r   Zmid_chsr.   r   r    r(     s<   


zBitBottleneckLayer.__init__c                 C   sp   |}| j d ur|  |}| |}| |}| |}| |}| |}| |}| |}| || }|S rG   )	r   r   r   r   r   r   r   rs   rB   )r,   r\   r   r   r   r    r6     s   








zBitBottleneckLayer.forwardr   r7   r   r   r.   r    r     s    1r   c                       s*   e Zd Z		d fdd	Zdd Z  ZS )r   r   Tc                    sH   t    t||d|d|jd| _|rt | _d S t||dd| _d S )Nr   r^   )r   r+   r   Fr   )	r'   r(   r"   rb   convr   rC   r=   rg   )r,   rD   r   r-   r   r   r.   r   r    r(     s   
zBitDownsampleConv.__init__c                 C   s   |  | |S rG   )rg   r   )r,   rK   r   r   r    r6     ru   zBitDownsampleConv.forward)r   T)r8   r9   r:   r(   r6   r<   r   r   r.   r    r     s
    r   c                       s@   e Zd ZdZ		d fdd	Zdd Zded	efd
dZ  ZS )BitStagez7
    A ResNet v2 stage composed by stacked layers.
    r~   Nc	                    s   t    |dv rdnd}	|jdkrt}
nt}
|}t | _t|D ]$}| 	|||\}}}| j
t||
|||||||	||d	 |}|}	q"d S )N)r   r   r   r   Z
bottleneck)r   r   r   r   r   r   )r'   r(   rf   r   r}   r   
Sequentiallayersrange_get_updated_hyperparameters
add_moduler   )r,   rD   r   r-   r   r   depthr   layer_dropoutr   Z	layer_clsprev_chs	layer_idxr   r   r.   r   r    r(     s8   



zBitStage.__init__c                 C   s0   |r|| }nd}|dkrd}|dk}|||fS )zt
        Get the new hyper-parameters with respect to the previous ones and the index of the current layer.
        r1   r   r   r   )r,   r   r   r   r   r   r   r   r    r     s   

z%BitStage._get_updated_hyperparametersrQ   r   c                 C   s$   |}t | jD ]\}}||}q|S rG   )	enumerater   )r,   rQ   r5   _layerr   r   r    r6   +  s   
zBitStage.forward)r~   N)	r8   r9   r:   r;   r(   r   r   r6   r<   r   r   r.   r    r     s    .r   c                	       sH   e Zd Zdef fddZdd Z	dded	ed
edefddZ	  Z
S )
BitEncoderrD   c              
      s   t    tg | _|j}d}d}dd tt	d|j
t|j|jD }tt|j|j|D ]-\}\}}}	| |||||\}
}}t|||
||||	d}|
}||9 }| jt|| q3d S )N   r   c                 S   s   g | ]}|  qS r   )tolist).0rK   r   r   r    
<listcomp>=  s    z'BitEncoder.__init__.<locals>.<listcomp>r   )r   r   r   r   )r'   r(   r   Z
ModuleListstagesra   rp   r   npZlinspacer   sumZdepthssplitr   ziphidden_sizesr   r   r   r   )r,   rD   r   current_strider   Zlayer_dropouts	stage_idxZcurrent_depthcurrent_hidden_sizer   r-   r   stager.   r   r    r(   3  s6   
"


zBitEncoder.__init__c                 C   s>   t ||j }|dkrdnd}||jkr||9 }d}|||fS )Nr   r   r   )r{   Zwidth_factorZoutput_stride)r,   r   r   r   r   rD   r-   r   r   r   r    r   Y  s   

z'BitEncoder._get_updated_hyperparametersFTr5   output_hidden_statesreturn_dictr   c                 C   sb   |rdnd }| j D ]}|r||f }||}q	|r||f }|s+tdd ||fD S t||dS )Nr   c                 s   s    | ]	}|d ur|V  qd S rG   r   )r   vr   r   r    	<genexpr>p  s    z%BitEncoder.forward.<locals>.<genexpr>)last_hidden_stater\   )r   tupler   )r,   r5   r   r   r\   Zstage_moduler   r   r    r6   a  s   



zBitEncoder.forward)FT)r8   r9   r:   r   r(   r   r   boolr   r6   r<   r   r   r.   r    r   2  s    &	r   c                   @   s&   e Zd ZeZdZdZdgZdd ZdS )BitPreTrainedModelbitri   r]   c                 C   s   t |tjrtjj|jddd d S t |tjrMtjj|jt	dd |j
d urKtj|j\}}|dkr=dt	| nd}tj|j
| | d S d S t |tjtjfrhtj|jd tj|j
d d S d S )NZfan_outZrelu)modeZnonlinearity   )ar   r   )r   r   Conv2dinitZkaiming_normal_r4   LinearZkaiming_uniform_rI   sqrtr&   Z_calculate_fan_in_and_fan_outZuniform_ZBatchNorm2d	GroupNormZ	constant_)r,   moduleZfan_inr   boundr   r   r    _init_weights  s   
z BitPreTrainedModel._init_weightsN)	r8   r9   r:   r   Zconfig_classZbase_model_prefixZmain_input_nameZ_no_split_modulesr   r   r   r   r    r   x  s    r   c                
       F   e Zd Z fddZe	d
dedee dee defdd	Z	  Z
S )BitModelc                    sd   t  | || _t|| _t|| _|jdkr!t||j	d dnt
 | _t
d| _|   d S )Nr_   r0   r`   )r   r   )r'   r(   rD   r]   embedderr   encoderrf   r=   r   r   rC   rg   ZAdaptiveAvgPool2drd   	post_initrh   r.   r   r    r(     s   


zBitModel.__init__Nri   r   r   r   c                 C   s   |d ur|n| j j}|d ur|n| j j}| |}| j|||d}|d }| |}| |}|s;||f|dd   S t|||jdS )Nr   r   r   r   )r   pooler_outputr\   )	rD   r   use_return_dictr   r   rg   rd   r   r\   )r,   ri   r   r   Zembedding_outputZencoder_outputsr   pooled_outputr   r   r    r6     s"   


zBitModel.forwardNN)r8   r9   r:   r(   r   r   r   r   r   r6   r<   r   r   r.   r    r     s    r   z
    BiT Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
    ImageNet.
    )Zcustom_introc                       s\   e Zd Z fddZe				ddeej deej dee	 dee	 de
f
d	d
Z  ZS )BitForImageClassificationc                    s^   t  | |j| _t|| _tt |jdkr#t|j	d |jnt
 | _|   d S )Nr   r0   )r'   r(   
num_labelsr   r   r   r   ZFlattenr   r   rC   
classifierr   rh   r.   r   r    r(     s   
$z"BitForImageClassification.__init__Nri   labelsr   r   r   c                 C   sb  |dur|n| j j}| j|||d}|r|jn|d }| |}d}|dur| j jdu rP| jdkr6d| j _n| jdkrL|jtj	ksG|jtj
krLd| j _nd| j _| j jdkrnt }	| jdkrh|	| | }n+|	||}n%| j jdkrt }	|	|d| j|d}n| j jdkrt }	|	||}|s|f|dd  }
|dur|f|
 S |
S t|||jd	S )
a0  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        Nr   r   Z
regressionZsingle_label_classificationZmulti_label_classificationr0   r   )losslogitsr\   )rD   r   r   r   r   Zproblem_typer   rm   rp   longrM   r   Zsqueezer   viewr   r   r\   )r,   ri   r   r   r   outputsr   r   r   Zloss_fctrr   r   r   r    r6     s6   


"


z!BitForImageClassification.forward)NNNN)r8   r9   r:   r(   r   r   rp   ZFloatTensorZ
LongTensorr   r   r6   r<   r   r   r.   r    r     s$    r   zL
    BiT backbone, to be used with frameworks like DETR and MaskFormer.
    c                
       r   )BitBackbonec                    s>   t  | t  | t|| _|jg|j | _|   d S rG   )	r'   r(   Z_init_backboner   r   ra   r   Znum_featuresr   rh   r.   r   r    r(     s
   
zBitBackbone.__init__Nri   r   r   r   c           
      C   s   |dur|n| j j}|dur|n| j j}| j|ddd}|j}d}t| jD ]\}}|| jv r6||| f7 }q&|sF|f}	|rD|	|jf7 }	|	S t||rP|jddS dddS )aN  
        Examples:

        ```python
        >>> from transformers import AutoImageProcessor, AutoBackbone
        >>> import torch
        >>> from PIL import Image
        >>> import requests

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> processor = AutoImageProcessor.from_pretrained("google/bit-50")
        >>> model = AutoBackbone.from_pretrained("google/bit-50")

        >>> inputs = processor(image, return_tensors="pt")
        >>> outputs = model(**inputs)
        ```NTr   r   )feature_mapsr\   Z
attentions)	rD   r   r   r   r\   r   Zstage_namesZout_featuresr   )
r,   ri   r   r   r   r\   r   idxr   rr   r   r   r    r6     s.   
zBitBackbone.forwardr   )r8   r9   r:   r(   r   r   r   r   r   r6   r<   r   r   r.   r    r     s    
r   )r   r   r   r   )Nr   r   r   )r1   F)ry   )<r;   rV   rI   typingr   r   numpyr   rp   Ztorch.utils.checkpointr   r   Ztorch.nnr   r   r   Zactivationsr
   Zmodeling_outputsr   r   r   r   Zmodeling_utilsr   utilsr   r   Zutils.backbone_utilsr   Zconfiguration_bitr   Z
get_loggerr8   loggerr   r!   r   r"   r   r=   Moduler)   Z	MaxPool2drT   r]   rx   rs   rt   r{   r}   r   r   r   r   r   r   r   r   __all__r   r   r   r    <module>   sV   
)03 3
DIJF1@<