o
    Zh                     @   s  d Z ddlZddlmZ ddlmZmZmZm	Z	m
Z
mZ ddlZddlZddlmZ ddlmZ ddlmZ dd	lmZmZmZ dd
lmZmZ ddlmZmZ ddlmZmZm Z m!Z! ddl"m#Z# ddl$m%Z% e &e'Z(eG dd deZ)eG dd deZ*G dd dej+Z,G dd dej+Z-G dd dej+Z.	dSdej+dej/dej/dej/deej/ d e0d!e0fd"d#Z1G d$d% d%ej+Z2G d&d' d'ej+Z3G d(d) d)ej+Z4G d*d+ d+ej+Z5G d,d- d-ej+Z6G d.d/ d/ej+Z7G d0d1 d1ej+Z8G d2d3 d3ej+Z9d4d5 Z:G d6d7 d7ej+Z;G d8d9 d9ej+Z<G d:d; d;ej+Z=G d<d= d=ej+Z>eG d>d? d?eZ?eG d@dA dAe?Z@G dBdC dCej+ZAG dDdE dEej+ZBG dFdG dGej+ZCedHdIG dJdK dKe?ZDG dLdM dMej+ZEG dNdO dOej+ZFeG dPdQ dQe?ZGg dRZHdS )TzPyTorch DPT (Dense Prediction Transformers) model.

This implementation is heavily inspired by OpenMMLab's implementation, found here:
https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/decode_heads/dpt_head.py.

    N)	dataclass)CallableListOptionalSetTupleUnion)nn)CrossEntropyLoss   )ACT2FN)BaseModelOutputDepthEstimatorOutputSemanticSegmenterOutput)ALL_ATTENTION_FUNCTIONSPreTrainedModel) find_pruneable_heads_and_indicesprune_linear_layer)ModelOutputauto_docstringlogging	torch_int)load_backbone   )	DPTConfigc                   @   s>   e Zd ZU dZdZeej ed< dZ	ee
ejdf  ed< dS )*BaseModelOutputWithIntermediateActivationsa#  
    Base class for model's outputs that also contains intermediate activations that can be used at later stages. Useful
    in the context of Vision models.:

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        intermediate_activations (`tuple(torch.FloatTensor)`, *optional*):
            Intermediate activations that can be used to compute hidden states of the model at various layers.
    Nlast_hidden_states.intermediate_activations)__name__
__module____qualname____doc__r   r   torchFloatTensor__annotations__r   r    r%   r%   S/var/www/auris/lib/python3.10/site-packages/transformers/models/dpt/modeling_dpt.pyr   +   s   
 r   c                   @   s   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eeejdf  ed< dZeeejdf  ed< dZeeejdf  ed< dS )	4BaseModelOutputWithPoolingAndIntermediateActivationsa  
    Base class for model's outputs that also contains a pooling of the last hidden states as well as intermediate
    activations that can be used by the model at later stages.

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
            Last layer hidden-state of the first token of the sequence (classification token) after further processing
            through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns
            the classification token after processing through a linear layer and a tanh activation function. The linear
            layer weights are trained from the next sentence prediction (classification) objective during pretraining.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
        intermediate_activations (`tuple(torch.FloatTensor)`, *optional*):
            Intermediate activations that can be used to compute hidden states of the model at various layers.
    Nlast_hidden_statepooler_output.hidden_states
