o
    Zh                     @   s  d dl mZmZmZmZ d dlZd dlmZ d dl	Z	d dl
mZ d dlZd dlmZmZmZ d dlmZ d dlmZmZ ddlmZmZmZmZ ddlmZmZmZm Z  dd	l!m"Z"m#Z# d
dl$m%Z% ej&j'G dd deZ(dZ)dZ*dee+e+f dej,fddZ-ej.fddZ/G dd dej0Z1G dd dej0Z2G dd dej0Z3G dd dej0Z4G dd dej0Z5G d d! d!ej0Z6G d"d# d#ej0Z7G d$d% d%ej0Z8G d&d' d'ej0Z9G d(d) d)ej0Z:G d*d+ d+ej0Z;G d,d- d-ej0Z<G d.d/ d/eZ=G d0d1 d1ej0Z>G d2d3 d3ej0Z?e"d4e)G d5d6 d6e=Z@d7ZAe e@eA ee@e(e%d8 G d9d: d:ej0ZBe"d;e)G d<d= d=e=ZCd>ZDe eCeD eeCee%d8 G d?d@ d@ej0ZEe"dAe)G dBdC dCe=ZFdDZGe eFeG eeFee%d8 g dEZHdS )F    )CallableListOptionalTupleN)
FrozenDictfreezeunfreeze)dot_product_attention_weights)flatten_dictunflatten_dict   )FlaxBaseModelOutputFlaxBaseModelOutputWithPoolingFlaxMaskedLMOutputFlaxSequenceClassifierOutput)ACT2FNFlaxPreTrainedModel append_replace_return_docstringsoverwrite_call_docstring)add_start_docstrings%add_start_docstrings_to_model_forward   )
BeitConfigc                   @   s   e Zd ZdZdS )FlaxBeitModelOutputWithPoolinga  
    Class for outputs of [`FlaxBeitModel`].

    Args:
        last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        pooler_output (`jnp.ndarray` of shape `(batch_size, hidden_size)`):
            Average of the last layer hidden states of the patch tokens (excluding the *[CLS]* token) if
            *config.use_mean_pooling* is set to True. If set to False, then the final hidden state of the *[CLS]* token
            will be returned.
        hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
            `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus
            the initial embedding outputs.
        attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
            the self-attention heads.
    N)__name__
__module____qualname____doc__ r   r   Z/var/www/auris/lib/python3.10/site-packages/transformers/models/beit/modeling_flax_beit.pyr   ,   s    r   a  

    This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading, saving and converting weights from PyTorch models)

    This model is also a
    [flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it as
    a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and
    behavior.

    Finally, this model supports inherent JAX features such as:

    - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
    - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
    - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
    - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)

    Parameters:
        config ([`BeitConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
        dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
            The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
            `jax.numpy.bfloat16` (on TPUs).

            This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
            specified all the computation will be performed with the given `dtype`.

            **Note that this only specifies the dtype of the computation and does not influence the dtype of model
            parameters.**

            If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
            [`~FlaxPreTrainedModel.to_bf16`].
a  
    Args:
        pixel_values (`numpy.ndarray` of shape `(batch_size, num_channels, height, width)`):
            Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
            [`AutoImageProcessor.__call__`] for details.

        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
window_sizereturnc                 C   s  d| d  d d| d  d  d }t | d }t | d }t t j||dd}t |d}|dddddf |dddddf  }t |d	}|dddddf  | d d 7  < |dddddf  | d d 7  < |dddddf  d| d  d 9  < t j| d | d  d fd |jd
}|d|ddddf< |d |dddf< |d |dddf< |d |d< t	
|S )zP
    get pair-wise relative position index for each token inside the window
       r   r   r   Zij)Zindexing)r"   N)r   r"   r   shapedtyper#   )r   r   )npZarangestackZmeshgridreshape	transposezerosr&   sumjnparray)r    num_relative_distanceZcoords_hZcoords_wZcoordsZcoords_flattenZrelative_coordsrelative_position_indexr   r   r   relative_position_index_initw   s    $,&&*&
r1   c                 C   s   t ||| S N)r-   Zones)keyr%   scaler&   r   r   r   ones_with_scale   s   r5   c                   @   s6   e Zd ZU dZeed< ejjdde	e
 fddZdS )	FlaxBeitDropPathzXDrop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).rateTdeterministicc           	      C   sv   | j dkr|S d| j  }|r|S |jd fd|jd   }| d}|tjj|||jd }t	|}|| | }|S )N        g      ?r   )r   r   droppathr$   )
