o
    Zh*                     @   s  d Z ddlZddlZddlZddlmZmZmZmZ ddl	Z	ddl
Z	ddl	mZ ddlmZ ddlmZmZmZ ddlmZ dd	lmZ dd
lmZmZmZ ddlmZ ddlmZmZmZ ddl m!Z!m"Z"m#Z# ddl$m%Z% e"&e'Z(dd Z)G dd dej*Z+G dd dej*Z,G dd dej*Z-G dd dej*Z.e!G dd deZ/e!G dd de/Z0e!ddG d d! d!e/eZ1e!d"dG d#d$ d$e/Z2g d%Z3dS )&zPyTorch OpenAI ImageGPT model.    N)AnyOptionalTupleUnion)nn)autocast)BCEWithLogitsLossCrossEntropyLossMSELoss   )ACT2FN)GenerationMixin))BaseModelOutputWithPastAndCrossAttentions!CausalLMOutputWithCrossAttentions SequenceClassifierOutputWithPast)PreTrainedModel)Conv1D find_pruneable_heads_and_indicesprune_conv1d_layer)auto_docstringloggingtorch_float   )ImageGPTConfigc                 C   s  z
ddl }ddl}W n ty   td  w tj|}td	| |j
|}g }g }|D ] \}	}
td	|	|
 |j
||	}||	 ||  q2t||D ]\}	}|	dd }	|	d}	tdd	 |	D sw|	d
 dv rtd	d|	 qX| }|	d
 dvrt|d}|	D ]}|d|r|d|}n|g}|d dks|d dkrt|d}n|d dkrt|d}nw|d dks|d dkrt||d }t|d}n^|d dv rt|d}t|d}nMt|	dkr|	d dkr|d dkrt||d }t|d}n+|d dkr t|d}t|d}n|d d kr2t|d}t|d}nt||d }t|d!krJt|d }|| }qt|	dkrY|	d dksn|	d
 dksn|	d
 d ksn|	d
 dkron%z|j|jksyJ W n ty } z| j|j|jf7  _ d}~ww td"	|	 |	d
 d#krt||j|jj|jddd|jf< qX|	d
 d$krt||j|jj|jdd|jd!|j f< qX|	d
 d%krt||j|jj|jddd!|j df< qXt|	dkr|	d dkr|	d! dkrt||j|j|_qX|	d
 dkr+t||_qX|	d
 dkrDt||jd|j d ddf< qX|	d
 d krTt||jd
< qXt||_qX| S )&z0
    Load tf checkpoints in a pytorch model
    r   NzLoading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see https://www.tensorflow.org/install/ for installation instructions.z(Converting TensorFlow checkpoint from {}z"Loading TF weight {} with shape {}   /c                 s   s    | ]}|d v V  qdS ))Zadam_vZadam_mZAdamWeightDecayOptimizerZAdamWeightDecayOptimizer_1Zglobal_stepN ).0nr   r   ]/var/www/auris/lib/python3.10/site-packages/transformers/models/imagegpt/modeling_imagegpt.py	<genexpr>Q   s
    
z.load_tf_weights_in_imagegpt.<locals>.<genexpr>)Z_stepzSkipping {})wtettransformerz[A-Za-z]+\d+z(\d+)wgweightbbiaswpewte)q_projk_projv_projc_attnr   r   attnc_projr"   lm_headsos   zInitialize PyTorch weight {}r+   r,   r-   )!reZ
tensorflowImportErrorloggererrorospathabspathinfoformattrainZlist_variablesZload_variableappendsqueezezipsplitanyjoingetattr	fullmatchlenintshapeAssertionErrorargstorchZ
from_numpyreshapen_embdTdata
vocab_size)modelconfigZimagegpt_checkpoint_pathr4   tfZtf_pathZ	init_varsnamesZarraysnamerH   arrayZpointerZm_nameZscope_namesnumer   r   r   load_tf_weights_in_imagegpt0   s   



*