attentionsr   )r   r   r    r!   r(   r   r"   r#   r$   r)   r*   r   r+   r   r%   r%   r%   r&   r'   <   s   
 r'   c                	       sN   e Zd ZdZd fdd	ZdddZ	dd	ejd
ededejfddZ	  Z
S )DPTViTHybridEmbeddingsz
    This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
    `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
    Transformer.
    Nc           
         sj  t    |j|j}}|j|j}}t|tjj	r|n||f}t|tjj	r)|n||f}|d |d  |d |d   }t
|| _| jjd }t| jjdkr[tdt| jj ddg| _|d u rr|j}	|	dd  }|	d }nt|tjj	r{|n||f}| jjd }|| _|d | _|| _tj||dd| _ttdd|j| _ttd|d |j| _d S )Nr   r   r   z1Expected backbone to have 3 output features, got kernel_size)super__init__
image_size
patch_sizenum_channelshidden_size
isinstancecollectionsabcIterabler   backbonechannelslen
ValueErrorresidual_feature_map_indexZbackbone_featmap_shaper	   Conv2d
projection	Parameterr"   zeros	cls_tokenposition_embeddings)
selfconfigZfeature_sizer3   r4   r5   r6   num_patchesZfeature_dimZfeat_map_shape	__class__r%   r&   r2   g   s0   
 



 zDPTViTHybridEmbeddings.__init__r   c                 C   s   |d d d |f }|d|d f }t t|d }|d||ddddd}tjj|||fdd}|ddddd|| d}tj||gdd	}|S 
Nr         ?r   r-   r      bilinear)sizemodedim)	r   r=   reshapepermuter	   
functionalinterpolater"   catrF   ZposembZgrid_size_heightZgrid_size_widthstart_indexZ
posemb_tokZposemb_gridZold_grid_sizer%   r%   r&   _resize_pos_embed   s   z(DPTViTHybridEmbeddings._resize_pos_embedFpixel_valuesinterpolate_pos_encodingreturn_dictreturnc              
      s  |j \}}}}|| jkrtd|s7|| jd ks || jd kr7td| d| d| jd  d| jd  d	| | j|| j || j }| |  jd }	 fd	d
| j	D }
| 
|	ddd}| j|dd}tj||fdd}|| }|s||
fS t||
dS )NeMake sure that the channel dimension of the pixel values match with the one set in the configuration.r   r   zInput image size (*z) doesn't match model (z).r-   c                    s   g | ]} j | qS r%   )feature_maps).0indexZbackbone_outputr%   r&   
<listcomp>   s    z2DPTViTHybridEmbeddings.forward.<locals>.<listcomp>rM   rQ   )r   r   )shaper5   r>   r3   rZ   rE   r4   r;   ra   r?   rA   flatten	transposerD   expandr"   rW   r   )rF   r[   r\   r]   
batch_sizer5   heightwidthrE   featuresoutput_hidden_states
embeddings
cls_tokensr%   rd   r&   forward   s<   


zDPTViTHybridEmbeddings.forwardNr   )FF)r   r   r    r!   r2   rZ   r"   Tensorboolrq   __classcell__r%   r%   rI   r&   r,   `   s    
"r,   c                       s4   e Zd ZdZ fddZd
ddZddd	Z  ZS )DPTViTEmbeddingszB
    Construct the CLS token, position and patch embeddings.

    c                    sh   t    ttdd|j| _t|| _	| j	j
}ttd|d |j| _t|j| _|| _d S )Nr   )r1   r2   r	   rB   r"   rC   r6   rD   DPTViTPatchEmbeddingspatch_embeddingsrH   rE   Dropouthidden_dropout_probdropoutrG   )rF   rG   rH   rI   r%   r&   r2      s   


zDPTViTEmbeddings.__init__r   c                 C   s   |d d d |f }|d|d f }t |dd }|d||ddddd}tjj|||fdd}|ddddd|| d}tj||gdd	}|S rK   )	r   rO   rS   rT   r	   rU   rV   r"   rW   rX   r%   r%   r&   rZ      s   z"DPTViTEmbeddings._resize_pos_embedFc                 C   s   |j \}}}}| jj}| | j|| || }| |}	|	 \}}
}| j|dd}t	j
||	fdd}	|	| }	| |	}	|sB|	fS t|	dS )Nr-   r   rQ   )r   )rf   rG   r4   rZ   rE   ry   rO   rD   ri   r"   rW   r|   r   )rF   r[   r]   rj   r5   rk   rl   r4   rE   ro   Zseq_len_rp   r%   r%   r&   rq      s   


zDPTViTEmbeddings.forwardrs   )F)r   r   r    r!   r2   rZ   rq   rv   r%   r%   rI   r&   rw      s
    

rw   c                       s(   e Zd ZdZ fddZdd Z  ZS )rx   z$
    Image to Patch Embedding.

    c                    s   t    |j|j}}|j|j}}t|tjj	r|n||f}t|tjj	r)|n||f}|d |d  |d |d   }|| _|| _|| _|| _
tj||||d| _d S )Nr   r   )r0   stride)r1   r2   r3   r4   r5   r6   r7   r8   r9   r:   rH   r	   r@   rA   )rF   rG   r3   r4   r5   r6   rH   rI   r%   r&   r2     s   
 zDPTViTPatchEmbeddings.__init__c                 C   s<   |j \}}}}|| jkrtd| |ddd}|S )Nr_   rM   r   )rf   r5   r>   rA   rg   rh   )rF   r[   rj   r5   rk   rl   ro   r%   r%   r&   rq     s   
zDPTViTPatchEmbeddings.forwardr   r   r    r!   r2   rq   rv   r%   r%   rI   r&   rx      s    rx           modulequerykeyvalueattention_maskscalingr|   c           
      K   s|   t ||dd| }tjj|dt jd|j}tjj	||| j
d}|d ur,|| }t ||}	|	dd }	|	|fS )Nr-   r.   )rR   dtype)ptrainingr   rM   )r"   matmulrh   r	   rU   ZsoftmaxZfloat32tor   r|   r   
contiguous)
r   r   r   r   r   r   r|   kwargsZattn_weightsZattn_outputr%   r%   r&   eager_attention_forward  s   r   c                
       sv   e Zd Zdeddf fddZdejdejfddZ		dd
