Cryo-ET 鞭毛モーターの3D検出 | Kaggle BYU 2025 1位解法コード解説 (学習)

AI

Kaggle BYU 2025 1位解法について、これまでモデルからDatasetまで見てきて、今回は学習部分となります。

先に結果を言うと、学習部分については性能上げるために尖ったことはしてないです。
(尖ったことしてないですが、ひたすら基礎の積み重ねであり理解の甘さを突き付けられてしんどかったです…)

コンペ:BYU – Locating Bacterial Flagellar Motors 2025 | Kaggle
参考リポジトリ:GitHub – brendanartley/BYU-competition: 1st place solution for the BYU Locating Bacterial Flagellar Motors Competition

学習

学習部分のGithubリンク:
BYU-competition/src/modules/train.py at main · brendanartley/BYU-competition · GitHub

最初の”def run_eval()”は推論のところともかぶるのでスキップして、def train()からです。

trainについて、特に図示することがないので大まかな流れだけ書いておきます。
※DDPやoptimizerなど先人方が詳細に書いてくれてるのでここでは深く掘らないようにします。

  1. DataLoader, Dataset, Sampler等の学習に必要なものの準備
  2. DDP(Distributed Data Parallel)による分散学習準備 ← Distributed* 周り
  3. autocastで最適な数値精度(FP16, FP32, BF16)に自動調整されるようにしてforward
  4. 準備が整ったらlossをスケールしてbackwardで勾配算出 ← 285行目
  5. スケール戻して勾配ノルム収集(ログ確認用) ← 286~297行
  6. 勾配から重み更新、次ループ用のスケール値更新、Optimizer勾配リセット ← 298~300行
  7. EMAモデルの重み更新 ※補足あり
  8. 次ループ用の学習率更新
  9. loss, 勾配ノルムのログ取得
  10. 最後にモデル重みを保存

やっている内容は普通のことなんですが、

  • 「数値精度(FP16, FP32, BF16)は何が違うんだっけ?」
  • 「BF16が深層モデルでいいのなんだっけ?」、
  • 「DDPの原理どうなってる?」
  • 「EMAモデルの更新どうだったっけ?」

と色々基礎的な話が飛んでしまっており、一つ一つ見直してたら時間がかかってしまいました…

ただ、こういう基礎的なところをしっかり積み上げられないと、精度出ないときの原因の切り分けができず、せっかく良いアイディアが出ても迷走してしまうと思い、改めて基礎の重要さを感じました。

補足 EMAモデル

EMA(Exponential Moving Average)モデルはその名の通り、指数移動平均したモデルです。

何を指数移動平均してるかといったら、重みやバッチ正則化のパラメータといったモデルが保持するパラメータです。

式として表現すると以下のようになってます。EMAはEMAの値、EMAの添え字は時系列、βは定数、θは現在のパラメータ(現在のepochのパラメータ)、を示してます。βが大きいと過去値の影響が大きくなります。

EMAt=(1β)θt+βEMAt1EMA_t = (1-\beta)\cdot\theta_t + \beta\cdot EMA_{t-1}

各epochで学習したパラメータをEMAすることで、過去の学習パラメータも考慮されるので汎化性能が上がるというものになります。

学習方向(?)に対してアンサンブルしている感じかなと思いました。適当なこと言ってたらすみません(笑)

実装は以下のようになってました。def update(self, model)でEMA更新してます。

class ModelEMA(nn.Module):
    """
    EMA for model weights.
    Source: https://www.kaggle.com/competitions/blood-vessel-segmentation/discussion/475080#2641635
    """
    def __init__(self, model, decay=0.9999, device=None):
        super().__init__()
        self.module = deepcopy(model)
        self.module.eval()
        self.decay = decay
        self.device = device
        if self.device is not None:
            self.module.to(device=device)

    def _update(self, model, update_fn):
        with torch.no_grad():
            for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
                if self.device is not None:
                    model_v = model_v.to(device=self.device)
                ema_v.copy_(update_fn(ema_v, model_v))

    def update(self, model):
        self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)

    def set(self, model):
        self._update(model, update_fn=lambda e, m: m)

EMAモデルは学習中は全く学習に寄与しておらず、更新されたパラメータを受け取りEMAを更新してるだけであり、最後に完成したものが推論で使われるというものになります。

まとめ

今回は短く、内容もあっさりですが以上となります。

度々書いてましたが、見た目はあっさりですがAIだけでなく情報系の話も含めて基礎知識が要求されるので、正しく読み解こうとするとヘビーではありました…

どうしても解法とかモデルとかに目が行きがちであり、「学習させるところは学習させるだけでしょ?」という気持ちになりますが、正しく検証を積み上げられるようにこういう細かいところも見ておく価値はあると思いました。

次回はいよいよ推論になります。

最後までご覧いただきありがとうございました。

タイトルとURLをコピーしました