F.2*$rY   c                       sB   e Zd Zd
dee def fddZdejdejfdd	Z	  Z
S )ImageGPTLayerNormh㈵>hidden_sizeepsc                    s&   t    || _tt|| _d S N)super__init__r]   r   	ParameterrK   Tensorr&   )selfr\   r]   	__class__r   r   r`      s   
zImageGPTLayerNorm.__init__tensorreturnc                 C   s4   |t t jt |ddd| j  }|| j }|S )Nr!   T)ZaxisZkeepdim)rK   sqrtmeanZsquarer]   r&   )rc   rf   r   r   r   forward   s   &
zImageGPTLayerNorm.forward)r[   )__name__
__module____qualname__r   rG   floatr`   rK   rb   rj   __classcell__r   r   rd   r   rZ      s    rZ   c                       s   e Zd Zddee dee f fddZdd Zdd	d
ZdddZ	dd Z
dd Z							ddejdee deej deej deej deej dee dee defddZ  ZS )ImageGPTAttentionFNis_cross_attention	layer_idxc                    sF  t    |j}| jdttj||ftjddd||dd | jdt	ddd |j
| _|j| _| j| j | _| j| _| j| j | jkrUtd| j d	| j d
|j| _|| _|j| _|| _|j| _| jr}td| j | j| _t| j| j| _n
td| j | j| _t| j| j| _t|j| _t|j| _t  | _!d S )Nr(   dtyper   F)
persistentZmasked_biasg     z=`embed_dim` must be divisible by num_heads (got `embed_dim`: z and `num_heads`: z).r3   r   )"r_   r`   max_position_embeddingsZregister_bufferrK   Ztrilonesboolviewrf   r\   	embed_dimZnum_attention_heads	num_headshead_dim
split_size
ValueErrorscale_attn_weightsrq   scale_attn_by_inverse_layer_idxrr   reorder_and_upcast_attnr   r.   q_attnr0   r   DropoutZ
attn_pdropattn_dropoutresid_pdropresid_dropoutsetpruned_heads)rc   rR   rq   rr   Zmax_positionsrd   r   r   r`      sB   

zImageGPTAttention.__init__c                 C   s   t |dkrd S t|| j| j| j\}}t||| j |d| j  g}t| j	|dd| _	t| j
|dd| _
| j| j | jt |  | _| jt | | _| j|| _d S )Nr   r3   r   dim)rF   r   r{   r|   r   rK   catr}   r   r.   r0   union)rc   headsindexZ
index_attnr   r   r   prune_heads   s    zImageGPTAttention.prune_headsc                 C   s  t ||dd}| jr|t|dd  }| jr$|t| jd  }| j	s]|d|d}}| j
d d d d || |d |f }	t |jj}
t j|
|j|jd}
t |	||
}|d ure|| }tjdd|}||j}| |}|d ur|| }t ||}||fS )Nr!         ?r   rt   devicer   )rK   matmul	transposer   r   sizer   rn   rr   rq   r(   finfort   minrf   r   wherer   Softmaxtyper   )rc   querykeyvalueattention_mask	head_maskattn_weightsquery_length
key_lengthcausal_mask
mask_valueattn_outputr   r   r   _attn   s(   &
zImageGPTAttention._attnc                 C   s  |  \}}}}	|  \}
}
}}
tj|| ||tj|jd}d}| jr.|t| dd  }| jr:|t| jd  }t	dd1 |
d||	|dd
d|	|}}tj|| | d	|d
}|
||||}W d    n1 ssw   Y  | js| d| d}}| jd d d d || |d |f }t|jj}tj||j|jd}t|||}|d ur|| }tjdd|}|jtjkrtd||j}| |}|d ur|| }t||}||fS )Nr         ?r!   r   r   F)enabledr   r   )betaalphar   zDError with upcasting, attn_weights does not have dtype torch.float32)r   rK   emptyZfloat32r   r   rn   r   rr   r   rL   r   Zbaddbmmrq   r(   r   rt   r   rf   r   r   r   RuntimeErrorr   r   r   )rc   r   r   r   r   r   Zbszr{   Z	q_seq_lenZdk_Z	k_seq_lenr   Zscale_factorqkr   r   r   r   r   r   r   r   _upcast_and_reordered_attn
  s<   &&
z,ImageGPTAttention._upcast_and_reordered_attnc                 C   s2   |  dd ||f }|j| }|ddddS )zJ
        Splits hidden_size dim into attn_head_size and num_heads
        Nr!   r   r3   r   r   )r   ry   permuterc   rf   r{   Zattn_head_sizeZ	new_shaper   r   r   _split_heads>  s   
zImageGPTAttention._split_headsc                 C   s8   | dddd }| dd || f }||S )zS
        Merges attn_head_size dim and num_attn_heads dim into hidden_size
        r   r3   r   r   Nr   )r   
contiguousr   ry   r   r   r   r   _merge_headsF  s   
zImageGPTAttention._merge_headshidden_states
layer_pastr   r   encoder_hidden_statesencoder_attention_mask	use_cacheoutput_attentionsrg   c	                 C   sT  |d ur"t | dstd| |}	| |j| jdd\}
}|}n| |j| jdd\}	}
}| |	| j| j}	| |
| j| j}
| || j| j}|d ure|\}}t	j
||
fdd}
t	j
||fdd}|du rn|
|f}nd }| jr| |	|
|||\}}n| |	|
|||\}}| || j| j}| |}| |}||f}|r||f7 }|S )Nr   zIf class is used as cross attention, the weights `q_attn` have to be defined. Please make sure to instantiate class with `ImageGPTAttention(..., is_cross_attention=True)`.r3   r   r   T)hasattrr~   r   r.   rA   r}   r   r{   r|   rK   r   r   r   r   r   r0   r   )rc   r   r   r   r   r   r   r   r   r   r   r   Zpast_keyZ
past_valueZpresentr   r   outputsr   r   r   rj   N  s:   





zImageGPTAttention.forward)FN)NNNNNNNFF)rk   rl   rm   r   rx   rG   r`   r   r   r   r   r   rK   rb   tuplerj   ro   r   r   rd   r   rp      sB     +

