o
    Zhe                     @   s(  d Z ddlZddlmZ ddlmZmZmZ ddlZ	ddl
mZ ddlmZmZ ddlmZmZ dd	lmZ dd
lmZmZ ddlmZ eeZeG dd deZG dd de	jjjZ G dd de	jjjZ!G dd de	jjjZ"G dd de	jjjZ#G dd de	jjjZ$G dd deZ%dS )zOTF IdeficsVision model: a copy of CLIPVisionModel using a simpler config object    N)	dataclass)OptionalTupleUnion   )get_tf_activation)TFBaseModelOutputTFBaseModelOutputWithPooling)TFPreTrainedModel
shape_list)flatten)ModelOutputlogging   )IdeficsVisionConfigc                   @   sb   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eeej  ed< dZeeej  ed< dS )TFIdeficsVisionModelOutputa  
    Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.

    Args:
        image_embeds (`tf.Tensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
            The image embeddings obtained by applying the projection layer to the pooler_output.
        last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
        attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    Nimage_embedslast_hidden_statehidden_states
attentions)__name__
__module____qualname____doc__r   r   tfTensor__annotations__r   r   r   r    r   r   T/var/www/auris/lib/python3.10/site-packages/transformers/models/idefics/vision_tf.pyr   "   s   
 r   c                       sf   e Zd Zdef fddZdejdededejfdd	Zddejde	dejfddZ
dddZ  ZS )TFIdeficsVisionEmbeddingsconfigc              	      s   t  jd
i | || _|j| _|j| _|j| _tjj	j
| j| j| jddddd| _| j| j d | _| jd | _tjj	j| j| jdd	| _d S )NFZvalidZchannels_lastpatch_embedding)filtersZkernel_sizestridesZuse_biaspaddingZdata_formatname   r   position_embeddingr%   r   )super__init__r    hidden_size	embed_dim
image_size
patch_sizer   keraslayersZConv2Dr!   num_patchesnum_positionsZ	Embeddingr'   selfr    kwargs	__class__r   r   r*   @   s&   

z"TFIdeficsVisionEmbeddings.__init__
embeddingsheightwidthreturnc                 C   s  t |d d }| | j}t |d d }||kr ||kr |S |d d df }|d d dd f }t |d }	|| jj }
|| jj }|
d |d }
}tt|}t	|dt
|t
||	f}|
| }|| }tt|d tj}tt|d tj}t|| tj}t|| tj}tjj|||gtjjjd}t
|
t |d kst
|t |d krtd	t
|
t
|f d
t |d t |d f dt	|dd|	f}tj|tjd d f |fddS )Nr   r   g?r&   )sizemethodzNumber of patches for images (z/) don't match the shape of position embedding ()Zaxis)r   r'   position_idsr    r.   mathsqrtfloatr   reshapeintcastshapeZfloat32Zint32imageresizeZResizeMethodZBICUBIC
ValueErrorconcatnewaxis)r4   r8   r9   r:   r1   Z	pos_embedr2   Zclass_pos_embedZpatch_pos_embedr,   Znum_h_patchesZnum_w_patchesZsqrt_num_positionsZscale_heightZscale_widthZoriginal_heightZoriginal_widthZ
new_heightZ	new_widthr   r   r   interpolate_pos_encodingX   s>    z2TFIdeficsVisionEmbeddings.interpolate_pos_encodingFpixel_valuesrP   c           
   
   C   s   t |tr	|d }tj|dd}t|\}}}}|s7|| jks$|| jkr7td| d| d| j d| j d	| |}t|dd	}t	| j
tjtjd d f |d| jg}tj||gdd
}	|rl|	| |	|| }	|	S |	| | j }	|	S )NrQ   )r   r&   r   r   permzInput image size (*z) doesn't match model (z8). You should try to set `interpolate_pos_encoding=True`r   r&   rB   )
isinstancedictr   	transposer   r-   rM   r!   r   Zbroadcast_toclass_embeddingrO   r,   rN   rP   r'   rC   )
r4   rQ   rP   Z
batch_sizer9   r:   num_channelsZpatch_embedsZclass_embedsr8   r   r   r   call   s0   

 zTFIdeficsVisionEmbeddings.callNc                 C   s   | j rd S d| _ tj| jddtjd d f | _| j| jfdd| _t	| dd d urMt
| jj | jd d d | jjg W d    n1 sHw   Y  t	| dd d urut
| jj | jd  W d    d S 1 snw   Y  d S d S )NTzself.position_idsr(   rX   )rJ   r%   r!   r'   )builtr   ranger2   rO   rC   Z
add_weightr,   rX   getattr
name_scoper!   r%   buildr    rY   r'   r4   input_shaper   r   r   r_      s    "zTFIdeficsVisionEmbeddings.buildFN)r   r   r   r   r*   r   r   rH   rP   boolrZ   r_   __classcell__r   r   r6   r   r   ?   s
    '#r   c                       s   e Zd ZdZ fddZdejdedefddZ					
ddejde	ej de	ej de	e
 deeje	ej e	eej  f f
ddZdddZ  ZS )TFIdeficsVisionAttentionz=Multi-headed attention from 'Attention Is All You Need' paperc                    s   t  jd
i | || _|j| _|j| _| j| j | _| j| j | jkr1td| j d| j d| jd | _	|j
| _tjjj| jdd| _tjjj| jdd| _tjjj| jdd| _tjjj| jd	d| _d S )Nz;embed_dim must be divisible by num_heads (got `embed_dim`: z and `num_heads`: z).g      k_projr(   v_projq_projout_projr   )r)   r*   r    r+   r,   Znum_attention_heads	num_headshead_dimrM   scaleZattention_dropoutdropoutr   r/   r0   Denserg   rh   ri   rj   r3   r6   r   r   r*      s"   
z!TFIdeficsVisionAttention.__init__tensorseq_lenbszc                 C   s&   t jt |||| j| jfg ddS )Nr   r&   r   r   rR   )r   rW   rG   rk   rl   )r4   rp   rq   rr   r   r   r   _shape   s   &zTFIdeficsVisionAttention._shapeNFr   attention_maskcausal_attention_maskoutput_attentionsr;   c              	   C   s  t |\}}}| || j }| | |d|}	| | |d|}
|| j d| jf}t	| ||||}t	|	|}	t	|
|}
t |	d }tj
j||	dd}tjjt||| j ||gd|| j ||g dt| d |durt ||d||gkrtd	|d||f dt | t	||| j||f| }t	||| j ||f}|durt ||d||gkrtd	|d||f dt | t	||| j||f| }t	||| j ||f}tjj|dd
}|rt	||| j||f}t	||| j ||f}nd}tjj|| jd}tj
||
}tjjt||| j || jgd|| j || jg dt| d t	||| j|| jf}tj|g dd}t	||||f}| |}||fS )z#Input shape: Batch x Time x Channelr<   r   T)Ztranspose_bz$Attention weights should be of size z	, but is )messageNz!Attention mask should be of size rB   )Zraters   rR   )r   ri   rm   rt   rg   rh   rk   rl   r   rG   Zlinalgmatmul	debuggingZassert_equalrJ   rM   nnZsoftmaxrn   rW   rj   )r4   r   ru   rv   rw   rr   Ztgt_lenr,   Zquery_statesZ
key_statesZvalue_statesZ
proj_shapeZsrc_lenattn_weightsZattn_weights_reshapedZ
attn_probsZattn_outputr   r   r   rZ      s`   	 "
zTFIdeficsVisionAttention.callc                 C   sb  | j rd S d| _ t| dd d ur1t| jj | j| j| jf W d    n1 s,w   Y  t| dd d urZt| jj | j| j| jf W d    n1 sUw   Y  t| dd d urt| j	j | j	| j| jf W d    n1 s~w   Y  t| dd d urt| j
j | j
| j| jf W d    d S 1 sw   Y  d S d S )NTrg   rh   ri   rj   )r[   r]   r   r^   rg   r%   r_   r,   rh   ri   rj   r`   r   r   r   r_     s(   "zTFIdeficsVisionAttention.build)NNFrc   )r   r   r   r   r*   r   r   rH   rt   r   rd   r   rZ   r_   re   r   r   r6   r   rf      s&    
Nrf   c                       s<   e Zd Z fddZdejdejfddZd
dd	Z  ZS )TFIdeficsVisionMLPc                    sT   t  jdi | || _t|j| _tjjj	|j
dd| _tjjj	|jdd| _d S )Nfc1r(   fc2r   )r)   r*   r    r   Z
hidden_actactivation_fnr   r/   r0   ro   intermediate_sizer~   r+   r   r3   r6   r   r   r*   *  s
   zTFIdeficsVisionMLP.__init__r   r;   c                 C   s"   |  |}| |}| |}|S rc   )r~   r   r   )r4   r   r   r   r   rZ   1  s   


zTFIdeficsVisionMLP.callNc                 C   s   | j rd S d| _ t| dd d ur/t| jj | j| jj W d    n1 s*w   Y  t| dd d urYt| j	j | j	| jj
 W d    d S 1 sRw   Y  d S d S )NTr~   r   )r[   r]   r   r^   r~   r%   r_   r    r+   r   r   r`   r   r   r   r_   7  s   "zTFIdeficsVisionMLP.buildrc   )	r   r   r   r*   r   r   rZ   r_   re   r   r   r6   r   r}   )  s    r}   c                       s^   e Zd Zdef fddZ	ddejdejdejdee d	e	ej f
