トップページ -> 点数制リバーシのAIの仕様について

点数制リバーシのAIの仕様について

このページでは点数制リバーシで使っている対戦相手AIの入力形式・学習方法・ブラウザでの実行方法について簡単に説明します.

目次

1. 入力形式とモデル

10ブロック128フィルタのResNet型モデルに以下のような入力を与え,次の一手の確率(Policy)と勝率(Value)を出力させます.(ブラウザ上で動いているのは軽量化した5ブロック128フィルタです)

入力形式(obs_to_tensorの内容)

モデルの全体


import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(channels)

    def forward(self, x):
        residual = x
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += residual
        out = F.relu(out)
        return out

class PointOthelloNet(nn.Module):
    def __init__(self, n_res_blocks=10, filters=128):
        """
        デフォルトで 10ブロック・128フィルタ のモデルを作成します。
        軽量化したい場合は n_res_blocks=5, filters=64 などを指定可能です。
        """
        super(PointOthelloNet, self).__init__()
        self.n_res_blocks = n_res_blocks
        self.filters = filters
        
        # 入力8チャンネル: [自石, 敵石, 自pt, 敵pt, CurPt, Next1, Next2, Next3]
        self.conv_input = nn.Conv2d(8, filters, kernel_size=3, padding=1, bias=False)
        self.bn_input = nn.BatchNorm2d(filters)
        
        # Residual Tower
        self.res_blocks = nn.ModuleList([
            ResidualBlock(filters) for _ in range(n_res_blocks)
        ])
        
        # Policy Head (64マス)
        # filters -> 2ch -> 128(64*2) -> 64
        self.policy_conv = nn.Conv2d(filters, 2, kernel_size=1, bias=False)
        self.policy_bn = nn.BatchNorm2d(2)
        self.policy_fc = nn.Linear(2 * 64, 64) 
        
        # Value Head (勝率 -1 ~ 1)
        # filters -> 1ch -> 64 -> 64 -> 1
        self.value_conv = nn.Conv2d(filters, 1, kernel_size=1, bias=False)
        self.value_bn = nn.BatchNorm2d(1)
        self.value_fc1 = nn.Linear(64, 64)
        self.value_fc2 = nn.Linear(64, 1)

    def forward(self, x):
        # 共通部分 (Residual Tower)
        x = F.relu(self.bn_input(self.conv_input(x)))
        for block in self.res_blocks:
            x = block(x)
            
        # Policy Head
        p = F.relu(self.policy_bn(self.policy_conv(x)))
        p = p.view(p.size(0), -1) # Flatten
        p = self.policy_fc(p)
        p = F.log_softmax(p, dim=1)
        
        # Value Head
        v = F.relu(self.value_bn(self.value_conv(x)))
        v = v.view(v.size(0), -1) # Flatten
        v = F.relu(self.value_fc1(v))
        v = torch.tanh(self.value_fc2(v))
        
        return p, v

# ==========================================
# 観測データ変換関数 
# ==========================================
def obs_to_tensor(obs, device):
    """観測データをDNN入力形式に変換"""
    board_owner = obs["board_owner"]
    board_points = obs["board_points"]
    turn = obs["turn"]
    
    # 視点を「手番プレイヤー」に合わせる
    my_board = (board_owner == turn).astype(np.float32)
    op_board = (board_owner == -turn).astype(np.float32)
    
    # ポイント正規化 (-30~30 -> -1.0~1.0)
    norm_points = board_points.astype(np.float32) / 30.0
    my_points = norm_points * my_board
    op_points = norm_points * op_board
    
    # 手元の石と未来の石 (Plane全体を埋める)
    def make_plane(val):
        return np.full((8, 8), val / 30.0, dtype=np.float32)
    
    cur_plane = make_plane(obs["current_stone_val"])
    next_planes = [make_plane(v) for v in obs["next_stones_vals"]]
    
    # 8チャンネル結合
    layers = [my_board, op_board, my_points, op_points, cur_plane] + next_planes
    tensor = torch.tensor(np.stack(layers), dtype=torch.float32).unsqueeze(0)
    return tensor.to(device)

2. 学習方法

学習はAI同士を自己対局させて行いました. モンテカルロ木探索を用いてより良い手を探し,その結果得られた次の一手の確率(Policy)と勝敗(Value 勝ち:+1, 負け:-1, 引き分け:0)を教師データとしてニューラルネットワークを更新しました. 学習時の探索回数は800とし,探索時には3手先以降の非公開情報も確認できる設計で学習させました.(実際の対局時には最終的な石の合計が0になるように再生成させています.)

3. ブラウザ上で動かす

10ブロック128フィルタのままだと10MB程度とファイルサイズも推論速度も重いので,10ブロック128フィルタのモデルが生成した4万局程度の棋譜からPolicyとValueを5ブロック128フィルタに学習させました. 最終的に出来上がったpthファイルを以下のコードでonnxファイルに変換し,JavaScriptで動かしています.


import torch
from model import PointOthelloNet as ModelV2

# 学習済みモデルのロード
args = {'n_res_blocks': 5, 'filters': 128}
model = ModelV2(**args)
model.load_state_dict(torch.load("distillation_models/model_epoch_20.pth"))
model.eval()

# ダミー入力の作成 (8ch, 8x8)
# バッチサイズ1, チャンネル8, 縦8, 横8
dummy_input = torch.randn(1, 8, 8, 8, device='cpu')

# ONNXへ変換
try:
    torch.onnx.export(
        model,
        dummy_input,
        "point_othello_5b_v2.onnx",
        input_names=['input'],
        output_names=['policy', 'value'],
        opset_version=11
    )
    print("変換成功: point_othello_5b_v2.onnx")
except Exception as e:
    print(f"変換エラー: {e}")