&4	
rp   c                       s2   e Zd Z fddZdejdejfddZ  ZS )ImageGPTMLPc                    sF   t    |j}t||| _t||| _t|j | _t	
|j| _d S r^   )r_   r`   r\   r   c_fcr0   r   Zactivation_functionactr   r   r   dropout)rc   Zintermediate_sizerR   rz   rd   r   r   r`     s   
zImageGPTMLP.__init__r   rg   c                 C   s,   |  |}| |}| |}| |}|S r^   )r   r   r0   r   )rc   r   r   r   r   rj     s
   



zImageGPTMLP.forward)rk   rl   rm   r`   rK   rb   rj   ro   r   r   rd   r   r     s    r   c                       s   e Zd Zd fdd	Z							ddejdee deej deej d	eej d
eej dee dee defddZ	  Z
S )ImageGPTBlockNc                    s   t    |j}|jd ur|jnd| }t||jd| _t||d| _t||jd| _	|j
r>t|d|d| _t||jd| _t||| _d S )N   r]   rr   T)rq   rr   )r_   r`   r\   Zn_innerrZ   layer_norm_epsilonln_1rp   r/   ln_2add_cross_attentioncrossattentionln_cross_attnr   mlp)rc   rR   rr   r\   Z	inner_dimrd   r   r   r`     s   
zImageGPTBlock.__init__Fr   r   r   r   r   r   r   r   rg   c	                 C   s   |}	|  |}| j||||||d}
|
d }|
dd  }||	 }|d urSt| ds1td|  d|}	| |}| j||||||d}|d }|	| }||dd   }|}	| |}| |}|	| }|f|rl| }|S |dd   }|S )	N)r   r   r   r   r   r   r   r   z'If `encoder_hidden_states` are passed, z` has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`)r   r   r   r   r   r3   )r   r/   r   r~   r   r   r   r   )rc   r   r   r   r   r   r   r   r   ZresidualZattn_outputsr   r   Zcross_attn_outputsZfeed_forward_hidden_statesr   r   r   rj     sN   





zImageGPTBlock.forwardr^   r   )rk   rl   rm   r`   rK   rb   r   rx   r   rj   ro   r   r   rd   r   r     s8    	
r   c                       s>   e Zd ZeZeZdZdZdZ	dgZ
 fddZdd Z  ZS )	ImageGPTPreTrainedModelr#   	input_idsTr   c                    s   t  j|i | d S r^   )r_   r`   )rc   Zinputskwargsrd   r   r   r`     s   z ImageGPTPreTrainedModel.__init__c                 C   s   t |tjtfr|jjjd| jjd |j	dur|j	j
  n,t |tjr?|jjjd| jjd |jdur>|jj|j 
  nt |trK|jjd | D ]\}}d|v rnd|v rn|jjd| jjtd| jj  d qOdS )zInitialize the weights.g        )ri   ZstdNr   r0   r&   r3   )
