情報系院生のノート

情報技術系全般,自分用メモを公開してます。

EfficientDetのsingle-machine model parallelを実装して、D8(D7x)を学習させる

はじめに

魚群コンペ記事の第二弾です。

tmyoda.hatenablog.com

EfficientDetの良さそうなリポジトリを見つけ、このリポジトリをコンペに使おうと思いました。 github.com

しかし、EfficientDetの性質上、高い係数のモデルを学習させようとすると、バックボーンのネットワークにかなり大きなVRAMが必要になります。 12GB VRAM x4の環境で試しに学習させてみたのですが、案の定CUDA out of memoryです。

そこで、以下の記事のように、modelをGPUにスライスして学習させる手法を見つけたので、これを実装してみました。 pytorch.org

実装したリポジトリ

github.com

使い方はシンプルに--model_parallelを引数に追加するだけで行けます。

実装解説

バックボーン

バックボーンがメモリの8割以上を占めているそうなので、それを分割しようと考えました。

MBConvBlockのリストを単純に.to()して移せば良いかと思いましたが、pytorchの仕様上、nn.Moduleを継承したクラスのインスタンス.to()したところで、 中の重みに相当するテンソルは移動しないようです。

なので、ぼちぼち中身をいじる羽目になりました…。

NMS

model parallelの実装が終わり、学習が上手く行って喜んでいたのですが、

https://github.com/zylo117/Yet-Another-EfficientDet-Pytorch/issues/225

ここで議論されてる通り、d5以上のモデルだとtorch.visionmnsだとintがオーバーフローしてしまい、エラーを吐かれてしまいました。

なので、nms実装も上のリポジトリで変更しています。

まとめ

お世辞にもきれいと言える実装ではないですが、とりあえず動くものができたので公開しました。

ただ、batch 1で学習したところで、Normalization系のモジュールが機能せず学習が思ったように進まないのであしからず…。