o
    ZhN                     @   s  d dl Z d dlmZ d dlmZmZ d dlmZ d dlm	Z	 d dl
mZmZ d dlmZmZ d dlmZ d d	lmZmZmZ ed
Zg dZG dd deZdededefddZdefddZdd Zdededeee ee eef fddZ 	d>dededede	e def
dd Z!ded!e"d"ed#e#fd$d%Z$ded!e"d"efd&d'Z%		(	)d?deded"ed!e"de	e d#e#defd*d+Z&d,edefd-d.Z'd,edefd/d0Z(d1efd2d3Z)e*d4krme Z+e+, Z-ee-j./ e-_.zce0d5 e&e-j1e-j2e-j.e-j3e-j4e-j5e-j e-j(r"ee e-j1d6kre0d7 e0d8 e'e-j.e-_6e(e-j6e-_7e-j8rHe0d9 e)e-j. e9e-d:r:e)e-j6 e9e-d;rKe)e-j7 W dS W dS W dS  e:yl Z; ze0d<e;  e<d= W Y dZ;[;dS dZ;[;ww dS )@    N)ArgumentParser)listdirmakedirs)Path)Optional)Versionparse)Pipelinepipeline)BatchEncoding)ModelOutputis_tf_availableis_torch_availablez1.4.0)	feature-extractionZnerzsentiment-analysisz	fill-maskzquestion-answeringztext-generationZtranslation_en_to_frZtranslation_en_to_deZtranslation_en_to_roc                       s    e Zd ZdZ fddZ  ZS )OnnxConverterArgumentParserz[
    Wraps all the script arguments supported to export transformers models to ONNX IR
    c                    s   t  d | jdttdd | jdtddd | jd	td
d | jdtddgdd | jdtddd | jdddd | jdddd | jdddd | d d S )NzONNX Converterz
--pipeliner   )typechoicesdefaultz--modelTz4Model's id or path (ex: google-bert/bert-base-cased))r   requiredhelpz--tokenizerz8Tokenizer's id or path (ex: google-bert/bert-base-cased))r   r   z--frameworkpttfzFramework for loading the model)r   r   r   z--opset   zONNX opset to use)r   r   r   z--check-loading
store_truez$Check ONNX is able to load the model)actionr   z--use-external-formatz!Allow exporting model >= than 2Gbz
--quantizez/Quantize the neural network to be run with int8output)super__init__add_argumentstrSUPPORTED_PIPELINESint)self	__class__ Q/var/www/auris/lib/python3.10/site-packages/transformers/convert_graph_to_onnx.pyr   3   sJ   z$OnnxConverterArgumentParser.__init__)__name__
__module____qualname____doc__r   __classcell__r%   r%   r#   r&   r   .   s    r   filename
identifierreturnc                 C   s   | j | j| | jS )aE  
    Append a string-identifier at the end (before the extension, if any) to the provided filepath

    Args:
        filename: pathlib.Path The actual path object we would like to add an identifier suffix
        identifier: The suffix to add

    Returns: String with concatenated identifier at the end of the filename
    )parentjoinpathstemwith_suffixsuffix)r,   r-   r%   r%   r&   generate_identified_filename\   s   
r4   minimum_versionc                 C   sP   zddl }t|j}|tk rtd|j d|  dW dS  ty'   tdw )z
    Check onnxruntime is installed and if the installed version match is recent enough

    Raises:
        ImportError: If onnxruntime is not installed or too old version is found
    r   Nz*We found an older version of onnxruntime (z&) but we require onnxruntime to be >= zp to enable all the conversions options.
Please update onnxruntime by running `pip install --upgrade onnxruntime`zonnxruntime doesn't seem to be currently installed. Please install the onnxruntime by running `pip install onnxruntime` and relaunch the conversion.)onnxruntimer   __version__ORT_QUANTIZE_MINIMUM_VERSIONImportError)r5   r6   Zort_versionr%   r%   r&   check_onnxruntime_requirementsi   s   