eej de	de
eejejf eej f fddZ  ZS )DPTSelfAttentionrG   r^   Nc                    s   t    |j|j dkrt|dstd|j d|j d|| _|j| _t|j|j | _| j| j | _	|j
| _| jd | _d| _tj|j| j	|jd| _tj|j| j	|jd| _tj|j| j	|jd| _d S )	Nr   Zembedding_sizezThe hidden size z4 is not a multiple of the number of attention heads .g      F)bias)r1   r2   r6   num_attention_headshasattrr>   rG   intattention_head_sizeall_head_sizeZattention_probs_dropout_probdropout_probr   	is_causalr	   LinearZqkv_biasr   r   r   rF   rG   rI   r%   r&   r2   =  s"   

zDPTSelfAttention.__init__xc                 C   s6   |  d d | j| jf }||}|ddddS )Nr-   r   rM   r   r   )rO   r   r   viewrT   )rF   r   Znew_x_shaper%   r%   r&   transpose_for_scoresQ  s   
z%DPTSelfAttention.transpose_for_scoresF	head_maskoutput_attentionsc              
   C   s   |  | |}|  | |}|  | |}t}| jjdkr4| jjdkr.|r.td nt	| jj }|| ||||| j
| j| jsCdn| jd\}}	| d d | jf }
||
}|rc||	f}|S |f}|S )NeagerZsdpaz`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.r   )r   r   r|   r.   )r   r   r   r   r   rG   Z_attn_implementationloggerZwarning_oncer   r   r   r   r   rO   r   rS   )rF   r*   r   r   Z	key_layerZvalue_layerZquery_layerZattention_interfaceZcontext_layerZattention_probsZnew_context_layer_shapeoutputsr%   r%   r&   rq   V  s4   

zDPTSelfAttention.forwardNF)r   r   r    r   r2   r"   rt   r   r   ru   r   r   rq   rv   r%   r%   rI   r&   r   <  s    r   c                       sF   e Zd ZdZdeddf fddZdejdejdejfd	d
Z  Z	S )DPTViTSelfOutputz
    The residual connection is defined in DPTLayer instead of here (as is the case with other models), due to the
    layernorm applied before each block.
    rG   r^   Nc                    s.   t    t|j|j| _t|j| _d S rr   )	r1   r2   r	   r   r6   denserz   r{   r|   r   rI   r%   r&   r2        
zDPTViTSelfOutput.__init__r*   input_tensorc                 C      |  |}| |}|S rr   r   r|   rF   r*   r   r%   r%   r&   rq        

zDPTViTSelfOutput.forward)
r   r   r    r!   r   r2   r"   rt   rq   rv   r%   r%   rI   r&   r   {  s    $r   c                       s~   e Zd Zdeddf fddZdee ddfddZ			dd
ej	de
ej	 dedeeej	ej	f eej	 f fddZ  ZS )DPTViTAttentionrG   r^   Nc                    s*   t    t|| _t|| _t | _d S rr   )r1   r2   r   	attentionr   outputsetpruned_headsr   rI   r%   r&   r2     s   


zDPTViTAttention.__init__headsc                 C   s   t |dkrd S t|| jj| jj| j\}}t| jj|| j_t| jj|| j_t| jj	|| j_	t| j
j|dd| j
_| jjt | | j_| jj| jj | j_| j|| _d S )Nr   r   rQ   )r=   r   r   r   r   r   r   r   r   r   r   r   r   union)rF   r   rc   r%   r%   r&   prune_heads  s   zDPTViTAttention.prune_headsFr*   r   r   c                 C   s4   |  |||}| |d |}|f|dd   }|S )Nr   r   )r   r   )rF   r*   r   r   Zself_outputsattention_outputr   r%   r%   r&   rq     s   zDPTViTAttention.forwardr   )r   r   r    r   r2   r   r   r   r"   rt   r   ru   r   r   rq   rv   r%   r%   rI   r&   r     s    r   c                       s<   e Zd Zdeddf fddZdejdejfddZ  ZS )	DPTViTIntermediaterG   r^   Nc                    sD   t    t|j|j| _t|jt	rt
|j | _d S |j| _d S rr   )r1   r2   r	   r   r6   intermediate_sizer   r7   
hidden_actstrr   intermediate_act_fnr   rI   r%   r&   r2     s
   
zDPTViTIntermediate.__init__r*   c                 C   r   rr   )r   r   )rF   r*   r%   r%   r&   rq     r   zDPTViTIntermediate.forward	r   r   r    r   r2   r"   rt   rq   rv   r%   r%   rI   r&   r     s    r   c                       sB   e Zd Zdeddf fddZdejdejdejfdd	Z  ZS )
DPTViTOutputrG   r^   Nc                    s.   t    t|j|j| _t|j| _	d S rr   )
r1   r2   r	   r   r   r6   r   rz   r{   r|   r   rI   r%   r&   r2     r   zDPTViTOutput.__init__r*   r   c                 C   s    |  |}| |}|| }|S rr   r   r   r%   r%   r&   rq     s   

zDPTViTOutput.forwardr   r%   r%   rI   r&   r     s    $r   c                       sl   e Zd ZdZdeddf fddZ		ddejd	eej d
e	de
eejejf eej f fddZ  ZS )DPTViTLayerz?This corresponds to the Block class in the timm implementation.rG   r^   Nc                    sb   t    |j| _d| _t|| _t|| _t|| _	t
j|j|jd| _t
j|j|jd| _d S )Nr   Zeps)r1   r2   Zchunk_size_feed_forwardZseq_len_dimr   r   r   intermediater   r   r	   	LayerNormr6   layer_norm_epslayernorm_beforelayernorm_afterr   rI   r%   r&   r2     s   



zDPTViTLayer.__init__Fr*   r   r   c                 C   s`   | j | |||d}|d }|dd  }|| }| |}| |}| ||}|f| }|S )N)r   r   r   )r   r   r   r   r   )rF   r*   r   r   Zself_attention_outputsr   r   Zlayer_outputr%   r%   r&   rq     s   


