ONNX Model Conversion

ONNX (Open Neural Network Exchange) for converting models between frameworks.


Installation

1pip install onnx onnxruntime
2pip install tf2onnx  # TensorFlow to ONNX
3pip install onnx2pytorch  # ONNX to PyTorch

PyTorch to ONNX

 1import torch
 2import torch.onnx
 3
 4# Load PyTorch model
 5model = MyModel()
 6model.load_state_dict(torch.load('model.pth'))
 7model.eval()
 8
 9# Create dummy input
10dummy_input = torch.randn(1, 3, 224, 224)
11
12# Export
13torch.onnx.export(
14    model,
15    dummy_input,
16    'model.onnx',
17    export_params=True,
18    opset_version=11,
19    do_constant_folding=True,
20    input_names=['input'],
21    output_names=['output'],
22    dynamic_axes={
23        'input': {0: 'batch_size'},
24        'output': {0: 'batch_size'}
25    }
26)

TensorFlow to ONNX

 1# Command line
 2python -m tf2onnx.convert \
 3    --saved-model tensorflow_model/ \
 4    --output model.onnx \
 5    --opset 13
 6
 7# From Keras
 8python -m tf2onnx.convert \
 9    --keras model.h5 \
10    --output model.onnx \
11    --opset 13
 1# Python API
 2import tf2onnx
 3import tensorflow as tf
 4
 5model = tf.keras.models.load_model('model.h5')
 6
 7spec = (tf.TensorSpec((None, 224, 224, 3), tf.float32, name='input'),)
 8output_path = 'model.onnx'
 9
10model_proto, _ = tf2onnx.convert.from_keras(model, input_signature=spec, opset=13, output_path=output_path)

Run ONNX Model

 1import onnxruntime as ort
 2import numpy as np
 3
 4# Load model
 5session = ort.InferenceSession('model.onnx')
 6
 7# Get input/output names
 8input_name = session.get_inputs()[0].name
 9output_name = session.get_outputs()[0].name
10
11# Prepare input
12input_data = np.random.randn(1, 3, 224, 224).astype(np.float32)
13
14# Run inference
15outputs = session.run([output_name], {input_name: input_data})
16print(outputs[0])

Validate ONNX Model

 1import onnx
 2
 3# Load model
 4model = onnx.load('model.onnx')
 5
 6# Check model
 7onnx.checker.check_model(model)
 8print('Model is valid!')
 9
10# Print model info
11print(onnx.helper.printable_graph(model.graph))

Optimize ONNX Model

 1from onnxruntime.transformers import optimizer
 2
 3# Optimize
 4optimized_model = optimizer.optimize_model(
 5    'model.onnx',
 6    model_type='bert',
 7    num_heads=12,
 8    hidden_size=768
 9)
10
11optimized_model.save_model_to_file('model_optimized.onnx')

ONNX to PyTorch

 1import onnx
 2from onnx2pytorch import ConvertModel
 3
 4# Load ONNX model
 5onnx_model = onnx.load('model.onnx')
 6
 7# Convert to PyTorch
 8pytorch_model = ConvertModel(onnx_model)
 9
10# Use like normal PyTorch model
11import torch
12input_tensor = torch.randn(1, 3, 224, 224)
13output = pytorch_model(input_tensor)

ONNX Runtime Providers

 1import onnxruntime as ort
 2
 3# List available providers
 4print(ort.get_available_providers())
 5
 6# Use specific provider
 7session = ort.InferenceSession(
 8    'model.onnx',
 9    providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
10)

Related Snippets