らんだむな記憶

blogというものを体験してみようか!的なー

pretrainedモデルのウェイトのロード (2)

pretrainedモデルのウェイトのロード (1) - らんだむな記憶の続き。

実際の実装を行う。

def create_vgg19_model(caffe_pt):
    vision_name2caffe_name = {
        "features.0" : "conv1_1",
        "features.2" : "conv1_2",
        "features.5" : "conv2_1",
        "features.7" : "conv2_2",
        "features.10": "conv3_1",
        "features.12": "conv3_2",
        "features.14": "conv3_3",
        "features.16": "conv3_4",
        "features.19": "conv4_1",
        "features.21": "conv4_2",
        "features.23": "conv4_3",
        "features.25": "conv4_4",
        "features.28": "conv5_1",
        "features.30": "conv5_2",
        "features.32": "conv5_3",
        "features.34": "conv5_4",
        "classifier.0": "fc6_1",
        "classifier.3": "fc7_1",
        "classifier.6": "fc8_1",
    }

    model = models.vgg19(pretrained=False)
    caffe_model = torch.load(caffe_pt)

    caffe_name2layer = {}
    for name, layer in caffe_model.named_modules():
        if name in vision_name2caffe_name.values():
            caffe_name2layer[name] = layer

    with torch.no_grad():
        for name, layer in model.named_modules():
            if name not in vision_name2caffe_name:
                continue
            caffe_name = vision_name2caffe_name[name]
            caffe_layer = caffe_name2layer[caffe_name]
            layer.state_dict()["weight"].copy_(caffe_layer.state_dict()["weight"])
            layer.state_dict()["bias"].copy_(caffe_layer.state_dict()["bias"])
            
    return model

で多分いけてる。

for name, param in model.named_parameters():
    print(param)

で中身を見ると、Visual Geometry Group - University of Oxfordから入手できる caffemodel のパラメータに置き換わっていることがわかる。この準備のもと

model.classifier = nn.Sequential(nn.Linear(512*7*7, 10),
                                 nn.LogSoftmax(dim=1))

で Fashion-MNIST の学習を行わせると 1 エポックで 86% 程度の精度が出る。

やっていることは基本的にdeep-learning-from-scratch-3/models.py at master · oreilly-japan/deep-learning-from-scratch-3 · GitHubを同じであり、またこの内容は KitModel とほぼ同じなのでとても参考になる。

ついでなので、 caffemodelMMdnn で TensorFlow 2 用に変換しようと tf.keras.applications.VGG19 の実装を確認してみたがよく分からなかった・・・。なんか難しい。こっちはまた今度だな。