zDPTViTLayer.forwardr   )r   r   r    r!   r   r2   r"   rt   r   ru   r   r   rq   rv   r%   r%   rI   r&   r     s    r   c                       sb   e Zd Zdeddf fddZ				ddejd	eej d
ededede	e
ef fddZ  ZS )DPTViTEncoderrG   r^   Nc                    s:   t     | _t fddt jD | _d| _d S )Nc                    s   g | ]}t  qS r%   )r   )rb   r}   rG   r%   r&   re     s    z*DPTViTEncoder.__init__.<locals>.<listcomp>F)	r1   r2   rG   r	   
ModuleListrangenum_hidden_layerslayergradient_checkpointingr   rI   r   r&   r2     s   
 
zDPTViTEncoder.__init__FTr*   r   r   rn   r]   c                 C   s   |rdnd }|r
dnd }t | jD ]8\}}	|r||f }|d ur$|| nd }
| jr6| jr6| |	j||
|}n|	||
|}|d }|rI||d f }q|rQ||f }|s_tdd |||fD S t|||dS )Nr%   r   r   c                 s   s    | ]	}|d ur|V  qd S rr   r%   )rb   vr%   r%   r&   	<genexpr>.  s    z(DPTViTEncoder.forward.<locals>.<genexpr>)r(   r*   r+   )	enumerater   r   r   Z_gradient_checkpointing_func__call__tupler   )rF   r*   r   r   rn   r]   Zall_hidden_statesZall_self_attentionsiZlayer_moduleZlayer_head_maskZlayer_outputsr%   r%   r&   rq   
  s6   

zDPTViTEncoder.forward)NFFT)r   r   r    r   r2   r"   rt   r   ru   r   r   r   rq   rv   r%   r%   rI   r&   r     s&    	
r   c                       sP   e Zd ZdZ fddZdd Zdd Zdd	eej	 d
eej	 fddZ
  ZS )DPTReassembleStagea@  
    This class reassembles the hidden states of the backbone into image-like feature representations at various
    resolutions.

    This happens in 3 stages:
    1. Map the N + 1 tokens to a set of N tokens, by taking into account the readout ([CLS]) token according to
       `config.readout_type`.
    2. Project the channel dimension of the hidden states according to `config.neck_hidden_sizes`.
    3. Resizing the spatial dimensions (height, width).

    Args:
        config (`[DPTConfig]`):
            Model configuration class defining the model architecture.
    c                    sB   t    || _t | _|jr| | n| | |j	| _	d S rr   )
r1   r2   rG   r	   r   layers	is_hybrid_init_reassemble_dpt_hybrid_init_reassemble_dptneck_ignore_stagesr   rI   r%   r&   r2   F  s   


zDPTReassembleStage.__init__c              	   C   s   t tt|j|jD ]#\}}|dkr| jt  q|dkr.| jt	||j| |d q|j
dkr=td|j
 dt | _t|}tt|jD ])}|dkr_| jtt  qM|dkrv| jttd| |t|j  qMdS )a   "
        For DPT-Hybrid the first 2 reassemble layers are set to `nn.Identity()`, please check the official
        implementation: https://github.com/isl-org/DPT/blob/f43ef9e08d70a752195028a51be5e1aff227b913/dpt/vit.py#L438
        for more details.
        r   r<   factorprojectzReadout type z! is not supported for DPT-Hybrid.rM   N)zipr   r=   neck_hidden_sizesreassemble_factorsr   appendr	   IdentityDPTReassembleLayerreadout_typer>   r   readout_projects_get_backbone_hidden_size
Sequentialr   r   r   )rF   rG   r   r   r6   r%   r%   r&   r   R  s&   

z.DPTReassembleStage._init_reassemble_dpt_hybridc              	   C   s   t tt|j|jD ]\}}| jt||j| |d q|jdkrIt	
 | _t|}tt|jD ]}| jt	t	d| |t|j  q3d S d S )Nr   r   rM   )r   r   r=   r   r   r   r   r   r   r	   r   r   r   r   r   r   r   )rF   rG   r   r   r6   r}   r%   r%   r&   r   l  s   