r7   r%   ndimmake_rngjaxrandomuniformr&   r-   floor)	selfZinputsr8   Z	keep_probr%   rngZrandom_tensorZbinary_tensoroutputr   r   r   __call__   s   



zFlaxBeitDropPath.__call__NT)r   r   r   r   float__annotations__nnmodulecompactr   boolrD   r   r   r   r   r6      s
   
 r6   c                   @   6   e Zd ZU eed< ejZejed< dd Zdd Z	dS )FlaxBeitPatchEmbeddingsconfigr&   c              	   C   s~   | j j| _| j j}| j j}|| ||  }|| || f}|| _|| _tj| j j||f||fd| j	t
jj| j jd| _d S )NZVALID)Zkernel_sizestridespaddingr&   kernel_init)rN   num_channels
image_size
patch_sizenum_patchespatch_shaperH   ZConvhidden_sizer&   r=   initializersnormalinitializer_range
projection)rA   rS   rT   rU   rV   r   r   r   setup   s   
zFlaxBeitPatchEmbeddings.setupc                 C   sF   |j d }|| jkrtd| |}|j \}}}}t||d|fS )Nr#   zeMake sure that the channel dimension of the pixel values match with the one set in the configuration.)r%   rR   
ValueErrorr[   r-   r)   )rA   pixel_valuesrR   
embeddings
batch_size_Zchannelsr   r   r   rD      s   


z FlaxBeitPatchEmbeddings.__call__N
r   r   r   r   rG   r-   float32r&   r\   rD   r   r   r   r   rM      s
   
 rM   c                   @   s<   e Zd ZU dZeed< ejZejed< dd Z	d
dd	Z
dS )FlaxBeitEmbeddingsz7Construct the CLS token, position and patch embeddings.rN   r&   c                 C   s   |  dtjjdd| jjf| _| jjr"|  dtjjdd| jjf| _t	| j| j
d| _| jj}| jjrD|  dtjjd|d | jjf| _tj| jjd| _d S )N	cls_tokenr   
mask_tokenr&   position_embeddingsr7   )paramrH   rX   r+   rN   rW   re   Zuse_mask_tokenrf   rM   r&   patch_embeddingsrU    use_absolute_position_embeddingsrh   Dropouthidden_dropout_probdropout)rA   rU   r   r   r   r\      s   zFlaxBeitEmbeddings.setupNTc                 C   s   |  |}|j\}}}t| j|d| jjf}||j}|d urDt| j	||| jjf}	|	|j}	tj
|dd}
|d|
  |	|
  }tj||fdd}| jjrZ|| j|j }| j||d}|S )Nr   r#   Zaxisr8   )rk   r%   r-   Zbroadcast_tore   rN   rW   astyper&   rf   expand_dimsZconcatenaterl   rh   ro   )rA   r^   bool_masked_posr8   r_   r`   Zseq_lenra   Z
cls_tokensZmask_tokenswr   r   r   rD      s   
zFlaxBeitEmbeddings.__call__)NT)r   r   r   r   r   rG   r-   rc   r&   r\   rD   r   r   r   r   rd      s   
 rd   c                   @   sF   e Zd ZU eed< eeef ed< ejZ	ej	ed< dd Z
dd ZdS )	FlaxBeitRelativePositionBiasrN   r    r&   c                 C   sT   d| j d  d d| j d  d  d }| dtjj|| jjf| _t| j | _	d S )Nr"   r   r   r   relative_position_bias_table)
r    rj   rH   rX   r+   rN   num_attention_headsrw   r1   r0   )rA   r/   r   r   r   r\      s   (
z"FlaxBeitRelativePositionBias.setupc                 C   sZ   | j d}| jd | jd  d | jd | jd  d df}| j| |}t|dS )Nr#   r   r   )r"   r   r   )r0   r)   r    rw   r-   r*   )rA   indexr%   relative_position_biasr   r   r   rD     s   2z%FlaxBeitRelativePositionBias.__call__N)r   r   r   r   rG   r   intr-   rc   r&   r\   rD   r   r   r   r   rv      s   
 rv   c                   @   sT   e Zd ZU eed< eeef ed< ejZ	ej	ed< dd Z
	dd	ed
