EfficientDetのsingle-machine model parallelを実装して、D8(D7x)を学習させる
はじめに
魚群コンペ記事の第二弾です。
EfficientDetの良さそうなリポジトリを見つけ、このリポジトリをコンペに使おうと思いました。 github.com
しかし、EfficientDetの性質上、高い係数のモデルを学習させようとすると、バックボーンのネットワークにかなり大きなVRAMが必要になります。 12GB VRAM x4の環境で試しに学習させてみたのですが、案の定CUDA out of memoryです。
そこで、以下の記事のように、modelをGPUにスライスして学習させる手法を見つけたので、これを実装してみました。 pytorch.org
実装したリポジトリ
使い方はシンプルに--model_parallel
を引数に追加するだけで行けます。
実装解説
バックボーン
バックボーンがメモリの8割以上を占めているそうなので、それを分割しようと考えました。
MBConvBlockのリストを単純に.to()
して移せば良いかと思いましたが、pytorchの仕様上、nn.Module
を継承したクラスのインスタンスを.to()
したところで、
中の重みに相当するテンソルは移動しないようです。
なので、ぼちぼち中身をいじる羽目になりました…。
NMS
model parallelの実装が終わり、学習が上手く行って喜んでいたのですが、
https://github.com/zylo117/Yet-Another-EfficientDet-Pytorch/issues/225
ここで議論されてる通り、d5以上のモデルだとtorch.visionのmnsだとintがオーバーフローしてしまい、エラーを吐かれてしまいました。
なので、nms実装も上のリポジトリで変更しています。
まとめ
お世辞にもきれいと言える実装ではないですが、とりあえず動くものができたので公開しました。
ただ、batch 1で学習したところで、Normalization系のモジュールが機能せず学習が思ったように進まないのであしからず…。