o
    Zhc                     @   s"  d dl mZmZ d dlmZ d dlZd dlmZ	 d dl
mZmZmZ d dlmZ d dlmZmZ ddlmZmZmZ ddlmZmZmZmZ dd	lmZmZ d
dlm Z  dZ!dZ"G dd dej#Z$G dd dej#Z%G dd dej#Z&G dd dej#Z'G dd dej#Z(G dd dej#Z)G dd dej#Z*G dd dej#Z+G dd dej#Z,G d d! d!ej#Z-G d"d# d#ej#Z.G d$d% d%eZ/G d&d' d'ej#Z0ed(e!G d)d* d*e/Z1d+Z2ee1e2 ee1ee d, G d-d. d.ej#Z3ed/e!G d0d1 d1e/Z4d2Z5ee4e5 ee4ee d, g d3Z6dS )4    )OptionalTupleN)
FrozenDictfreezeunfreeze)dot_product_attention_weights)flatten_dictunflatten_dict   )FlaxBaseModelOutputFlaxBaseModelOutputWithPoolingFlaxSequenceClassifierOutput)ACT2FNFlaxPreTrainedModel append_replace_return_docstringsoverwrite_call_docstring)add_start_docstrings%add_start_docstrings_to_model_forward   )	ViTConfiga  

    This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading, saving and converting weights from PyTorch models)

    This model is also a
    [flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it as
    a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and
    behavior.

    Finally, this model supports inherent JAX features such as:

    - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
    - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
    - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
    - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)

    Parameters:
        config ([`ViTConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
        dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
            The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
            `jax.numpy.bfloat16` (on TPUs).

            This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
            specified all the computation will be performed with the given `dtype`.

            **Note that this only specifies the dtype of the computation and does not influence the dtype of model
            parameters.**

            If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
            [`~FlaxPreTrainedModel.to_bf16`].
a  
    Args:
        pixel_values (`numpy.ndarray` of shape `(batch_size, num_channels, height, width)`):
            Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`]
            for details.

        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
c                   @   6   e Zd ZU eed< ejZejed< dd Zdd Z	dS )FlaxViTPatchEmbeddingsconfigdtypec                 C   sp   | j j}| j j}|| ||  }|| _| j j| _tj| j j||f||fd| jt	jj
| j jd ddd| _d S )NZVALID   fan_intruncated_normal)Zkernel_sizestridespaddingr   kernel_init)r   
image_size
patch_sizenum_patchesnum_channelsnnZConvhidden_sizer   jaxinitializersvariance_scalinginitializer_range
projection)selfr    r!   r"    r,   X/var/www/auris/lib/python3.10/site-packages/transformers/models/vit/modeling_flax_vit.pysetup\   s   
zFlaxViTPatchEmbeddings.setupc                 C   sF   |j d }|| jkrtd| |}|j \}}}}t||d|fS )NzeMake sure that the channel dimension of the pixel values match with the one set in the configuration.)shaper#   
ValueErrorr*   jnpreshape)r+   pixel_valuesr#   
embeddings
batch_size_Zchannelsr,   r,   r-   __call__m   s   


zFlaxViTPatchEmbeddings.__call__N
__name__
__module____qualname__r   __annotations__r2   float32r   r.   r8   r,   r,   r,   r-   r   X   s
   
 r   c                   @   s<   e Zd ZU dZeed< ejZejed< dd Z	d
ddZ
d	S )FlaxViTEmbeddingsz7Construct the CLS token, position and patch embeddings.r   r   c                 C   s   |  dtjj| jjd dddd| jjf| _t	| j| j
d| _| jj}|  dtjj| jjd ddd|d | jjf| _tj| jjd| _d S )	N	cls_tokenr   r   r   r   r   position_embeddingsZrate)paramr&   r$   r'   r(   r   r)   r%   r@   r   r   patch_embeddingsr"   rB   Dropouthidden_dropout_probdropout)r+   r"   r,   r,   r-   r.   ~   s   zFlaxViTEmbeddings.setupTc                 C   sZ   |j d }| |}t| j|d| jjf}tj||fdd}|| j }| j	||d}|S )Nr   r   )Zaxisdeterministic)
r0   rE   r2   Zbroadcast_tor@   r   r%   ZconcatenaterB   rH   )r+   r4   rJ   r6   r5   Z
cls_tokensr,   r,   r-   r8      s   


zFlaxViTEmbeddings.__call__NT)r:   r;   r<   __doc__r   r=   r2   r>   r   r.   r8   r,   r,   r,   r-   r?   x   s   
 r?   c                   @   B   e Zd ZU eed< ejZejed< dd Zdde	de	fd	d