r:   c                 C   sx   t d | jjj}g g }}|dd D ]}||v r'|| |||  qt | d  t d|  |t|fS )a  
    Ensure inputs are presented in the correct order, without any Non

    Args:
        model: The model used to forward the input data
        tokens: BatchEncoding holding the input data
        input_names: The name of the inputs

    Returns: Tuple

    z$Ensuring inputs are in correct order   Nz, is not present in the generated input list.zGenerated inputs order: )printforward__code__co_varnamesappendtuple)modeltokensinput_namesZmodel_args_name
model_argsordered_input_namesZarg_namer%   r%   r&   ensure_valid_input   s   


rG   nlp	frameworkc                    s  dt dtdtf fdd | jd|d}|jjd |d	kr&| jdi |n| |}t|tr4|	 }t|t
tfs>|f}t
| } fd
d| D }g }|D ]}t|tt
frc|| qT|| qTdd tt|D } fddt||D }	t|fi |	}
|||
|fS )a?  
    Attempt to infer the static vs dynamic axes for each input and output tensors for a specific model

    Args:
        nlp: The pipeline object holding the model to be exported
        framework: The framework identifier to dispatch to the correct inference scheme (pt/tf)

    Returns:

        - List of the inferred input variable names
        - List of the inferred output variable names
        - Dictionary with input/output variables names as key and shape tensor as value
        - a BatchEncoding reference which was used to infer all the above information
    nameis_inputseq_lenc                    s   t |ttfr fdd|D S dd t|jD d di} r:t|jdkr/d|d< n td	t|j d
fddt|jD }|t	|d t
d rUdnd d d|  |S )Nc                    s   g | ]	} |qS r%   r%   ).0t)build_shape_dictrK   rJ   rL   r%   r&   
<listcomp>   s    z:infer_shapes.<locals>.build_shape_dict.<locals>.<listcomp>c                 S   s   g | ]
\}}|d kr|qS )r;   r%   )rM   ZaxisZnumelr%   r%   r&   rP          r   batch   sequencer;   zUnable to infer tensor axes ()c                    s   g | ]
\}}| kr|qS r%   r%   )rM   dimshape)rL   r%   r&   rP      rQ   zFound inputr    z with shape: )
isinstancerA   list	enumeraterW   len
ValueErrorupdatedictfromkeysr<   )rJ   tensorrK   rL   ZaxesZseq_axes)rO   )rK   rJ   rL   r&   rO      s   
"z&infer_shapes.<locals>.build_shape_dictzThis is a sample output)Zreturn_tensorsr   c                        i | ]\}}| ||d qS )Tr%   rM   kvrO   rL   r%   r&   
<dictcomp>        z infer_shapes.<locals>.<dictcomp>c                 S   s   g | ]}d | qS )Zoutput_r%   )rM   ir%   r%   r&   rP      s    z infer_shapes.<locals>.<listcomp>c                    rd   )Fr%   re   rh   r%   r&   ri      rj   Nr%   )r   boolr!   	tokenizerZ	input_idsrW   rB   rZ   r   Zto_tupler[   rA   keysitemsextendr@   ranger]   zipr`   )rH   rI   rC   ZoutputsZ
input_varsZinput_dynamic_axesZoutputs_flatr   output_namesZoutput_dynamic_axesdynamic_axesr%   rh   r&   infer_shapes   s&   "
ru   pipeline_namerB   rm   c                 K   s`   |du r|}|dkrt  std|dkrt stdtd| d| d t| ||||d	S )
a  
    Convert the set of arguments provided through the CLI to an actual pipeline reference (tokenizer + model

    Args:
        pipeline_name: The kind of pipeline to use (ner, question-answering, etc.)
        framework: The actual model to convert the pipeline from ("pt" or "tf")
        model: The model name which will be loaded by the pipeline
        tokenizer: The tokenizer name which will be loaded by the pipeline, default to the model's value

    Returns: Pipeline object

    Nr   LCannot convert because PyTorch is not installed. Please install torch first.r   LCannot convert because TF is not installed. Please install tensorflow first.zLoading pipeline (model: z, tokenizer: rU   )rB   rm   rI   model_kwargs)r   	Exceptionr   r<   r
   )rv   rI   rB   rm   Zmodels_kwargsr%   r%   r&   load_graph_from_args   s   r{   opsetr   use_external_formatc                 C   s   t  stdddl}ddlm} td|j  | * t| d\}}}}	t	| j