isinstancer   Linearr   r&   rO   Znormal_rR   Zinitializer_ranger(   Zzero_	EmbeddingZpadding_idxrZ   Zfill_Znamed_parametersmathrh   n_layer)rc   modulerU   pr   r   r   _init_weights  s"   


&z%ImageGPTPreTrainedModel._init_weights)rk   rl   rm   r   Zconfig_classrY   Zload_tf_weightsZbase_model_prefixZmain_input_nameZsupports_gradient_checkpointingZ_no_split_modulesr`   r   ro   r   r   rd   r   r     s    r   c                "       s   e Zd Zdef fddZdd Zdd Zdd	 Ze	
	
	
	
	
	
	
	
	
	
	
	
	
dde	e
j de	eee
j   de	e
j de	e
j de	e
j de	e
j de	e
j de	e
j de	e
j de	e de	e de	e de	e dedeeef fddZ  ZS )ImageGPTModelrR   c                    s   t     j| _t j| j| _t j| j| _	t
 j| _t fddt jD | _t| j jd| _d| _d | _d| _|   d S )Nc                    s   g | ]}t  |d qS )r   )r   )r   irR   r   r   
<listcomp>  s    z*ImageGPTModel.__init__.<locals>.<listcomp>r   F)r_   r`   r\   rz   r   r   rP   r*   rv   r)   r   Z
embd_pdropdropZ
ModuleListrangeZnum_hidden_layershrZ   r   ln_fmodel_parallel
device_mapgradient_checkpointing	post_initrc   rR   rd   r   r   r`   	  s    zImageGPTModel.__init__c                 C      | j S r^   r*   rc   r   r   r   get_input_embeddings     z"ImageGPTModel.get_input_embeddingsc                 C   
   || _ d S r^   r   rc   Znew_embeddingsr   r   r   set_input_embeddings     
z"ImageGPTModel.set_input_embeddingsc                 C   s(   |  D ]\}}| j| j| qdS )zv
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
        N)itemsr   r/   r   )rc   Zheads_to_prunelayerr   r   r   r   _prune_heads"  s   zImageGPTModel._prune_headsNr   past_key_valuesr   token_type_idsposition_idsr   inputs_embedsr   r   r   r   output_hidden_statesreturn_dictr   rg   c           $         s  d|v rt dt |durtd|d}|dur|n| jj}|dur'|n| jj}|
dur1|
n| jj}
|dur;|n| jj	}|durK|durKtd|durg| 
|| | }|d|d }|jd }n|dury| dd }|jd }ntd|dur|jn|j}|dur|d|d }|du rd}tdgt| j }n	|d d d	}|du rtj||d | tj|d
}|d}|dur|dkrtd||d}|ddddddf }|j| jd}d| t| jj }| jjr|dur| \}}}||f}|	du rtj||d}	| |	}	nd}	| || jj}|du r/| |}|  |}|||j  |durJ| |} |  | !  | df }| j"ri| j#ri|
rit$%d d}
|
rndnd}|rudnd}|r| jjrdnd}|rdnd}t&t'| j|D ]\}\}} | j(rtj)* j | durt fdd| D } |dur| j}t+|tj,r| j}|r| f }| j"r| j#r| -|j. d||| ||	|
|	}!n| | ||| ||	|
|d}!|!d  |
du r
||!d f }|r*||!|
rdnd f }| jjr*||!|
r%dnd f }| j(rU| j/0 D ]!\}"}#||#d krSdt1|" | j2krS dt1|"d   q3q| 3   j|  |ri| f }|sztdd  ||||fD S t4 ||||dS )a  
        input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
            `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
            `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
            sequence tokens in the vocabulary.

            If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
            `input_ids`.

            Indices can be obtained using [`AutoImageProcessor`]. See [`ImageGPTImageProcessor.__call__`] for details.

        Examples:

        ```python
        >>> from transformers import AutoImageProcessor, ImageGPTModel
        >>> from PIL import Image
        >>> import requests

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> image_processor = AutoImageProcessor.from_pretrained("openai/imagegpt-small")
        >>> model = ImageGPTModel.from_pretrained("openai/imagegpt-small")

        >>> inputs = image_processor(images=image, return_tensors="pt")
        >>> outputs = model(**inputs)
        >>> last_hidden_states = outputs.last_hidden_state
        ```pixel_values`The `pixel_values` argument is deprecated and will be removed in v4.47, use `input_ids` instead.N_You cannot pass both `pixel_values` and `input_ids`. Please make sure to only pass `input_ids`.zDYou cannot specify both input_ids and inputs_embeds at the same timer!   r   z5You have to specify either input_ids or inputs_embedsr   r   z$batch_size has to be defined and > 0rs   r   )r   zZ`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...Fr   c                 3   s    | ]	}|  jV  qd S r^   )tor   r   Z
past_stater   r   r   r      s    z(ImageGPTModel.forward.<locals>.<genexpr>)r   r   r   r   r   r   r   Tr   r3   r   zcuda:c                 s   s    | ]	}|d ur|V  qd S r^   r   )r   vr   r   r   r      s    )Zlast_hidden_stater   r   
attentionscross_attentions)5warningswarnFutureWarningr~   poprR   r   r   r   use_return_dictZ%warn_if_padding_and_no_attention_maskr   ry   rH   r   r   rF   r   rK   ZarangelongZ	unsqueezer   rt   r   r   r   rw   Zinvert_attention_maskZget_head_maskr   r*   r)   r   r   Ztrainingr6   Zwarning_once	enumerater@   r   cudaZ
set_devicer   rb   Z_gradient_checkpointing_func__call__r   r   strZlast_devicer   r   )$rc   r   r   r   r   r   r   r   r   r   r   r   r   r   r   Zinput_shapeZ
batch_sizer   Zpast_lengthZencoder_batch_sizeZencoder_sequence_lengthr   Zencoder_hidden_shapeZposition_embedsZtoken_type_embedsZoutput_shapeZpresentsZall_self_attentionsZall_cross_attentionsZall_hidden_statesr   blockr   r   r   r   r   r   r   rj   )  s  /













"


zImageGPTModel.forward)NNNNNNNNNNNNN)rk   rl   rm   r   r`   r   r   r   r   r   rK   rb   r   rx   r   r   r   rj   ro   r   r   rd   r   r     sd    	