efddZdS )FlaxBeitSelfAttentionrN   r    r&   c                 C   s   | j j| j j dkrt| j dstd| j j d| j j dtj| j j| jtjj	
| j jd| _tj| j j| jtjj	
| j jdd| _tj| j j| jtjj	
| j jd| _| jrit| j | j| jd	| _d S d | _d S )
Nr   Zembedding_sizezThe hidden size z4 is not a multiple of the number of attention heads .)r&   rQ   F)r&   rQ   Zuse_biasr    r&   )rN   rW   rx   hasattrr]   rH   Denser&   r=   rX   rY   rZ   queryr3   valuer    rv   rz   rA   r   r   r   r\     s:   zFlaxBeitSelfAttention.setupNTFr8   output_attentionsc                 C   sN  | j j| j j }| ||jd d | j j|f }| ||jd d | j j|f }| ||jd d | j j|f }d }	|sP| j jdkrP| 	d}	t
jd| jd}
| jd urkt
|  d}
|
|j}
|d urw|
||
j }
t|||
|	| j jd|| jd d	}t
d||}||jd d d	 }|r||f}|S |f}|S )
Nr"   r9   ro   rg   r   T)Zbiasdropout_rngZdropout_rateZbroadcast_dropoutr8   r&   	precisionz...hqk,...khd->...qhd)r#   )rN   rW   rx   r   r)   r%   r   r3   Zattention_probs_dropout_probr<   r-   r.   r&   rz   rs   rr   r	   Zeinsum)rA   hidden_statesrz   r8   r   Zhead_dimZquery_statesZvalue_statesZ
key_statesr   Zattention_biasZattn_weightsattn_outputoutputsr   r   r   rD   -  sH   




zFlaxBeitSelfAttention.__call__NTFr   r   r   r   rG   r   r{   r-   rc   r&   r\   rK   rD   r   r   r   r   r|     s   
 !r|   c                   @   >   e Zd ZU eed< ejZejed< dd Zd
de	fddZ
d	S )FlaxBeitSelfOutputrN   r&   c                 C   <   t j| jjtj j| jj| jd| _	t j
| jjd| _d S NrQ   r&   ri   rH   r   rN   rW   r=   rX   rY   rZ   r&   denserm   rn   ro   r   r   r   r   r\   a     zFlaxBeitSelfOutput.setupTr8   c                 C      |  |}| j||d}|S Nrq   r   ro   rA   r   r8   r   r   r   rD   i  s   
zFlaxBeitSelfOutput.__call__NrE   r   r   r   r   rG   r-   rc   r&   r\   rK   rD   r   r   r   r   r   ]  
   
 r   c                   @   sP   e Zd ZU eed< eeef ed< ejZ	ej	ed< dd Z
	dd	efd
dZdS )FlaxBeitAttentionrN   r    r&   c                 C   s,   t | j| j| jd| _t| j| jd| _d S )Nrg   )r|   rN   r    r&   	attentionr   rC   r   r   r   r   r\   t  s   zFlaxBeitAttention.setupNTFr   c                 C   sD   | j ||||d}|d }| j||d}|f}|r ||d f7 }|S Nr8   r   r   rq   r   )r   rC   )rA   r   rz   r8   r   Zattn_outputsr   r   r   r   r   rD   x  s   zFlaxBeitAttention.__call__r   r   r   r   r   r   r   o  s   
 r   c                   @   rL   )FlaxBeitIntermediaterN   r&   c                 C   s8   t j| jjtj j| jj| jd| _	t
| jj | _d S )Nr   )rH   r   rN   Zintermediate_sizer=   rX   rY   rZ   r&   r   r   Z
hidden_act
activationr   r   r   r   r\     s   zFlaxBeitIntermediate.setupc                 C   s   |  |}| |}|S r2   )r   r   )rA   r   r   r   r   rD     s   

zFlaxBeitIntermediate.__call__Nrb   r   r   r   r   r     s
   
 r   c                   @   r   )FlaxBeitOutputrN   r&   c                 C   r   r   r   r   r   r   r   r\     r   zFlaxBeitOutput.setupTr8   c                 C   r   r   r   r   r   r   r   rD     s   