z'DPTReassembleStage._init_reassemble_dptNr*   r^   c                 C   sL  g }t |D ]\}}|| jvr|dddf |ddddf }}|j\}}	}
|dur9|dur9|||||
}nt|	d }|||||
}|dddd }|j}| jjdkr|	dd}|
d|}| j| t||fd	}|ddd|}n| jjd
kr|	d|
d	 }||}| j| |}|| q|S )z
        Args:
            hidden_states (`List[torch.FloatTensor]`, each of shape `(batch_size, sequence_length + 1, hidden_size)`):
                List of hidden states from the backbone.
        Nr   r   rL   r   rM   r   )r   rM   r   r-   add)r   r   rf   rS   r   rT   r   rG   r   rg   Z	unsqueezeZ	expand_asr   r"   rW   r   r   )rF   r*   patch_heightpatch_widthoutr   hidden_staterD   rj   Zsequence_lengthr5   rO   Zfeature_shapeZreadoutr%   r%   r&   rq   x  s,   
&
zDPTReassembleStage.forwardNN)r   r   r    r!   r2   r   r   r   r"   rt   rq   rv   r%   r%   rI   r&   r   6  s    (r   c                 C   s"   | j d ur| jdu r| j jS | jS r   )backbone_configr   r6   r   r%   r%   r&   r     s   r   c                       $   e Zd Z fddZdd Z  ZS )r   c                    s   t    t|}tj||dd| _|dkr#tj||||dd| _d S |dkr.t | _d S |dk rCtj||dt	d| dd| _d S d S )Nr   )Zin_channelsZout_channelsr0   r   r0   r~   paddingr   )
r1   r2   r   r	   r@   rA   ConvTranspose2dresizer   r   )rF   rG   r<   r   r6   rI   r%   r&   r2     s   
"zDPTReassembleLayer.__init__c                 C   r   rr   )rA   r   )rF   r   r%   r%   r&   rq     s   

zDPTReassembleLayer.forwardr   r   r    r2   rq   rv   r%   r%   rI   r&   r     s    r   c                       r   )DPTFeatureFusionStagec                    s<   t    t | _tt|jD ]
}| jt	| qd S rr   )
r1   r2   r	   r   r   r   r=   r   r   DPTFeatureFusionLayer)rF   rG   r}   rI   r%   r&   r2     s
   

zDPTFeatureFusionStage.__init__c                 C   sV   |d d d }g }d }t || jD ]\}}|d u r||}n|||}|| q|S )Nr-   )r   r   r   )rF   r*   Zfused_hidden_statesZfused_hidden_stater   r   r%   r%   r&   rq     s   

zDPTFeatureFusionStage.forwardr   r%   r%   rI   r&   r     s    r   c                       s6   e Zd ZdZ fddZdejdejfddZ  ZS )DPTPreActResidualLayerz
    ResidualConvUnit, pre-activate residual unit.

    Args:
        config (`[DPTConfig]`):
            Model configuration class defining the model architecture.
    c                    s   t    |j| _|jd ur|jn| j }t | _tj|j	|j	ddd|d| _
t | _tj|j	|j	ddd|d| _| jrNt|j	| _t|j	| _d S d S )Nr   r   )r0   r~   r   r   )r1   r2   Z!use_batch_norm_in_fusion_residualuse_batch_normuse_bias_in_fusion_residualr	   ReLUactivation1r@   fusion_hidden_sizeconvolution1activation2convolution2BatchNorm2dbatch_norm1batch_norm2)rF   rG   r   rI   r%   r&   r2     s8   



		zDPTPreActResidualLayer.__init__r   r^   c                 C   sT   |}|  |}| |}| jr| |}| |}| |}| jr&| |}|| S rr   )r   r   r   r   r   r   r   rF   r   Zresidualr%   r%   r&   rq     s   





zDPTPreActResidualLayer.forward)	r   r   r    r!   r2   r"   rt   rq   rv   r%   r%   rI   r&   r     s    "r   c                       s,   e Zd ZdZd fdd	Zd	ddZ  ZS )
r   a3  Feature fusion layer, merges feature maps from different stages.

    Args:
        config (`[DPTConfig]`):
            Model configuration class defining the model architecture.
        align_corners (`bool`, *optional*, defaults to `True`):
            The align_corner setting for bilinear upsample.
    Tc                    s@   t    || _tj|j|jddd| _t|| _t|| _	d S )Nr   T)r0   r   )
r1   r2   align_cornersr	   r@   r   rA   r   residual_layer1residual_layer2)rF   rG   r   rI   r%   r&   r2     s
   

zDPTFeatureFusionLayer.__init__Nc                 C   st   |d ur#|j |j krtjj||j d |j d fddd}|| | }| |}tjj|dd| jd}| |}|S )NrM   r   rN   FrO   rP   r   Zscale_factorrP   r   )rf   r	   rU   rV   r   r  r   rA   r   r%   r%   r&   rq   $  s   