r   z
    The ImageGPT Model transformer with a language modeling head on top (linear layer with weights tied to the input
    embeddings).
    )Zcustom_introc                $       s0  e Zd ZdgZdef fddZdd Zdd Ze																												dd
e	e
j de	eee
j   de	e
j de	e
j de	e
j de	e
j de	e
j de	e
j de	e
j de	e
j de	e de	e de	e de	e dedeeef f ddZedeee
j  de
jdeee
j  fddZ  ZS ) ImageGPTForCausalImageModelingzlm_head.weightrR   c                    sH   t  | t|| _tj|j|jd dd| _d| _	d | _
|   d S )Nr   Fr(   )r_   r`   r   r#   r   r   rM   rP   r1   r   r   r   r   rd   r   r   r`     s   
z'ImageGPTForCausalImageModeling.__init__c                 C   r   r^   r1   r   r   r   r   get_output_embeddings%  r   z4ImageGPTForCausalImageModeling.get_output_embeddingsc                 C   r   r^   r  r   r   r   r   set_output_embeddings(  r   z4ImageGPTForCausalImageModeling.set_output_embeddingsNr   r   r   r   r   r   r   r   r   labelsr   r   r   r   r   rg   c                 K   s  d|v rt dt |durtd|d}|dur|n| jj}| j|||||||||	||||d}|d }| |}d}|
durk|dddddf 	 }|
dd	df 	 }t
 }||d|d|d}|s|f|d	d  }|dur|f| S |S t|||j|j|j|jd