d
dZ
dddZ  ZS )TFIdeficsVisionEncoderLayerr    c                    sf   t  jdi | |j| _t|dd| _tjjj	|j
dd| _t|dd| _tjjj	|j
dd| _d S )N	self_attnr(   layer_norm1epsilonr%   mlplayer_norm2r   )r)   r*   r+   r,   rf   r   r   r/   r0   LayerNormalizationlayer_norm_epsr   r}   r   r   r3   r6   r   r   r*   D  s   z$TFIdeficsVisionEncoderLayer.__init__Fr   ru   rv   rw   r;   c                 C   sd   |}|  |}| j||||d\}}|| }|}| |}| |}|| }|f}|r0||f7 }|S )a9  
        Args:
            hidden_states (`tf.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`tf.Tensor`): attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
                `(config.encoder_attention_heads,)`.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
        )r   ru   rv   rw   )r   r   r   r   )r4   r   ru   rv   rw   Zresidualr|   Zoutputsr   r   r   rZ   L  s"   




z TFIdeficsVisionEncoderLayer.callNc                 C   s   | j rd S d| _ t| dd d ur1t| jj | jd d | jg W d    n1 s,w   Y  t| dd d ur]t| jj | jd d | jg W d    d S 1 sVw   Y  d S d S )NTr   r   )	r[   r]   r   r^   r   r%   r_   r,   r   r`   r   r   r   r_   t  s   "z!TFIdeficsVisionEncoderLayer.buildrb   rc   )r   r   r   r   r*   r   r   r   rd   r   rZ   r_   re   r   r   r6   r   r   C  s    
(r   c                       s   e Zd ZdZdef fddZ						ddeej deej dee	 d	ee	 d
ee	 dee	 de
eef fddZdddZ  ZS )TFIdeficsVisionEncoderz
    Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
    [`TFIdeficsVisionEncoderLayer`].

    Args:
        config: IdeficsVisionConfig
    r    c                    s<   t  jdi |  | _ fddt jD | _d| _d S )Nc                    s   g | ]}t  d | dqS )zlayers.r(   )r   ).0ir    r   r   