zDPTFeatureFusionLayer.forwardTrr   r   r%   r%   rI   r&   r     s    	
r   c                   @   s,   e Zd ZeZdZdZdZdZdZ	dd Z
dS )DPTPreTrainedModeldptr[   Tc                 C   s   t |tjtjtjfr"|jjjd| jj	d |j
dur!|j
j  nt |tjtjfr8|j
j  |jjd t |ttfrM|jj  |jj  dS dS )zInitialize the weightsr   )meanZstdNg      ?)r7   r	   r   r@   r   weightdataZnormal_rG   Zinitializer_ranger   Zzero_r   r   Zfill_rw   r,   rD   rE   )rF   r   r%   r%   r&   _init_weights>  s   
z DPTPreTrainedModel._init_weightsN)r   r   r    r   Zconfig_classZbase_model_prefixZmain_input_nameZsupports_gradient_checkpointingZ_supports_sdpaZ_supports_flash_attn_2r
  r%   r%   r%   r&   r  5  s    r  c                       sz   e Zd Zd fdd	Zdd Zdd Ze				dd	ejd
e	ej de	e
 de	e
 de	e
 deeef fddZ  ZS )DPTModelTc                    sj   t  | || _|jrt|| _nt|| _t|| _t	j
|j|jd| _|r,t|nd| _|   dS )zv
        add_pooling_layer (bool, *optional*, defaults to `True`):
            Whether to add a pooling layer
        r   N)r1   r2   rG   r   r,   ro   rw   r   encoderr	   r   r6   r   	layernormDPTViTPoolerpooler	post_init)rF   rG   add_pooling_layerrI   r%   r&   r2   P  s   

zDPTModel.__init__c                 C   s   | j jr| jS | jjS rr   )rG   r   ro   ry   rF   r%   r%   r&   get_input_embeddingse  s   zDPTModel.get_input_embeddingsc                 C   s*   |  D ]\}}| jj| j| qdS )z
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        N)itemsr  r   r   r   )rF   Zheads_to_pruner   r   r%   r%   r&   _prune_headsk  s   zDPTModel._prune_headsNr[   r   r   rn   r]   r^   c                 C   s   |d ur|n| j j}|d ur|n| j j}|d ur|n| j j}| || j j}| j||d}|s3|d n|j}| j|||||d}|d }	| 	|	}	| j
d urS| 
|	nd }
|sp|
d ur_|	|
fn|	f}||dd   |dd   S t|	|
|j|j|jdS )N)r]   r   r   r   rn   r]   r   )r(   r)   r*   r+   r   )rG   r   rn   use_return_dictZget_head_maskr   ro   r   r  r  r  r'   r*   r+   r   )rF   r[   r   r   rn   r]   Zembedding_outputZembedding_last_hidden_statesZencoder_outputsZsequence_outputpooled_outputZhead_outputsr%   r%   r&   rq   s  s6   	
zDPTModel.forwardr  )NNNN)r   r   r    r2   r  r  r   r"   r#   r   ru   r   r   r'   rq   rv   r%   r%   rI   r&   r  N  s,    
r  c                       s*   e Zd Zdef fddZdd Z  ZS )r  rG   c                    s,   t    t|j|j| _t|j | _	d S rr   )
r1   r2   r	   r   r6   Zpooler_output_sizer   r   Z
pooler_act
activationr   rI   r%   r&   r2     s   
zDPTViTPooler.__init__c                 C   s(   |d d df }|  |}| |}|S )Nr   )r   r  )rF   r*   Zfirst_token_tensorr  r%   r%   r&   rq     s   

zDPTViTPooler.forward)r   r   r    r   r2   rq   rv   r%   r%   rI   r&   r    s    r  c                       s@   e Zd ZdZ fddZd	deej deej fddZ  Z	S )
DPTNecka;  
    DPTNeck. A neck is a module that is normally used between the backbone and the head. It takes a list of tensors as
    input and produces another list of tensors as output. For DPT, it includes 2 stages:

    * DPTReassembleStage
    * DPTFeatureFusionStage.

    Args:
        config (dict): config dict.
    c              
      sz   t    || _|jd ur|jjdv rd | _nt|| _t | _	|j
D ]}| j	tj||jdddd q$t|| _d S )N)Zswinv2r   r   Fr0   r   r   )r1   r2   rG   r   Z
model_typereassemble_stager   r	   r   convsr   r   r@   r   r   fusion_stage)rF   rG   ZchannelrI   r%   r&   r2     s   



 zDPTNeck.__init__Nr*   r^   c                    sn   t |ttfstdt|t jjkrtd jdur% |||} fddt	|D } 