Z
dS )FlaxViTSelfAttentionr   r   c                 C   s   | j j| j j dkrtdtj| j j| jtjjj	| j j
d ddd| j jd| _tj| j j| jtjjj	| j j
d ddd| j jd| _tj| j j| jtjjj	| j j
d ddd| j jd| _d S )Nr   z`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`: {self.config.num_attention_heads}r   r   r   )modedistribution)r   r   Zuse_bias)r   r%   num_attention_headsr1   r$   Denser   r&   r'   r(   r)   Zqkv_biasquerykeyvaluer+   r,   r,   r-   r.      s8   zFlaxViTSelfAttention.setupTFrJ   output_attentionsc              
   C   s   | j j| j j }| ||jd d | j j|f }| ||jd d | j j|f }| ||jd d | j j|f }d }|sP| j jdkrP| 	d}t
|||| j jd|| jd d}	td|	|}
|
|
jd d d }
|rz|
|	f}|S |
f}|S )Nr   g        rH   T)dropout_rngZdropout_rateZbroadcast_dropoutrJ   r   	precisionz...hqk,...khd->...qhd)r/   )r   r%   rQ   rS   r3   r0   rU   rT   Zattention_probs_dropout_probZmake_rngr   r   r2   Zeinsum)r+   hidden_statesrJ   rW   Zhead_dimZquery_statesZvalue_statesZ
key_statesrX   Zattn_weightsattn_outputoutputsr,   r,   r-   r8      s:   



zFlaxViTSelfAttention.__call__NTFr:   r;   r<   r   r=   r2   r>   r   r.   boolr8   r,   r,   r,   r-   rN      s
   
  rN   c                   @   >   e Zd ZU eed< ejZejed< dd Zd
de	fddZ
d	S )FlaxViTSelfOutputr   r   c                 C   D   t j| jjtj j| jjd dd| jd| _	t j
| jjd| _d S Nr   r   r   r   r   rC   r$   rR   r   r%   r&   r'   r(   r)   r   denserF   rG   rH   rV   r,   r,   r-   r.         zFlaxViTSelfOutput.setupTrJ   c                 C   s   |  |}| j||d}|S NrI   rf   rH   )r+   rZ   Zinput_tensorrJ   r,   r,   r-   r8      s   
zFlaxViTSelfOutput.__call__NrK   r^   r,   r,   r,   r-   ra      
   
 
ra   c                   @   s>   e Zd ZU eed< ejZejed< dd Zdde	fdd	Z
d
S )FlaxViTAttentionr   r   c                 C   s(   t | j| jd| _t| j| jd| _d S NrA   )rN   r   r   	attentionra   outputrV   r,   r,   r-   r.      s   zFlaxViTAttention.setupTFrW   c                 C   sD   | j |||d}|d }| j|||d}|f}|r ||d f7 }|S NrJ   rW   r   rI   r   )rm   rn   )r+   rZ   rJ   rW   Zattn_outputsr[   r\   r,   r,   r-   r8      s   zFlaxViTAttention.__call__Nr]   r^   r,   r,   r,   r-   rk      s
   
 rk   c                   @   r   )FlaxViTIntermediater   r   c                 C   @   t j| jjtj j| jjd dd| jd| _	t
| jj | _d S Nr   r   r   rd   )r$   rR   r   Zintermediate_sizer&   r'   r(   r)   r   rf   r   Z
hidden_act
activationrV   r,   r,   r-   r.        zFlaxViTIntermediate.setupc                 C   s   |  |}| |}|S Nrf   rt   )r+   rZ   r,   r,   r-   r8     s   