S )a%
  
        input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
            `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
            `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
            sequence tokens in the vocabulary.

            If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
            `input_ids`.

            Indices can be obtained using [`AutoImageProcessor`]. See [`ImageGPTImageProcessor.__call__`] for details.
        labels (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*):
            Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
            `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
            are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`

        Examples:

        ```python
        >>> from transformers import AutoImageProcessor, ImageGPTForCausalImageModeling
        >>> import torch
        >>> import matplotlib.pyplot as plt
        >>> import numpy as np

        >>> image_processor = AutoImageProcessor.from_pretrained("openai/imagegpt-small")
        >>> model = ImageGPTForCausalImageModeling.from_pretrained("openai/imagegpt-small")
        >>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        >>> model.to(device)  # doctest: +IGNORE_RESULT

        >>> # unconditional generation of 8 images
        >>> batch_size = 4
        >>> context = torch.full((batch_size, 1), model.config.vocab_size - 1)  # initialize with SOS token
        >>> context = context.to(device)
        >>> output = model.generate(
        ...     input_ids=context, max_length=model.config.n_positions + 1, temperature=1.0, do_sample=True, top_k=40
        ... )

        >>> clusters = image_processor.clusters
        >>> height = image_processor.size["height"]
        >>> width = image_processor.size["width"]

        >>> samples = output[:, 1:].detach().cpu().numpy()
        >>> samples_img = [
        ...     np.reshape(np.rint(127.5 * (clusters[s] + 1.0)), [height, width, 3]).astype(np.uint8) for s in samples
        ... ]  # convert color cluster tokens back to pixels
        >>> f, axes = plt.subplots(1, batch_size, dpi=300)

        >>> for img, ax in zip(samples_img, axes):  # doctest: +IGNORE_RESULT
        ...     ax.axis("off")
        ...     ax.imshow(img)
        ```r   r   Nr   )r   r   r   r   r   r   r   r   r   r   r   r   r   .r!   r   )losslogitsr   r   r  r  )r  r  r  r~   r  rR   r  r#   r1   r   r	   ry   r   r   r   r   r  r  )rc   r   r   r   r   r   r   r   r   r   r  r   r   r   r   r   transformer_outputsr   Z	lm_logitsr  Zshift_logitsZshift_labelsloss_fctoutputr   r   r   rj   +  sZ   F

z&ImageGPTForCausalImageModeling.forwardbeam_idxc                    s   t  fdd| D S )a  
        This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
        [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
        beam_idx at every generation step.
        c                 3   s&    | ]}t  fd d|D V  qdS )c                 3   s$    | ]}| d  |jV  qdS )r   N)Zindex_selectr   r   r   r  r   r   r      s   " zJImageGPTForCausalImageModeling._reorder_cache.<locals>.<genexpr>.<genexpr>Nr   )r   r   r  r   r   r      s
    
