しぷぜん

しぷぜん

素人プログラマなおいがStep Zero to Oneしていくブログ

畳み込みニューラルネットワークでウォーリーを探そうと思う(その3)

にほんブログ村 IT技術ブログ Pythonへ

こんにちは、なおいです。

前回、前々回と「ウォーリーを探せ」シリーズ書いてきましたが、今回は、適当なデータに対して正しい判定ができるかどうかを確認する検証のフェーズに入っていきます。

 

その前に、前の記事はこちら

ct-innovation01.hateblo.jp

 

 

 

 

検証用データを集める

まずは、正解データとして前回も利用した購入した本からウォーリーを再び探してデータを集めていきます。順調に私の探索能力が上がっていっており、早いとページを開いて数秒で見つけられるようになってきています。(あんまりこの能力の向上はうれしくないですが。。。)ともかく新たに4枚の正解データを作成しました。

 

次に、不正解データを用意します。ここで絶対に正解と判断されたくないキャラクターがいます。ウォーリーにはガールフレンドがいるんですが、これがそれなりに似ています。名前は「ウェンダ」というらしいです。見た目はこんな感じ。

 

f:id:ct-innovation01:20171101171940p:plain f:id:ct-innovation01:20171026090425p:plain

 

帽子と髪型はほぼ一致してるのでメガネと口で不正解と見抜いていかないといけません。これが、うまくいってるかどうかがきになるところです。

 

実際に検証するためのプログラムを組む

予測する関数プログラムとメインのプログラムを以下のように組みました。

モデルのインスタンス化から、実際に予測を行い結果を返して1なら「ウォーリー」という感じにしています。

def return_result(data_input):

    model = cnn.MyCNN(2)
    serializers.load_npz("mymodel.npz", model)

    x = Variable(np.array([data_input], dtype=np.float32))
    y = model.forward(x)

    return y


def main():
    train_pathes = {'folder_path': '0'}
    for fil in train_pathes.keys():
        append_data(fil, train_pathes[fil], 'train')

    N = len(x_train_data)
    x_train_data_np = np.array(x_train_data, dtype=np.float32)

    x_train_data_reshape = x_train_data_np.reshape(N, 1, 60, 60)
    x_train_data_reshape /= 255  # to 0-1

    for i in range(len(x_train_data_reshape)):
        data_input = x_train_data_reshape[i]
        r_data = return_result(data_input)
        data = np.argmax(r_data.data)
        if data == 1:
            print(str(t_train_data[i]) + 'はウォーリー')
        else:
            print(str(t_train_data[i]) + 'は違う')

 

実際にかけたファイルは以下の8つ。

f:id:ct-innovation01:20171101172306p:plain

では、早速結果をお見せ致します。

f:id:ct-innovation01:20171101172315p:plain

おお!ちゃんと判断してくれた!成功だ!

 

よかった。とりあえずは第一関門は突破しました。次回は走査的に全体を探索していくプログラムを組んで実際にウォーリーを探していきます。