DeepLearningフレームワークの推論時間の観察

pytorchで定義されたモデルをonnx, tflite形式へと変換し推論の実行時間を計測した。

pytorch->onnx->tensorflow->tfliteへの変換は以下の変換処理をまとめたライブラリを利用。 また、モデルのconvertする際には量子化等は適用しておらず、fp32のまま推論。

github.com

測定環境

CPU pytorch onnx onnxruntime tensorflow
AMD Ryzen 5 3600 6-Core Processor 1.8.1 1.9.0 1.7.0 2.5.0

推論時間の測定結果

実行時間は100回連続推論させた平均ですが、数値は参考程度にご覧ください。 ちなみにpytorchとtfliteの推論一発目は実行時間がresnetだと100ms程遅いので(アルゴリズム選定とかでしたっけ?)推論1回目の結果は除いてあります。onnxruntimeは1回目も推論速度は早かったです。

  • Resnet34
    input_shape: (1, 3, 224, 224)
pytorch onnxruntime tflite
29.18ms 12.29ms 39.37ms
  • transformer
    src_shape = (10, 32, 512)
    tgt_shape = (20, 32, 512)
pytorch onnxruntime
273.50ms 92.80ms

※transformerモデルのtfliteへの変換はまだうまくいってません。issue

この結果だけみるとonnxruntime良さそうです(tfとtorchで最適なshapeが異なっていて平等ではない可能性あり)
量子化した時とデバイスを変更したときにどうなるかって感じです。 convertした後のモデルでも眺めてみようかなー

測定コード

import torch                                                                                                                                                                                                       
import numpy as np                                                                                                                                                                                                 
import nne                                                                                                                                                                                                         
import torchvision                                                                                                                                                                                                 
                                                                                                                                                                                                                   
input_shape = (1, 3, 224, 224)                                                                                                                                                                                     
model = torchvision.models.resnet34(pretrained=True)                                                                                                                                                               
torch.save(model, "resnet.pt")                                                                                                                                                                                     
                                                                                                                                                                                                                   
bm = nne.Benchmark(name='torch')                                                                                                                                                                                   
input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)                                                                                                                                      
output_data = nne.infer_torch(model, input_data, bm=bm)                                                                                                                                                            
                                                                                                                                                                                                                   
## onnx                                                                                                                                                                                                            
onnx_file = "resnet.onnx"                                                                                                                                                                                          
bm = nne.Benchmark(name='onnx')                                                                                                                                                                                    
nne.cv2onnx(model, input_shape, onnx_file)                                                                                                                                                                         
onnx_model = nne.load_onnx(onnx_file)                                                                                                                                                                              
nne.infer_onnx(onnx_model, input_data, bm=bm)                                                                                                                                                                      
                                                                                                                                                                                                                   
## tflite                                                                                                                                                                                                          
tflite_file = "resnet.tflite"                                                                                                                                                                                      
bm = nne.Benchmark(name='tflite')                                                                                                                                                                                  
nne.cv2tflite(model, input_shape, tflite_file)                                                                                                                                                                     
tflite_model = nne.load_tflite(tflite_file)                                                                                                                                                                        
nne.infer_tflite(tflite_model, input_data, bm=bm)