|	|\}
}|| j
|| |
||d|d W d   dS 1 sJw   Y  dS )	a  
    Export a PyTorch backed pipeline to ONNX Intermediate Representation (IR

    Args:
        nlp: The pipeline to be exported
        opset: The actual version of the ONNX operator set to use
        output: Path where will be stored the generated ONNX model
        use_external_format: Split the model definition from its parameters to allow model bigger than 2GB

    Returns:

    rw   r   N)exportzUsing framework PyTorch: r   T)frD   rs   rt   Zdo_constant_foldingZopset_version)r   rz   torchZ
torch.onnxr~   r<   r7   Zno_gradru   rG   rB   as_posix)rH   r|   r   r}   r   r~   rD   rs   rt   rC   rF   rE   r%   r%   r&   convert_pytorch  s&   
"r   c              
      s   t  stdtd zGddl ddl}ddlm} td jj d|  t| d\}}}}| j	
|j  fd	d
| D }	|jj| j	|	|| d\}
}W dS  tym } ztd|j d|j d| d}~ww )av  
    Export a TensorFlow backed pipeline to ONNX Intermediate Representation (IR)

    Args:
        nlp: The pipeline to be exported
        opset: The actual version of the ONNX operator set to use
        output: Path where will be stored the generated ONNX model

    Notes: TensorFlow cannot export model bigger than 2GB due to internal constraint from TensorFlow

    rx   zD/!\ Please note TensorFlow doesn't support exporting model > 2Gb /!\r   N)r7   zUsing framework TensorFlow: z, tf2onnx: r   c                    s    g | ]\}} j j||d qS ))rJ   )Z
TensorSpecZfrom_tensor)rM   keyrb   r   r%   r&   rP   C  rj   z&convert_tensorflow.<locals>.<listcomp>)r|   Zoutput_pathzCannot import z6 required to convert TF model to ONNX. Please install z first. )r   rz   r<   Z
tensorflowtf2onnxr7   versionVERSIONru   rB   Zpredictdataro   convertZ
from_kerasr   r9   rJ   )rH   r|   r   r   Zt2ovrD   rs   rt   rC   Zinput_signatureZmodel_proto_er%   r   r&   convert_tensorflow&  s(   r   Fr   c           	      K   s   t dt td|  t|| ||fi |}|j s-td|j  t|j  nt	t
|j dkrCtd|j  d| dkrPt|||| dS t||| dS )	a  
    Convert the pipeline object to the ONNX Intermediate Representation (IR) format

    Args:
        framework: The framework the pipeline is backed by ("pt" or "tf")
        model: The name of the model to load for the pipeline
        output: The path where the ONNX graph will be stored
        opset: The actual version of the ONNX operator set to use
        tokenizer: The name of the model to load for the pipeline, default to the model's name if not provided
        use_external_format:
            Split the model definition from its parameters to allow model bigger than 2GB (PyTorch only)
        pipeline_name: The kind of pipeline to instantiate (ner, question-answering, etc.)
        model_kwargs: Keyword arguments to be forwarded to the model constructor

    Returns:

    zoThe `transformers.convert_graph_to_onnx` package is deprecated and will be removed in version 5 of TransformerszONNX opset version set to: zCreating folder r   zFolder z" is not empty, aborting conversionr   N)warningswarnFutureWarningr<   r{   r/   existsr   r   r]   r   rz   r   r   )	rI   rB   r   r|   rm   r}   rv   ry   rH   r%   r%   r&   r   N  s   