zFlaxBeitOutput.__call__NrE   r   r   r   r   r   r     r   r   c                   @   s\   e Zd ZU eed< eeef ed< eed< ej	Z
ej
ed< dd Z		dd
edefddZdS )FlaxBeitLayerrN   r    drop_path_rater&   c                 C   s   t | j| j| jd| _t| j| jd| _t| j| jd| _t	j
| jj| jd| _t| jd| _t	j
| jj| jd| _| jj| _| jdkr^| dt| jj| j| _| dt| jj| j| _d S d | _d | _d S )Nrg   epsilonr&   ri   r   lambda_1lambda_2)r   rN   r    r&   r   r   intermediater   rC   rH   	LayerNormlayer_norm_epslayernorm_beforer6   r   	drop_pathlayernorm_afterZlayer_scale_init_valueZinit_valuesrj   r5   rW   r   r   r   r   r   r   r\     s   


zFlaxBeitLayer.setupNTFr8   r   c           	      C   s   | j | ||||d}|d }| jd ur| j|j| }| j||d| }| |}| |}| j||d}| j	d urF| j	|j| }| j||d| }|f}|r[||d f7 }|S r   )
r   r   r   rr   r&   r   r   r   rC   r   )	rA   r   rz   r8   r   Zself_attention_outputsZattention_outputZlayer_outputr   r   r   r   rD     s(   



zFlaxBeitLayer.__call__r   )r   r   r   r   rG   r   r{   rF   r-   rc   r&   r\   rK   rD   r   r   r   r   r     s   
 r   c                	   @   s   e Zd ZU eed< eeef ed< ee ed< e	g e
jf ed< e
jZe
jed< dd Z						dd
edededefddZdS )FlaxBeitLayerCollectionrN   r    drop_path_ratesrz   r&   c                    s     fddt  jjD  _d S )Nc              	      s:   g | ]}t  j jjr jnd  j| t| jdqS )N)r    r   namer&   )r   rN   Zuse_relative_position_biasr    r   strr&   ).0ir   r   r   
<listcomp>  s    z1FlaxBeitLayerCollection.setup.<locals>.<listcomp>)rangerN   num_hidden_layerslayersr   r   r   r   r\     s   

zFlaxBeitLayerCollection.setupTFr8   r   output_hidden_statesreturn_dictc                 C   s   |rdnd }|r
dnd }t | jD ]+\}}	|r||f7 }| jd ur%|  nd }
|	||
||d}|d }|r<||d f7 }q|rD||f7 }|f}|sRtdd |D S t|||dS )Nr   r   r   r   c                 s   s    | ]	}|d ur|V  qd S r2   r   )r   vr   r   r   	<genexpr>  s    z3FlaxBeitLayerCollection.__call__.<locals>.<genexpr>)last_hidden_stater   
attentions)	enumerater   rz   tupler   )rA   r   r8   r   r   r   Zall_attentionsZall_hidden_statesr   layerrz   Zlayer_outputsr   r   r   r   rD     s*   

z FlaxBeitLayerCollection.__call__NTFFT)r   r   r   r   rG   r   r{   r   rF   r   r-   ndarrayrc   r&   r\   rK   rD   r   r   r   r   r     s(   
 r   c                	   @   sb   e Zd ZU eed< eeef ed< ejZ	ej	ed< dd Z
				dded	ed
edefddZdS )FlaxBeitEncoderrN   r    r&   c                 C   sd   | j jrt| j | j| jd| _ttd| j j	| j j
}t| j | j|| j jr)| jnd | jd| _d S )N)rN   r    r&   r   )r    r   rz   r&   )rN   Z!use_shared_relative_position_biasrv   r    r&   rz   listr'   Zlinspacer   r   r   r   )rA   r   r   r   r   r\   (  s   zFlaxBeitEncoder.setupTFr8   r   r   r   c                 C   s   | j |||||dS )Nr8   r   r   r   )r   )rA   r   r8   r   r   r   r   r   r   rD   :  s   zFlaxBeitEncoder.__call__Nr   r   r   r   r   r   r   #  s$   
 r   c                       s   e Zd ZU dZeZdZdZdZe	j
ed< ddejdfded	ed
ejdef fddZddejjdededefddZeed							ddee dejjdedee dee dee fddZ  ZS )FlaxBeitPreTrainedModelz
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    beitr^   Nmodule_classr   TrN   seedr&   _do_initc                    sL   | j d||d|}|d u rd|j|j|jf}t j||||||d d S )N)rN   r&   r   )input_shaper   r&   r   r   )r   rS   rR   super__init__)rA   rN   r   r   r&   r   kwargsrI   	__class__r   r   r   V  s   	z FlaxBeitPreTrainedModel.__init__rB   r   paramsr!   c                 C   s   t j|| jd}tj|\}}tj|\}}|||d}| jj||ddd }	|d urOtt	|	}	tt	|}| j
D ]}
|	|
 ||