<listcomp>  s    z3TFIdeficsVisionEncoder.__init__.<locals>.<listcomp>Fr   )r)   r*   r    r\   Znum_hidden_layersr0   gradient_checkpointingr3   r6   r   r   r*     s   

zTFIdeficsVisionEncoder.__init__Nru   rv   rw   output_hidden_statesreturn_dicttrainingr;   c                    s   dur n| j j |dur|n| j j}|dur|n| j j}|r"dnd} r(dnd}	|}
t| jD ]6\}}|r<||
f }| jrR|rR fdd}t|||
||}n||
|| d}|d }
 rg|	|d f }	q1|ro||
f }|s}t	dd	 |
||	fD S t
|
||	d
S )a  
        Args:
            inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
                Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
                This is useful if you want more control over how to convert `input_ids` indices into associated vectors
                than the model's internal embedding lookup matrix.
            attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                [What are attention masks?](../glossary#attention-mask)
            causal_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Causal mask for the text model. Mask values selected in `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                [What are attention masks?](../glossary#attention-mask)
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            output_hidden_states (`bool`, *optional*):
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
                for more detail.
            return_dict (`bool`, *optional*):
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
        Nr   c                    s    fdd}|S )Nc                     s    g | R  S rc   r   )Zinputs)modulerw   r   r   custom_forward  s   zRTFIdeficsVisionEncoder.call.<locals>.create_custom_forward.<locals>.custom_forwardr   )r   r   rw   )r   r   create_custom_forward  s   z:TFIdeficsVisionEncoder.call.<locals>.create_custom_forwardr   r   r   c                 s   s    | ]	}|d ur|V  qd S rc   r   )r   vr   r   r   	<genexpr>  s    z.TFIdeficsVisionEncoder.call.<locals>.<genexpr>)r   r   r   )r    rw   r   use_return_dict	enumerater0   r   r   Zrecompute_gradtupler   )r4   inputs_embedsru   rv   rw   r   r   r   Zencoder_statesZall_attentionsr   idxZencoder_layerr   Zlayer_outputsr   r   r   rZ     sF   '


zTFIdeficsVisionEncoder.callc              	   C   sj   | j rd S d| _ t| dd d ur1| jD ]}t|j |d  W d    n1 s+w   Y  qd S d S )NTr0   )r[   r]   r0   r   r^   r%   r_   )r4   ra   layerr   r   r   r_     s   
zTFIdeficsVisionEncoder.build)NNNNNNrc   )r   r   r   r   r   r*   r   r   r   rd   r   r   r   rZ   r_   re   r   r   r6   r   r     s2    
	
Xr   c                       s   e Zd Zdef fddZ						ddeej dee dee d	ee d
ee dee de	e
ef fddZdddZ  ZS )TFIdeficsVisionTransformerr    c                    sn   t  j|fi | || _|j| _t|dd| _tjj	j
|jdd| _t|dd| _tjj	j
|jdd| _d S )Nr8   r(   pre_layrnormr   encoderpost_layernorm)r)   r*   r    r+   r,   r   r8   r   r/   r0   r   r   r   r   r   r   r3   r6   r   r   r*     s   z#TFIdeficsVisionTransformer.__init__NFrQ   rw   r   rP   r   r   r;   c                 C   s   |dur|n| j j}|dur|n| j j}|dur|n| j j}|du r&td| j||d}| |}| j|||||d}|d }	|	dddddf }
| |
}
|s\|	|
f|dd  S t	|	|
|j
|jdS )z
        Returns:

        Nz You have to specify pixel_values)rP   )r   rw   r   r   r   r   r   )r   Zpooler_outputr   r   )r    rw   r   r   rM   r8   r   r   r   r	   r   r   )r4   rQ   rw   r   rP   r   r   r   Zencoder_outputsr   Zpooled_outputr   r   r   rZ     s4   

zTFIdeficsVisionTransformer.callc                 C   sP  | j rd S d| _ t| dd d ur-t| jj | jd  W d    n1 s(w   Y  t| dd d urVt| jj | jd d | jg W d    n1 sQw   Y  t| dd d ur{t| j	j | j	d  W d    n1 svw   Y  t| dd d urt| j
j | j
d | jg W d    d S 1 sw   Y  d S d S )NTr8   r   r   r   )r[   r]   r   r^   r8   r%   r_   r   r,   r   r   r`   r   r   r   r_   -  s(   "z TFIdeficsVisionTransformer.build)NNNFNFrc   )r   r   r   r   r*   r   r   r   rd   r   r   r	   rZ   r_   re   r   r   r6   r   r     s0    

.r   )&r   rD   dataclassesr   typingr   r   r   Z
tensorflowr   Zactivations_tfr   Zmodeling_tf_outputsr   r	   Zmodeling_tf_utilsr
   r   Ztf_utilsr   utilsr   r   Zconfiguration_ideficsr   Z
get_loggerr   loggerr   r/   r0   ZLayerr   rf   r}   r   r   r   r   r   r   r   <module>   s(   
qy=s