前回の記事
皆さんこんにちは。さっそくですが、まずはこの記事を読んでみて下さい。
この記事、深層学習の得意分野である手書き文字認識をあえて深層学習を用いずに実装しています。学習の際の教師データはscikit-learnの”digits“データ。8×8画素の数字データです。
ロジスティック回帰を用いましたが、実際に書いてみた手書き文字の認識精度は63%… 良くはないけど悪くもないといった感じです。
流石にこの精度ではいけないのでは…? ということで、今回は真打ち登場、深層学習を用いて同じデータセットを識別してみたいと思います。
ロジスティック回帰の問題点
今回の実装には、畳み込みニュートラルネットワークを使います。
上記の記事ではロジスティック回帰を使っていますが、その基本的な考え方は
画像のピクセル毎の明暗データを説明変数とし、画像に書かれている数を目的変数にすることで機械学習のモデルに適用させる。
といったものでした。ここで大事なのが、各ピクセルごとのデータを独立な変数としていた事です。ところが、文字や画像は空間的に連続なデータの集まりなので、周囲のピクセルとの関係を切って独立に見ていくことはナンセンスでしょう。
そこで、その周囲のピクセルとの関係までうまいこと評価に落とし込められれば精度は上がるだろう、というのが畳み込みニューラルネットワーク(CNN)の基本的な考え方です。
このあたりはエキスパート本がいくつかあるので、そちらを参考にしてみてください。
CNNをはじめとする深層学習の実装に関してより理解を深めたい方は、以下の本がおすすめです。
- 作者: 斎藤康毅
- 出版社/メーカー: オライリージャパン
- 発売日: 2016/09/24
- メディア: 単行本(ソフトカバー)
- この商品を含むブログ (17件) を見る
理解を深めたい人のための本: とてもわかりやすい。ニューラルネットワークの基本からCNNまで学べる。
CNNを用いた手書き文字認識の実装
というわけで、前回作ったプログラムを少し書き換えてCNNによる手書き文字認識を実装してみました。
使用言語はpython3系
使用したパッケージは
- os: ディレクトリ操作
- NumPy(np): 配列操作
- pillow(PIL): 画像操作
- scikit-learn(sklearn): 機械学習のユーティリティ、digitsデータセットの読み込み
- tensorflow(tf): 深層学習のユーティリティ
さらに今回は自前のCNNのクラスをまとめたモジュールを用意しました。
- CNNclasses(cnn)
今回の実装の目的は、
- sklearnのデータセットを使って機械学習をする
- 深層学習を用いて手書きの数字を判別し、用いなかった時との精度の差を確認する
この2つです。
コードは前回とほぼ似ているので、主に変更点を。読み飛ばしてしまっても問題ありません。
パッケージのインポート
割愛
画像データ読み込み、加工
これも前回とほぼ同じですが、ちょとした変更点が。
そこまで重要ではないのですが、CNNのクラスが受け取る数は[0.0, 1.0] なので、そこに注意します。
加工は前回と同じですが、一応こんな感じになります。
画像1: 手書き文字の画像をデータセットに合う型に加工する。
ついでに、教師データの方はこんな感じ。
画像2: sklearnの手書き文字データ(digits)
教師データの加工
変更点は二箇所。Xの次元をCNNが受け取る形にする事と、yを1-of-k表現にする事です。
今回学習に使うdigitsデータ(train)と評価用データ(valid)の比率は9:1となっています。
CNNの層、tfのグラフを構築
ここからCNNの実装です。まずCNNの層を用意するのですが、今回データのサイズが小さいこともあって畳み込み層(cnn.Conv)とPooling層(cnn.Pooling)はひとつずつしか用意できませんでした。
それ以降の部分はtensorflowの変数を用意したり、出力や誤差関数を定義したりしています。
深層学習をするには教師データが少なすぎるので、学習回数は500回にしました。
教師データから学習
いろいろやっていますが、大事な部分は
sess.run(train, feed_dict={x: train_X[start:end], t: train_y[start:end]})
です。ここで出力と教師データとの誤差を小さくするように調整しているわけです。
また、学習100回毎に教師データに対する学習精度と評価データに対する精度を出力します。
画像データの判別
学習が完了した vaild にデータを放り込むと、判別した結果を返します。
それ以外は前回と同じ。代わり映えしないですね!
以上になります。長かった…
それでは動かしてみます。
精度の確認、検証
今回判別してもらう文字は前回と同じものです。こんな感じの手書き数字が0〜9まで各3個、計30個あります。
画像3: 用意した手書き文字。スマホで書いた。
それではターミナルからプログラムを動かします。果たして精度はどのくらい上がるのか…
学習に時間がかかる…(30秒ほど)
結果はこうなりました。
“判別結果”より上は教師データによる学習の過程を表しています。
“誤差”、”F1値”は学習に用いていないdigitsデータ(valid_X,)について識別した結果の誤差関数の値とF1値(精度の指標)を表しています。
ここを見ると精度は常に上がっているので、過学習の心配は無さそうです。問題はその下で、“観測”が用意した手書き文字の実際の値、”予測”がプログラムによって判断された値となっています
あれ??
私の書いた文字は正答率56%って… ロジスティック回帰の時は精度63%だったのに対して低くなっていますね。3の予測がすべて外れています。ロジスティック回帰のときは予想がすべて外れる文字など無かったのですが…
このあとショックだったので勉強の回数(n_epochs)をいろいろいじってみたのですが、せいぜい66%止まりでした。乱数のseed値によっては40%台のときもありました。
ロジスティック回帰よりも精度が低い…?
結果
CNNを用いても精度は上がらないどころか、精度が下がる。
考えられる理由としては、
- 画素数が低く、畳み込みやPoolingが十分にできなかった
- 前回と同じ画像の加工方法自体に問題があった
この2点でしょうか。
と言うのも、実は手書き文字の教師データにはsklearnのdigitsの他にMNISTという有名なものがありまして(28×28画素と解像度が段違いに高い)、そちらを使えばより多くの畳み込み、Pooling層を用意できるのです。
ちなみにそちらの方の結果は下記(学習には3分ほどかかる)
学習回数は10回ですが、正答率が高いです。
まとめ
変数の少ないデータに畳み込みニューラルネットを用いても、その精度は高くはならない。
当然ながら、深層学習にも不得意なデータがあるようです。
CNNを用いる場合は小さなデータで学習しようとしてもたくさんの学習を必要とするため、単純な機械学習よりも処理時間が増えてしまいますし、精度も上がりません。
といったことを確認できたところで、今回はおしまいです。それではまた次の記事でお会いしましょう。最後までご覧くださりありがとうございました。