< q<t | _
tt|S |	S )Nrg   )r   ro   r:   F)r   r   )r-   r+   r&   r=   r>   splitrI   initr
   r   Z_missing_keyssetr   r   )rA   rB   r   r   r^   Z
params_rngr   droppath_rngrngsZrandom_paramsZmissing_keyr   r   r   init_weightsd  s   
z$FlaxBeitPreTrainedModel.init_weightszbatch_size, sequence_lengthFr   trainr   r   r   c	              
   C   s   |d ur|n| j j}|d ur|n| j j}|d ur|n| j j}t|d}i }	|d ur:tj|\}}
||	d< |
|	d< | j	j
d|pB| jitj|tjd|| ||||	dS )N)r   r"   r   r   ro   r:   r   rg   )r   )rN   r   r   r   r-   r*   r=   r>   r   rI   applyr   r.   rc   )rA   r^   rt   r   r   r   r   r   r   r   r   r   r   r   rD   x  s(   z FlaxBeitPreTrainedModel.__call__r2   )NNNFNNN)r   r   r   r   r   config_classZbase_model_prefixZmain_input_namer   rH   ModulerG   r-   rc   r{   r&   rK   r   r=   r>   ZPRNGKeyr   r   r   r   BEIT_INPUTS_DOCSTRINGformatr   dictrD   __classcell__r   r   r   r   r   K  sR   
  	r   c                   @   rL   )FlaxBeitPoolerrN   r&   c                 C   s&   | j jrtj| j j| jd| _d S d S )Nr   )rN   use_mean_poolingrH   r   r   r&   	layernormr   r   r   r   r\     s   zFlaxBeitPooler.setupc                 C   sN   | j jr|d d dd d d f }| tj|dd}|S |d d df }|S )Nr   rp   r   )rN   r   r   r-   mean)rA   r   Zpatch_tokenspooled_outputr   r   r   rD     s   zFlaxBeitPooler.__call__Nrb   r   r   r   r   r     s
   
 r   c                	   @   s`   e Zd ZU eed< ejZejed< dZe	ed< dd Z
					dd	e	d
e	de	de	fddZdS )FlaxBeitModulerN   r&   Tadd_pooling_layerc                 C   sp   t | j| jd| _t| j| jjj| jd| _| jjs%t	j
| jj| jd| _| jr3t| j| jd| _d S d | _d S )Nrg   r~   r   )rd   rN   r&   r_   r   rk   rV   encoderr   rH   r   r   r   r   r   poolerr   r   r   r   r\     s   &zFlaxBeitModule.setupNFr8   r   r   r   c           
      C   s   | j |||d}| j|||||d}|d }| jjs| |}| jr'| |nd }	|sB|	d u r8|f|dd   S ||	f|dd   S t||	|j|j	dS )Nrq   r   r   r   )r   Zpooler_outputr   r   )
r_   r   rN   r   r   r   r   r   r   r   )
rA   r^   rt   r8   r   r   r   r   r   Zpooledr   r   r   rD     s,   	
zFlaxBeitModule.__call__)NTFFT)r   r   r   r   rG   r-   rc   r&   r   rK   r\   rD   r   r   r   r   r     s&   
 r   z^The bare Beit Model transformer outputting raw hidden-states without any specific head on top.c                   @      e Zd ZeZdS )FlaxBeitModelN)r   r   r   r   r   r   r   r   r   r         r   a  
    Returns:

    Examples:

    ```python
    >>> from transformers import AutoImageProcessor, FlaxBeitModel
    >>> from PIL import Image
    >>> import requests

    >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
    >>> image = Image.open(requests.get(url, stream=True).raw)

    >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/beit-base-patch16-224-pt22k-ft22k")
    >>> model = FlaxBeitModel.from_pretrained("microsoft/beit-base-patch16-224-pt22k-ft22k")

    >>> inputs = image_processor(images=image, return_tensors="np")
    >>> outputs = model(**inputs)
    >>> last_hidden_states = outputs.last_hidden_state
    ```
)output_typer   c                   @   J   e Zd ZU eed< ejZejed< dd Z						d
de	fdd	Z
dS )$FlaxBeitForMaskedImageModelingModulerN   r&   c                 C   sT   t | jd| jd| _tj| jj| jd| _tj| jj	t
jj| jj| jd| _d S )NF)r   r&   r   r   )r   rN   r&   r   rH   r   r   r   r   Z
vocab_sizer=   rX   rY   rZ   lm_headr   r   r   r   r\     s   z*FlaxBeitForMaskedImageModelingModule.setupNTr8   c                 C   s   |d ur|n| j j}| j||||||d}|d }| |}| |d d dd f }	|s8|	f|dd   }
|
S t|	|j|jdS )Nr   r   r   r"   logitsr   r   )rN   use_return_dictr   r   r   r   r   r   )rA   r^   rt   r8   r   r   r   r   Zsequence_outputZprediction_scoresrC   r   r   r   rD     s(   		
z-FlaxBeitForMaskedImageModelingModule.__call__NNTNNNr   r   r   r   r   r     s   
 r   zYBeit Model transformer with a 'language' modeling head on top (to predict visual tokens).c                   @   r   )FlaxBeitForMaskedImageModelingN)r   r   r   r   r   r   r   r   r   r   9  r   r   a?  
    bool_masked_pos (`numpy.ndarray` of shape `(batch_size, num_patches)`):
        Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).

    Returns:

    Examples:

    ```python
    >>> from transformers import AutoImageProcessor, BeitForMaskedImageModeling
    >>> from PIL import Image
    >>> import requests

    >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
    >>> image = Image.open(requests.get(url, stream=True).raw)

    >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/beit-base-patch16-224-pt22k")
    >>> model = BeitForMaskedImageModeling.from_pretrained("microsoft/beit-base-patch16-224-pt22k")

    >>> inputs = image_processor(images=image, return_tensors="np")
    >>> outputs = model(**inputs)
    >>> logits = outputs.logits
    ```
c                   @   r   )$FlaxBeitForImageClassificationModulerN   r&   c                 C   s>   t | j| jdd| _tj| jjtjj	| jj
| jd| _d S )NT)rN   r&   r   r   )r   rN   r&   r   rH   r   Z
num_labelsr=   rX   rY   rZ   
classifierr   r   r   r   r\   d  s   z*FlaxBeitForImageClassificationModule.setupNTr8   c                 C   sf   |d ur|n| j j}| j|||||d}|d }| |}	|s*|	f|dd   }
|
S t|	|j|jdS )Nr   r   r"   r   )rN   r   r   r   r   r   r   )rA   r^   rt   r8   r   r   r   r   r   r   rC   r   r   r   rD   l  s$   	
z-FlaxBeitForImageClassificationModule.__call__r   r   r   r   r   r   r   `  s   
 