z@ImageGPTForCausalImageModeling._reorder_cache.<locals>.<genexpr>r  )r   r  r   r  r   _reorder_cache  s   	z-ImageGPTForCausalImageModeling._reorder_cache)NNNNNNNNNNNNNN)rk   rl   rm   Z_tied_weights_keysr   r`   r  r  r   r   rK   rb   r   rx   r   r   r   rj   staticmethodr  ro   r   r   rd   r   r    sz    	

}r  z
    The ImageGPT Model transformer with an image classification head on top (linear layer).
    [`ImageGPTForImageClassification`] average-pools the hidden states in order to do the classification.
    c                        s   e Zd Zdef fddZe												ddeej dee	e	ej   deej deej d	eej d
eej deej deej dee
 dee
 dee
 dee
 dedee	ef fddZ  ZS )ImageGPTForImageClassificationrR   c                    s@   t  | |j| _t|| _tj|j| jdd| _| 	  d S )NFr  )
r_   r`   
num_labelsr   r#   r   r   rM   scorer   r   rd   r   r   r`     s
   
z'ImageGPTForImageClassification.__init__Nr   r   r   r   r   r   r   r  r   r   r   r   r   rg   c                 K   s  d|v rt dt |durtd|d}|dur|n| jj}| j||||||||	|
||d}|d }|jdd}| 	|}d}|dur| jj
du rp| jdkrVd	| j_
n| jdkrl|jtjksg|jtjkrld
| j_
nd| j_
| jj
d	krt }| jdkr|| | }n+|||}n%| jj
d
krt }||d| j|d}n| jj
dkrt }|||}|s|f|dd  }|dur|f| S |S t|||j|j|jdS )ax  
        input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
            `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
            `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
            sequence tokens in the vocabulary.

            If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
            `input_ids`.

            Indices can be obtained using [`AutoImageProcessor`]. See [`ImageGPTImageProcessor.__call__`] for details.
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).

        Examples:

        ```python
        >>> from transformers import AutoImageProcessor, ImageGPTForImageClassification
        >>> from PIL import Image
        >>> import requests

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> image_processor = AutoImageProcessor.from_pretrained("openai/imagegpt-small")
        >>> model = ImageGPTForImageClassification.from_pretrained("openai/imagegpt-small")

        >>> inputs = image_processor(images=image, return_tensors="pt")
        >>> outputs = model(**inputs)
        >>> logits = outputs.logits
        ```r   r   Nr   )
r   r   r   r   r   r   r   r   r   r   r   r   r   Z
regressionZsingle_label_classificationZmulti_label_classificationr!   )r  r  r   r   r  )r  r  r  r~   r  rR   r  r#   ri   r   Zproblem_typer  rt   rK   r  rG   r
   r?   r	   ry   r   r   r   r   r  )rc   r   r   r   r   r   r   r   r  r   r   r   r   r   r  r   Zpooled_hidden_statesr  r  r  r  r   r   r   rj     sp   2



"


z&ImageGPTForImageClassification.forward)NNNNNNNNNNNN)rk   rl   rm   r   r`   r   r   rK   rb   r   rx   r   r   r   rj   ro   r   r   rd   r   r    sX    		

r  )r  r  r   r   rY   )4__doc__r   r8   r  typingr   r   r   r   rK   Ztorch.utils.checkpointr   Ztorch.cuda.ampr   Ztorch.nnr   r	   r
   Zactivationsr   Z
generationr   Zmodeling_outputsr   r   r   Zmodeling_utilsr   Zpytorch_utilsr   r   r   utilsr   r   r   Zconfiguration_imagegptr   Z
get_loggerrk   r6   rY   ModulerZ   rp   r   r   r   r   r  r  __all__r   r   r   r   <module>   sR   
l \K&   " 