zFlaxViTIntermediate.__call__Nr9   r,   r,   r,   r-   rq   	  
   
 
rq   c                   @   r`   )FlaxViTOutputr   r   c                 C   rb   rc   re   rV   r,   r,   r-   r.   !  rg   zFlaxViTOutput.setupTrJ   c                 C   s$   |  |}| j||d}|| }|S rh   ri   )r+   rZ   attention_outputrJ   r,   r,   r-   r8   +  s   
zFlaxViTOutput.__call__NrK   r^   r,   r,   r,   r-   ry     rj   ry   c                   @   rM   )FlaxViTLayerr   r   c                 C   sf   t | j| jd| _t| j| jd| _t| j| jd| _tj	| jj
| jd| _tj	| jj
| jd| _d S NrA   )epsilonr   )rk   r   r   rm   rq   intermediatery   rn   r$   	LayerNormlayer_norm_epslayernorm_beforelayernorm_afterrV   r,   r,   r-   r.   6  s
   zFlaxViTLayer.setupTFrJ   rW   c                 C   sf   | j | |||d}|d }|| }| |}| |}| j|||d}|f}|r1||d f7 }|S ro   )rm   r   r   r~   rn   )r+   rZ   rJ   rW   Zattention_outputsrz   Zlayer_outputr\   r,   r,   r-   r8   =  s   

zFlaxViTLayer.__call__Nr]   r^   r,   r,   r,   r-   r{   2  s
   
 r{   c                	   @   R   e Zd ZU eed< ejZejed< dd Z				dde	de	d	e	d
e	fddZ
dS )FlaxViTLayerCollectionr   r   c                    s     fddt  jjD  _d S )Nc                    s"   g | ]}t  jt| jd qS ))namer   )r{   r   strr   ).0irV   r,   r-   
<listcomp>[  s    z0FlaxViTLayerCollection.setup.<locals>.<listcomp>)ranger   Znum_hidden_layerslayersrV   r,   rV   r-   r.   Z  s   

zFlaxViTLayerCollection.setupTFrJ   rW   output_hidden_statesreturn_dictc                 C   s   |rdnd }|r
dnd }t | jD ]\}}	|r||f7 }|	|||d}
|
d }|r0||
d f7 }q|r8||f7 }|f}|sFtdd |D S t|||dS )Nr,   rp   r   r   c                 s   s    | ]	}|d ur|V  qd S rv   r,   )r   vr,   r,   r-   	<genexpr>z  s    z2FlaxViTLayerCollection.__call__.<locals>.<genexpr>)last_hidden_staterZ   
attentions)	enumerater   tupler   )r+   rZ   rJ   rW   r   r   Zall_attentionsZall_hidden_statesr   layerZlayer_outputsr\   r,   r,   r-   r8   _  s$   

zFlaxViTLayerCollection.__call__NTFFTr^   r,   r,   r,   r-   r   V  s"   
 r   c                	   @   r   )FlaxViTEncoderr   r   c                 C   s   t | j| jd| _d S rl   )r   r   r   r   rV   r,   r,   r-   r.     s   zFlaxViTEncoder.setupTFrJ   rW   r   r   c                 C   s   | j |||||dS )NrJ   rW   r   r   )r   )r+   rZ   rJ   rW   r   r   r,   r,   r-   r8     s   zFlaxViTEncoder.__call__Nr   r^   r,   r,   r,   r-   r     s"   
 r   c                   @   r   )FlaxViTPoolerr   r   c                 C   rr   rs   )r$   rR   r   Zpooler_output_sizer&   r'   r(   r)   r   rf   r   Z
pooler_actrt   rV   r,   r,   r-   r.     ru   zFlaxViTPooler.setupc                 C   s$   |d d df }|  |}| |S )Nr   rw   )r+   rZ   Zcls_hidden_stater,   r,   r-   r8     s   

zFlaxViTPooler.__call__Nr9   r,   r,   r,   r-   r     rx   r   c                       s   e Zd ZU dZeZdZdZdZe	j
ed< ddejdfded	ed
ejdef fddZddejjdededefddZeed						ddee dejjdedee dee dee fddZ  ZS )FlaxViTPreTrainedModelz
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    vitr4   Nmodule_classr   Tr   seedr   _do_initc                    sL   | j d||d|}|d u rd|j|j|jf}t j||||||d d S )N)r   r   r   )input_shaper   r   r   r,   )r   r    r#   super__init__)r+   r   r   r   r   r   kwargsmodule	__class__r,   r-   r     s   	zFlaxViTPreTrainedModel.__init__rngr   paramsreturnc           
      C   s   t j|| jd}tj|\}}||d}| jj||ddd }|d urFtt	|}tt	|}| j
D ]}	||	 ||	< q3t | _
tt|S |S )NrA   )r   rH   F)r   r   )r2   Zzerosr   r&   randomsplitr   initr   r   Z_missing_keyssetr   r	   )
r+   r   r   r   r4   Z
params_rngrX   rngsZrandom_paramsZmissing_keyr,   r,   r-   init_weights  s   

z#FlaxViTPreTrainedModel.init_weightszbatch_size, sequence_lengthFrX   trainrW   r   r   c           	   	   C   s   |d ur|n| j j}|d ur|n| j j}|d ur|n| j j}t|d}i }|d ur.||d< | jjd|p6| jitj	|tj
d| ||||dS )N)r   r   r
   r   rH   r   rA   )r   )r   rW   r   r   r2   Z	transposer   applyr   arrayr>   )	r+   r4   r   rX   r   rW   r   r   r   r,   r,   r-   r8     s"   zFlaxViTPreTrainedModel.__call__rv   )NNFNNN)r:   r;   r<   rL   r   config_classZbase_model_prefixZmain_input_namer   r$   Moduler=   r2   r>   intr   r_   r   r&   r   ZPRNGKeyr   r   r   r   VIT_INPUTS_DOCSTRINGformatr   dictr8   __classcell__r,   r,   r   r-   r     sP   
  r   c                	   @   s^   e Zd ZU eed< ejZejed< dZe	ed< dd Z
				dde	d	e	d
e	de	fddZdS )FlaxViTModuler   r   Tadd_pooling_layerc                 C   s`   t | j| jd| _t| j| jd| _tj| jj| jd| _	| j
r+t| j| jd| _d S d | _d S r|   )r?   r   r   r5   r   encoderr$   r   r   	layernormr   r   poolerrV   r,   r,   r-   r.      s   &zFlaxViTModule.setupFrJ   rW   r   r   c           	      C   s   | j ||d}| j|||||d}|d }| |}| jr"| |nd }|s=|d u r3|f|dd   S ||f|dd   S t|||j|jdS )NrI   r   r   r   )r   Zpooler_outputrZ   r   )r5   r   r   r   r   r   rZ   r   )	r+   r4   rJ   rW   r   r   rZ   r\   Zpooledr,   r,   r-   r8     s*   
zFlaxViTModule.__call__Nr   )r:   r;   r<   r   r=   r2   r>   r   r   r_   r.   r8   r,   r,   r,   r-   r     s$   
 	r   z]The bare ViT Model transformer outputting raw hidden-states without any specific head on top.c                   @      e Zd ZeZdS )FlaxViTModelN)r:   r;   r<   r   r   r,   r,   r,   r-   r   )  s    r   a  
    Returns:

    Examples:

    ```python
    >>> from transformers import AutoImageProcessor, FlaxViTModel
    >>> from PIL import Image
    >>> import requests

    >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
    >>> image = Image.open(requests.get(url, stream=True).raw)

    >>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
    >>> model = FlaxViTModel.from_pretrained("google/vit-base-patch16-224-in21k")

    >>> inputs = image_processor(images=image, return_tensors="np")
    >>> outputs = model(**inputs)
    >>> last_hidden_states = outputs.last_hidden_state
    ```
)output_typer   c                   @   sH   e Zd ZU eed< ejZejed< dd Z					d
de	fdd	Z
dS )#FlaxViTForImageClassificationModuler   r   c                 C   sF   t | j| jdd| _tj| jj| jtjj	| jj
d ddd| _d S )NF)r   r   r   r   r   r   )r   r   )r   r   r   r   r$   rR   Z
num_labelsr&   r'   r(   r)   
classifierrV   r,   r,   r-   r.   O  s   z)FlaxViTForImageClassificationModule.setupNTrJ   c           
      C   sx   |d ur|n| j j}| j|||||d}|d }| |d d dd d f }|s3|f|dd   }	|	S t||j|jdS )Nr   r   r   )logitsrZ   r   )r   Zuse_return_dictr   r   r   rZ   r   )