r   onnx_model_pathc                 C   sT   ddl m}m} t| d}| }| |_||  |}td| d td |S )a>  
    Load the model at the specified path and let onnxruntime look at transformations on the graph to enable all the
    optimizations possible

    Args:
        onnx_model_path: filepath where the model binary description is stored

    Returns: Path where the optimized model binary description has been saved

    r   InferenceSessionSessionOptionsz
-optimizedz$Optimized model has been written at    : ✔zY/!\ Optimized model contains hardware specific operators which might not be portable. /!\)r6   r   r   r4   r   Zoptimized_model_filepathr<   )r   r   r   Zopt_model_pathZsess_optionr   r%   r%   r&   optimize  s   

r   c                 C   s  ddl }ddl}ddlm} ddlm} ddlm} ddlm	} |
|  }t|jtdk r4td | }|| t|jtd	k rX||d
d
|jd
dd
dddt|d}	n||d
d
|jd
dd
dddt|d}	|	  t| d}
td|
 d ||	jj|
  |
S )z
    Quantize the weights of the model from float32 to in8 to allow very efficient inference on modern CPU

    Args:
        onnx_model_path: Path to location the exported ONNX model is stored

    Returns: The Path generated for the quantized
    r   N)
ModelProto)QuantizationMode)ONNXQuantizer)IntegerOpsRegistryz1.5.0zpModels larger than 2GB will fail to quantize due to protobuf constraint.
Please upgrade to onnxruntime >= 1.5.0.z1.13.1FT)rB   per_channelreduce_rangemodestaticweight_qTypeZinput_qTypetensors_rangenodes_to_quantizenodes_to_excludeop_types_to_quantize)rB   r   r   r   r   r   Zactivation_qTyper   r   r   r   z
-quantizedz$Quantized model has been written at r   )onnxr6   Zonnx.onnx_pbr   Zonnxruntime.quantizationr   Z'onnxruntime.quantization.onnx_quantizerr   Z!onnxruntime.quantization.registryr   loadr   r   r7   r<   ZCopyFromZ
IntegerOpsr[   Zquantize_modelr4   Z
save_modelrB   )r   r   r6   r   r   r   r   Z
onnx_modelZ
copy_modelZ	quantizerZquantized_model_pathr%   r%   r&   quantize  sZ   	

r   pathc              
   C   s   ddl m}m} ddlm} td|  d z| }||  |dgd}td|  d	 W d S  |yI } ztd
| d W Y d }~d S d }~ww )Nr   r   )RuntimeExceptionz"Checking ONNX model loading from: z ...ZCPUExecutionProvider)	providerszModel u    correctly loaded: ✔zError while loading the model u   : ✘)r6   r   r   Z+onnxruntime.capi.onnxruntime_pybind11_stater   r<   r   )r   r   r   r   Zonnx_optionsr   rer%   r%   r&   verify  s   r   __main__z'
====== Converting model to ONNX ======r   aV  	 Using TensorFlow might not provide the same optimization level compared to PyTorch.
	 For TensorFlow users you can try optimizing the model directly through onnxruntime_tools.
	 For more information, please refer to the onnxruntime documentation:
		https://github.com/microsoft/onnxruntime/tree/master/onnxruntime/python/tools/transformers
z$
====== Optimizing ONNX model ======z+
====== Check exported ONNX model(s) ======optimized_outputquantized_outputz"Error while converting the model: r;   )N)NFr   )=r   argparser   osr   r   pathlibr   typingr   Zpackaging.versionr   r   Ztransformers.pipelinesr	   r
   Ztransformers.tokenization_utilsr   Ztransformers.utilsr   r   r   r8   r    r   r   r4   r:   rG   rA   r[   r`   ru   r{   r!   rl   r   r   r   r   r   r   r'   parser
parse_argsargsr   absoluter<   rI   rB   r|   rm   r}   r   r   Zcheck_loadinghasattrrz   r   exitr%   r%   r%   r&   <module>   s   .*A
%-
2K