r   z
    Beit Model transformer with an image classification head on top (a linear layer on top of the average of the final
    hidden states of the patch tokens) e.g. for ImageNet.
    c                   @   r   )FlaxBeitForImageClassificationN)r   r   r   r   r   r   r   r   r   r     s    r   aM  
    Returns:

    Example:

    ```python
    >>> from transformers import AutoImageProcessor, FlaxBeitForImageClassification
    >>> from PIL import Image
    >>> import requests

    >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
    >>> image = Image.open(requests.get(url, stream=True).raw)

    >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/beit-base-patch16-224")
    >>> model = FlaxBeitForImageClassification.from_pretrained("microsoft/beit-base-patch16-224")

    >>> inputs = image_processor(images=image, return_tensors="np")
    >>> outputs = model(**inputs)
    >>> logits = outputs.logits
    >>> # model predicts one of the 1000 ImageNet classes
    >>> predicted_class_idx = logits.argmax(-1).item()
    >>> print("Predicted class:", model.config.id2label[predicted_class_idx])
    ```
)r   r   r   r   )Itypingr   r   r   r   ZflaxZ
flax.linenZlinenrH   r=   Z	jax.numpynumpyr-   r'   Zflax.core.frozen_dictr   r   r   Zflax.linen.attentionr	   Zflax.traverse_utilr
   r   Zmodeling_flax_outputsr   r   r   r   Zmodeling_flax_utilsr   r   r   r   utilsr   r   Zconfiguration_beitr   struct	dataclassr   ZBEIT_START_DOCSTRINGr   r{   r   r1   rc   r5   r   r6   rM   rd   rv   r|   r   r   r   r   r   r   r   r   r   r   r   ZFLAX_BEIT_MODEL_DOCSTRINGr   r   ZFLAX_BEIT_MLM_DOCSTRINGr   r   ZFLAX_BEIT_CLASSIF_DOCSTRING__all__r   r   r   r   <module>   s   # )U=7(S3
2
-