|}|S )z
        Args:
            hidden_states (`List[torch.FloatTensor]`, each of shape `(batch_size, sequence_length, hidden_size)` or `(batch_size, hidden_size, height, width)`):
                List of hidden states from the backbone.
        z2hidden_states should be a tuple or list of tensorszOThe number of hidden states should be equal to the number of neck hidden sizes.Nc                    s   g | ]\}} j | |qS r%   )r  )rb   r   featurer  r%   r&   re     s    z#DPTNeck.forward.<locals>.<listcomp>)r7   r   list	TypeErrorr=   rG   r   r>   r  r   r  )rF   r*   r   r   rm   r   r%   r  r&   rq     s   

zDPTNeck.forwardr   
r   r   r    r!   r2   r   r"   rt   rq   rv   r%   r%   rI   r&   r    s    (r  c                       s:   e Zd ZdZ fddZdeej dejfddZ  Z	S )DPTDepthEstimationHeada	  
    Output head consisting of 3 convolutional layers. It progressively halves the feature dimension and upsamples
    the predictions to the input resolution after the first convolutional layer (details can be found in the paper's
    supplementary material).
    c                    s   t    || _d | _|jrtjdddddd| _|j}ttj||d ddddtj	ddd	d
tj|d dddddt
 tjddddddt
 | _d S )N   )r   r   )r   r   r   rM   r   r   rN   Tr      r   )r1   r2   rG   rA   Zadd_projectionr	   r@   r   r   Upsampler   headrF   rG   rm   rI   r%   r&   r2     s   

zDPTDepthEstimationHead.__init__r*   r^   c                 C   sF   || j j }| jd ur| |}t |}| |}|jdd}|S )Nr   rQ   )rG   head_in_indexrA   r	   r   r'  Zsqueeze)rF   r*   predicted_depthr%   r%   r&   rq     s   


zDPTDepthEstimationHead.forwardr"  r%   r%   rI   r&   r#    s    "r#  zu
    DPT Model with a depth estimation head on top (consisting of 3 convolutional layers) e.g. for KITTI, NYUv2.
    )Zcustom_introc                       sz   e Zd Z fddZe					ddejdeej deej dee	 dee	 d	ee	 d
e
eej ef fddZ  ZS )DPTForDepthEstimationc                    sj   t  | d | _|jdu r|jd us|jd urt|| _nt|dd| _t|| _	t
|| _|   d S NF)r  )r1   r2   r;   r   r   r   r  r  r  neckr#  r'  r  r   rI   r%   r&   r2     s   

zDPTForDepthEstimation.__init__Nr[   r   labelsr   rn   r]   r^   c                    s  d}|dur
t d|dur|n jj}|dur|n jj}|dur$|n jj} jdur: jj|||d}|j}	nF j|||d|d}|rI|j	n|d }	 jj
sa fddt|	dd D }	n|rf|jnt|d	 }
|
 fd
dt|	dd D  |
}	d\}} jjdur jj
du r|j\}}}} jjj}|| }|| } |	||}	 |	}|s|r|f|dd  }n	|f|dd  }|dur|f| S |S t|||r|j	nd|jdS )a  
        labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
            Ground truth depth estimation maps for computing the loss.

        Examples:
        ```python
        >>> from transformers import AutoImageProcessor, DPTForDepthEstimation
        >>> import torch
        >>> import numpy as np
        >>> from PIL import Image
        >>> import requests

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> image_processor = AutoImageProcessor.from_pretrained("Intel/dpt-large")
        >>> model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large")

        >>> # prepare image for the model
        >>> inputs = image_processor(images=image, return_tensors="pt")

        >>> with torch.no_grad():
        ...     outputs = model(**inputs)

        >>> # interpolate to original size
        >>> post_processed_output = image_processor.post_process_depth_estimation(
        ...     outputs,
        ...     target_sizes=[(image.height, image.width)],
        ... )

        >>> # visualize the prediction
        >>> predicted_depth = post_processed_output[0]["predicted_depth"]
        >>> depth = predicted_depth * 255 / predicted_depth.max()
        >>> depth = depth.detach().cpu().numpy()
        >>> depth = Image.fromarray(depth.astype("uint8"))
        ```NzTraining is not implemented yet)rn   r   Tr  r   c                        g | ]\}}| j jv r|qS r%   rG   Zbackbone_out_indicesrb   idxr  r  r%   r&   re   u      z1DPTForDepthEstimation.forward.<locals>.<listcomp>r-   c                 3   ,    | ]\}}| j jd d v r|V  qdS rM   Nr0  r1  r  r%   r&   r   z  s    z0DPTForDepthEstimation.forward.<locals>.<genexpr>r   FrM   )lossr*  r*   r+   )NotImplementedErrorrG   r  rn   r   r;   Zforward_with_filtered_kwargsra   r  r*   r   r   r   r   extendr   rf   r4   r-  r'  r   r+   )rF   r[   r   r.  r   rn   r]   r6  r   r*   backbone_hidden_statesr   r   r}   rk   rl   r4   r*  r   r%   r  r&   rq   ,  s`   .