r+   r4   rJ   rW   r   r   r\   rZ   r   rn   r,   r,   r-   r8   Y  s$   z,FlaxViTForImageClassificationModule.__call__)NTNNNr^   r,   r,   r,   r-   r   K  s   
 r   z
    ViT Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
    the [CLS] token) e.g. for ImageNet.
    c                   @   r   )FlaxViTForImageClassificationN)r:   r;   r<   r   r   r,   r,   r,   r-   r   y  s    r   ag  
    Returns:

    Example:

    ```python
    >>> from transformers import AutoImageProcessor, FlaxViTForImageClassification
    >>> from PIL import Image
    >>> import jax
    >>> import requests

    >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
    >>> image = Image.open(requests.get(url, stream=True).raw)

    >>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")
    >>> model = FlaxViTForImageClassification.from_pretrained("google/vit-base-patch16-224")

    >>> inputs = image_processor(images=image, return_tensors="np")
    >>> outputs = model(**inputs)
    >>> logits = outputs.logits

    >>> # model predicts one of the 1000 ImageNet classes
    >>> predicted_class_idx = jax.numpy.argmax(logits, axis=-1)
    >>> print("Predicted class:", model.config.id2label[predicted_class_idx.item()])
    ```
)r   r   r   )7typingr   r   Z
flax.linenZlinenr$   r&   Z	jax.numpynumpyr2   Zflax.core.frozen_dictr   r   r   Zflax.linen.attentionr   Zflax.traverse_utilr   r	   Zmodeling_flax_outputsr   r   r   Zmodeling_flax_utilsr   r   r   r   utilsr   r   Zconfiguration_vitr   ZVIT_START_DOCSTRINGr   r   r   r?   rN   ra   rk   rq   ry   r{   r   r   r   r   r   r   ZFLAX_VISION_MODEL_DOCSTRINGr   r   ZFLAX_VISION_CLASSIF_DOCSTRING__all__r,   r,   r,   r-   <module>   s\   # !G$+N.
.