zDPTForDepthEstimation.forward)NNNNN)r   r   r    r2   r   r"   r#   r   
LongTensorru   r   r   rt   r   rq   rv   r%   r%   rI   r&   r+    s.    r+  c                       s6   e Zd Z fddZdeej dejfddZ  ZS )DPTSemanticSegmentationHeadc                    sl   t    || _|j}ttj||ddddt|t t	|j
tj||jddtjdddd	| _d S )
Nr   r   Fr  r/   rM   rN   Tr  )r1   r2   rG   r   r	   r   r@   r   r   rz   Zsemantic_classifier_dropout
num_labelsr&  r'  r(  rI   r%   r&   r2     s   


z$DPTSemanticSegmentationHead.__init__r*   r^   c                 C   s   || j j }| |}|S rr   )rG   r)  r'  rF   r*   logitsr%   r%   r&   rq     s   
z#DPTSemanticSegmentationHead.forward)	r   r   r    r2   r   r"   rt   rq   rv   r%   r%   rI   r&   r;    s    "r;  c                       r   )DPTAuxiliaryHeadc                    sX   t    |j}ttj||ddddt|t tddtj||j	dd| _
d S )Nr   r   Fr  g?r/   )r1   r2   r   r	   r   r@   r   r   rz   r<  r'  r(  rI   r%   r&   r2     s   


zDPTAuxiliaryHead.__init__c                 C   s   |  |}|S rr   )r'  r=  r%   r%   r&   rq     s   
zDPTAuxiliaryHead.forwardr   r%   r%   rI   r&   r?    s    r?  c                       s   e Zd Z fddZe						ddeej deej deej dee	 dee	 d	ee	 d
e
eej ef fddZ  ZS )DPTForSemanticSegmentationc                    sN   t  | t|dd| _t|| _t|| _|jrt	|nd | _
|   d S r,  )r1   r2   r  r  r  r-  r;  r'  Zuse_auxiliary_headr?  auxiliary_headr  r   rI   r%   r&   r2     s   

z#DPTForSemanticSegmentation.__init__Nr[   r   r.  r   rn   r]   r^   c                    s  |dur|n j j}|dur|n j j}|dur" j jdkr"td j|||d|d}|r1|jn|d } j jsI fddt|dd D }n|rN|j	nt
|d }	|	 fd	d
t|dd D  |	} j|d} |}
d} jdur |d }d}|durtjj|
|jdd ddd}|durtjj||jdd ddd}t j jd}|||}|||}| j j|  }|s|r|
f|dd  }n	|
f|dd  }|dur|f| S |S t||
|r|jnd|jdS )a  
        labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
            Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy).

        Examples:
        ```python
        >>> from transformers import AutoImageProcessor, DPTForSemanticSegmentation
        >>> from PIL import Image
        >>> import requests

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> image_processor = AutoImageProcessor.from_pretrained("Intel/dpt-large-ade")
        >>> model = DPTForSemanticSegmentation.from_pretrained("Intel/dpt-large-ade")

        >>> inputs = image_processor(images=image, return_tensors="pt")

        >>> outputs = model(**inputs)
        >>> logits = outputs.logits
        ```Nr   z/The number of labels should be greater than oneTr  c                    r/  r%   r0  r1  r  r%   r&   re     r3  z6DPTForSemanticSegmentation.forward.<locals>.<listcomp>r-   c                 3   r4  r5  r0  r1  r  r%   r&   r     s    "z5DPTForSemanticSegmentation.forward.<locals>.<genexpr>)r*   r.   rN   Fr  )Zignore_indexrM   )r6  r>  r*   r+   )rG   r  rn   r<  r>   r  r*   r   r   r   r   r8  r-  r'  rA  r	   rU   rV   rf   r
   Zsemantic_loss_ignore_indexZauxiliary_loss_weightr   r+   )rF   r[   r   r.  r   rn   r]   r   r*   r9  r>  Zauxiliary_logitsr6  Zupsampled_logitsZupsampled_auxiliary_logitsZloss_fctZ	main_lossZauxiliary_lossr   r%   r  r&   rq     sf    




z"DPTForSemanticSegmentation.forward)NNNNNN)r   r   r    r2   r   r   r"   r#   r:  ru   r   r   rt   r   rq   rv   r%   r%   rI   r&   r@    s0    r@  )r+  r@  r  r  )r   )Ir!   collections.abcr8   dataclassesr   typingr   r   r   r   r   r   r"   Ztorch.utils.checkpointr	   Ztorch.nnr
   Zactivationsr   Zmodeling_outputsr   r   r   Zmodeling_utilsr   r   Zpytorch_utilsr   r   utilsr   r   r   r   Zutils.backbone_utilsr   Zconfiguration_dptr   Z
get_loggerr   r   r   r'   Moduler,   rw   rx   rt   floatr   r   r   r   r   r   r   r   r   r   r   r   r   r   r  r  r  r  r#  r+  r;  r?  r@  __all__r%   r%   r%   r&   <module>   s    
#c:'
?*+3h=%X5) w