みなさんこんにちは。技術開発室の岡田です。
前回の投稿では、AWS re:invent2019のレポートをしました。 いやー、楽しいイベントだったなー。来年も行きたいなー(と書き続けたら行かせてもらえないだろうか。ちなみに増員はしてもらえそうな雰囲気だった。)
さて、今回から私の担当している機械学習に関する投稿に戻りますが、今回は「画像から回帰問題を解く」をテーマにしたいと思います。
はじめに
画像を用いた機械学習といえば分類問題(Classification)が一般的ですが、実は回帰問題(Regression)を解くこともできます。 有名なところでは「顔の画像から年齢を推定する」というものが有ります。こちらの記事で詳しく紹介されている方がおられますので、ご参照ください。
FLECTでも同様の技術を用いたソリューションを提供しています。 詳しくは述べられないのですが、ざっくり一例をご紹介しますと、下記の図はとある工業製品の摩耗状態を画像から判定するモデルになります。 この工業製品の摩耗状態はとある指標で定量的に10段階で計測できるのですが、図の横軸が計測した正しい摩耗状態、縦軸が画像から推定した摩耗状態となります。かなりしっかりと推定ができているように見えると思います。
上図ではテストサンプルの件数が多くわかりにくいですが、推定した摩耗状態と実際の摩耗状態の誤差をバイオリンチャートで表現すると次の図のようになっています。 縦軸が0を中心に摩耗状態の段階誤差の範囲を示しています。ほぼ1段階以内の誤差で推定することができています。 (自画自賛ですがすごいと思ってる!興味のある方はご連絡を。)
この回帰の技術に関連して、ちょっとした実験をしてみましたのでご紹介します。
実験の動機
まず、実験をしてみた動機からご説明します。
みなさんもGoogleなどで検索していただくとわかると思いますが、画像で回帰問題を解くことについてWeb上にはそれほど多くの情報がありません。大体が上記ブログと同様に年齢推定の話です。 FLECTでも回帰問題を扱うことになりブログで情報発信をしようと思ってその時にその理由がわかったのですが、業務上扱っているデータは当然あるのですがブログで扱えるような公のデータがないのです(私の観測範囲では)。 なので、実験的に何かをしてブログを書くというのが難しかったのではないかと考えています。
で、我々も同じ理由でだいぶ長い間ブログを書けずにいたのですが、あるとき、ぱっと、あることを思い出しました。 そういえば、昔、きゅうりの仕分け(ランクづけ)を画像分類でやっていた人がいたなぁ、と。
それからずっと、このきゅうりのランクを回帰で推定できないか?というのが頭から離れずにいたのですが、この年末幸運にも少し時間が空いたので早速実験をしてみた、ということになります。
ということで、以下、きゅうりのランクを回帰で推定する話をつらつらと書いていきます。 なお、今回用いたソースコード(notebook)は文末のgit repositoryに格納しておきます。
きゅうりの仕分けとデータセット
ご存知の方も多いと思いますが、自動車部品メーカーに務められていた方が退職して、Deep Learningを用いてきゅうりの仕分けをするシステムを作りました。 当時だいぶバズって、Googleのブログにも特集されています。詳細はこの記事を見ていただくのが良いかと思います。
cloudplatform-jp.googleblog.com
この仕分けシステムでは、2LからCまで、きゅうりを9つのランクに仕分けることができるそうです。 そして、この学習用のデータはなんとgitで公開されているのです(Creative Commons)。ありがたい!
今回は、この中でもprototype_2の学習用のデータを用いて実験を行いました。 なお、このprototype_2は上下側面の3方向からの撮影したデータとなっていますが、今回は上から撮影したもののみを用いることにしました。 また、画像の中には、手が写り込んでいたりする写真などが含まれていますが、これらは事前に取り除きました(OTHERラベルがついている)。
一応、分布を確認するとこんな感じに各クラス800枚前後でまんべんなく格納されています。ちなみに横軸の数値は2L〜Cまでのランクに対応づいています。縦軸は画像数。
なお、本データはCIFAR10の形式で格納されていますので、読み出し方はCIFAR10のサイトかgit repositoryを参考にしてください。
ネットワーク
今回はresnet v2を用いてやってみました。ポイントはFlatten()のあとにsoftmaxではなくreluを活性化関数とした全結合層を置くことです。 また、今回はCIFAR10と同じ画像なので、Inputのサイズは32x32x3になります。
<略> x = AveragePooling2D(pool_size=8)(x) x = Flatten()(x) x = Dense(8, activation='relu')(x) outputs = Dense(1)(x)
モデル全体は次のような感じです。 なお、ある程度エイヤで作っているネットワークなので、チューニングの余地は有りますので、興味のある方は適当にいじってみるといいと思います。
__________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input_8 (InputLayer) (None, 32, 32, 3) 0 __________________________________________________________________________________________________ conv2d_218 (Conv2D) (None, 32, 32, 16) 448 input_8[0][0] __________________________________________________________________________________________________ batch_normalization_197 (BatchN (None, 32, 32, 16) 64 conv2d_218[0][0] __________________________________________________________________________________________________ activation_197 (Activation) (None, 32, 32, 16) 0 batch_normalization_197[0][0] __________________________________________________________________________________________________ conv2d_219 (Conv2D) (None, 32, 32, 16) 272 activation_197[0][0] __________________________________________________________________________________________________ batch_normalization_198 (BatchN (None, 32, 32, 16) 64 conv2d_219[0][0] __________________________________________________________________________________________________ activation_198 (Activation) (None, 32, 32, 16) 0 batch_normalization_198[0][0] __________________________________________________________________________________________________ conv2d_220 (Conv2D) (None, 32, 32, 16) 2320 activation_198[0][0] __________________________________________________________________________________________________ batch_normalization_199 (BatchN (None, 32, 32, 16) 64 conv2d_220[0][0] __________________________________________________________________________________________________ activation_199 (Activation) (None, 32, 32, 16) 0 batch_normalization_199[0][0] __________________________________________________________________________________________________ conv2d_222 (Conv2D) (None, 32, 32, 64) 1088 activation_197[0][0] __________________________________________________________________________________________________ conv2d_221 (Conv2D) (None, 32, 32, 64) 1088 activation_199[0][0] __________________________________________________________________________________________________ add_64 (Add) (None, 32, 32, 64) 0 conv2d_222[0][0] conv2d_221[0][0] __________________________________________________________________________________________________ batch_normalization_200 (BatchN (None, 32, 32, 64) 256 add_64[0][0] __________________________________________________________________________________________________ activation_200 (Activation) (None, 32, 32, 64) 0 batch_normalization_200[0][0] __________________________________________________________________________________________________ conv2d_223 (Conv2D) (None, 32, 32, 16) 1040 activation_200[0][0] __________________________________________________________________________________________________ batch_normalization_201 (BatchN (None, 32, 32, 16) 64 conv2d_223[0][0] __________________________________________________________________________________________________ activation_201 (Activation) (None, 32, 32, 16) 0 batch_normalization_201[0][0] __________________________________________________________________________________________________ conv2d_224 (Conv2D) (None, 32, 32, 16) 2320 activation_201[0][0] __________________________________________________________________________________________________ batch_normalization_202 (BatchN (None, 32, 32, 16) 64 conv2d_224[0][0] __________________________________________________________________________________________________ activation_202 (Activation) (None, 32, 32, 16) 0 batch_normalization_202[0][0] __________________________________________________________________________________________________ conv2d_225 (Conv2D) (None, 32, 32, 64) 1088 activation_202[0][0] __________________________________________________________________________________________________ add_65 (Add) (None, 32, 32, 64) 0 add_64[0][0] conv2d_225[0][0] __________________________________________________________________________________________________ batch_normalization_203 (BatchN (None, 32, 32, 64) 256 add_65[0][0] __________________________________________________________________________________________________ activation_203 (Activation) (None, 32, 32, 64) 0 batch_normalization_203[0][0] __________________________________________________________________________________________________ conv2d_226 (Conv2D) (None, 32, 32, 16) 1040 activation_203[0][0] __________________________________________________________________________________________________ batch_normalization_204 (BatchN (None, 32, 32, 16) 64 conv2d_226[0][0] __________________________________________________________________________________________________ activation_204 (Activation) (None, 32, 32, 16) 0 batch_normalization_204[0][0] __________________________________________________________________________________________________ conv2d_227 (Conv2D) (None, 32, 32, 16) 2320 activation_204[0][0] __________________________________________________________________________________________________ batch_normalization_205 (BatchN (None, 32, 32, 16) 64 conv2d_227[0][0] __________________________________________________________________________________________________ activation_205 (Activation) (None, 32, 32, 16) 0 batch_normalization_205[0][0] __________________________________________________________________________________________________ conv2d_228 (Conv2D) (None, 32, 32, 64) 1088 activation_205[0][0] __________________________________________________________________________________________________ add_66 (Add) (None, 32, 32, 64) 0 add_65[0][0] conv2d_228[0][0] __________________________________________________________________________________________________ batch_normalization_206 (BatchN (None, 32, 32, 64) 256 add_66[0][0] __________________________________________________________________________________________________ activation_206 (Activation) (None, 32, 32, 64) 0 batch_normalization_206[0][0] __________________________________________________________________________________________________ conv2d_229 (Conv2D) (None, 16, 16, 64) 4160 activation_206[0][0] __________________________________________________________________________________________________ batch_normalization_207 (BatchN (None, 16, 16, 64) 256 conv2d_229[0][0] __________________________________________________________________________________________________ activation_207 (Activation) (None, 16, 16, 64) 0 batch_normalization_207[0][0] __________________________________________________________________________________________________ conv2d_230 (Conv2D) (None, 16, 16, 64) 36928 activation_207[0][0] __________________________________________________________________________________________________ batch_normalization_208 (BatchN (None, 16, 16, 64) 256 conv2d_230[0][0] __________________________________________________________________________________________________ activation_208 (Activation) (None, 16, 16, 64) 0 batch_normalization_208[0][0] __________________________________________________________________________________________________ conv2d_232 (Conv2D) (None, 16, 16, 128) 8320 add_66[0][0] __________________________________________________________________________________________________ conv2d_231 (Conv2D) (None, 16, 16, 128) 8320 activation_208[0][0] __________________________________________________________________________________________________ add_67 (Add) (None, 16, 16, 128) 0 conv2d_232[0][0] conv2d_231[0][0] __________________________________________________________________________________________________ batch_normalization_209 (BatchN (None, 16, 16, 128) 512 add_67[0][0] __________________________________________________________________________________________________ activation_209 (Activation) (None, 16, 16, 128) 0 batch_normalization_209[0][0] __________________________________________________________________________________________________ conv2d_233 (Conv2D) (None, 16, 16, 64) 8256 activation_209[0][0] __________________________________________________________________________________________________ batch_normalization_210 (BatchN (None, 16, 16, 64) 256 conv2d_233[0][0] __________________________________________________________________________________________________ activation_210 (Activation) (None, 16, 16, 64) 0 batch_normalization_210[0][0] __________________________________________________________________________________________________ conv2d_234 (Conv2D) (None, 16, 16, 64) 36928 activation_210[0][0] __________________________________________________________________________________________________ batch_normalization_211 (BatchN (None, 16, 16, 64) 256 conv2d_234[0][0] __________________________________________________________________________________________________ activation_211 (Activation) (None, 16, 16, 64) 0 batch_normalization_211[0][0] __________________________________________________________________________________________________ conv2d_235 (Conv2D) (None, 16, 16, 128) 8320 activation_211[0][0] __________________________________________________________________________________________________ add_68 (Add) (None, 16, 16, 128) 0 add_67[0][0] conv2d_235[0][0] __________________________________________________________________________________________________ batch_normalization_212 (BatchN (None, 16, 16, 128) 512 add_68[0][0] __________________________________________________________________________________________________ activation_212 (Activation) (None, 16, 16, 128) 0 batch_normalization_212[0][0] __________________________________________________________________________________________________ conv2d_236 (Conv2D) (None, 16, 16, 64) 8256 activation_212[0][0] __________________________________________________________________________________________________ batch_normalization_213 (BatchN (None, 16, 16, 64) 256 conv2d_236[0][0] __________________________________________________________________________________________________ activation_213 (Activation) (None, 16, 16, 64) 0 batch_normalization_213[0][0] __________________________________________________________________________________________________ conv2d_237 (Conv2D) (None, 16, 16, 64) 36928 activation_213[0][0] __________________________________________________________________________________________________ batch_normalization_214 (BatchN (None, 16, 16, 64) 256 conv2d_237[0][0] __________________________________________________________________________________________________ activation_214 (Activation) (None, 16, 16, 64) 0 batch_normalization_214[0][0] __________________________________________________________________________________________________ conv2d_238 (Conv2D) (None, 16, 16, 128) 8320 activation_214[0][0] __________________________________________________________________________________________________ add_69 (Add) (None, 16, 16, 128) 0 add_68[0][0] conv2d_238[0][0] __________________________________________________________________________________________________ batch_normalization_215 (BatchN (None, 16, 16, 128) 512 add_69[0][0] __________________________________________________________________________________________________ activation_215 (Activation) (None, 16, 16, 128) 0 batch_normalization_215[0][0] __________________________________________________________________________________________________ conv2d_239 (Conv2D) (None, 8, 8, 128) 16512 activation_215[0][0] __________________________________________________________________________________________________ batch_normalization_216 (BatchN (None, 8, 8, 128) 512 conv2d_239[0][0] __________________________________________________________________________________________________ activation_216 (Activation) (None, 8, 8, 128) 0 batch_normalization_216[0][0] __________________________________________________________________________________________________ conv2d_240 (Conv2D) (None, 8, 8, 128) 147584 activation_216[0][0] __________________________________________________________________________________________________ batch_normalization_217 (BatchN (None, 8, 8, 128) 512 conv2d_240[0][0] __________________________________________________________________________________________________ activation_217 (Activation) (None, 8, 8, 128) 0 batch_normalization_217[0][0] __________________________________________________________________________________________________ conv2d_242 (Conv2D) (None, 8, 8, 256) 33024 add_69[0][0] __________________________________________________________________________________________________ conv2d_241 (Conv2D) (None, 8, 8, 256) 33024 activation_217[0][0] __________________________________________________________________________________________________ add_70 (Add) (None, 8, 8, 256) 0 conv2d_242[0][0] conv2d_241[0][0] __________________________________________________________________________________________________ batch_normalization_218 (BatchN (None, 8, 8, 256) 1024 add_70[0][0] __________________________________________________________________________________________________ activation_218 (Activation) (None, 8, 8, 256) 0 batch_normalization_218[0][0] __________________________________________________________________________________________________ conv2d_243 (Conv2D) (None, 8, 8, 128) 32896 activation_218[0][0] __________________________________________________________________________________________________ batch_normalization_219 (BatchN (None, 8, 8, 128) 512 conv2d_243[0][0] __________________________________________________________________________________________________ activation_219 (Activation) (None, 8, 8, 128) 0 batch_normalization_219[0][0] __________________________________________________________________________________________________ conv2d_244 (Conv2D) (None, 8, 8, 128) 147584 activation_219[0][0] __________________________________________________________________________________________________ batch_normalization_220 (BatchN (None, 8, 8, 128) 512 conv2d_244[0][0] __________________________________________________________________________________________________ activation_220 (Activation) (None, 8, 8, 128) 0 batch_normalization_220[0][0] __________________________________________________________________________________________________ conv2d_245 (Conv2D) (None, 8, 8, 256) 33024 activation_220[0][0] __________________________________________________________________________________________________ add_71 (Add) (None, 8, 8, 256) 0 add_70[0][0] conv2d_245[0][0] __________________________________________________________________________________________________ batch_normalization_221 (BatchN (None, 8, 8, 256) 1024 add_71[0][0] __________________________________________________________________________________________________ activation_221 (Activation) (None, 8, 8, 256) 0 batch_normalization_221[0][0] __________________________________________________________________________________________________ conv2d_246 (Conv2D) (None, 8, 8, 128) 32896 activation_221[0][0] __________________________________________________________________________________________________ batch_normalization_222 (BatchN (None, 8, 8, 128) 512 conv2d_246[0][0] __________________________________________________________________________________________________ activation_222 (Activation) (None, 8, 8, 128) 0 batch_normalization_222[0][0] __________________________________________________________________________________________________ conv2d_247 (Conv2D) (None, 8, 8, 128) 147584 activation_222[0][0] __________________________________________________________________________________________________ batch_normalization_223 (BatchN (None, 8, 8, 128) 512 conv2d_247[0][0] __________________________________________________________________________________________________ activation_223 (Activation) (None, 8, 8, 128) 0 batch_normalization_223[0][0] __________________________________________________________________________________________________ conv2d_248 (Conv2D) (None, 8, 8, 256) 33024 activation_223[0][0] __________________________________________________________________________________________________ add_72 (Add) (None, 8, 8, 256) 0 add_71[0][0] conv2d_248[0][0] __________________________________________________________________________________________________ batch_normalization_224 (BatchN (None, 8, 8, 256) 1024 add_72[0][0] __________________________________________________________________________________________________ activation_224 (Activation) (None, 8, 8, 256) 0 batch_normalization_224[0][0] __________________________________________________________________________________________________ average_pooling2d_8 (AveragePoo (None, 1, 1, 256) 0 activation_224[0][0] __________________________________________________________________________________________________ flatten_8 (Flatten) (None, 256) 0 average_pooling2d_8[0][0] __________________________________________________________________________________________________ dense_15 (Dense) (None, 8) 2056 flatten_8[0][0] __________________________________________________________________________________________________ dense_16 (Dense) (None, 1) 9 dense_15[0][0] ================================================================================================== Total params: 848,497 Trainable params: 843,281 Non-trainable params: 5,216 __________________________________________________________________________________________________
トレーニング
ネットワークを定義できたら、トレーニングを開始します。損失関数は今回はMSEにしています。 また、全サンプル数は7861で、これをTraining, Validation, Testでそれぞれ6:2:2に分割しています。 またこちらもエイヤですがエポック数は200にしています。トレーニングの結果は次の通り。 青線がlossでオレンジ線がval_lossです。大体130エポックで収束しているようですね。
評価
では、トレーニングしたモデルで評価してみましょう。 推定ランクをプロットした図です。なかなかうまく推定できているように見えます。
ただ、やはり、サンプル数が多いのでどの程度分散しているのかわかりにくいですね。 なので、こちらのバイオリンチャートを見ていただきましょう。
ご覧になっておわかりになると思いますが、ほぼ正しいランクの位置をてっぺんに正規分布していますね。 やや裾が重いところがあり、他のクラスと認識されてしまっているものが有るようですが、そこそこ正しく推定されたと見て良いと思います。 (作成者の方もおっしゃっていますが、人間がやっているのでランク付けもある程度えいやなところがあるらしいので。)
TensorFlowでディープラーニングによる『キュウリ』の仕分け | Workpiles
今回の場合は、各分類クラス間が離れていないため、人間がやっても“2LとL”や“LとBL”の判別は難しい(けっこう適当)だったりします。
といことで、今回の実験はひとまず成功したかなと思います。画像からの回帰、みなさんもチャレンジしてみてはいかがでしょうか。 結果をシェアしてもらえるとありがたいです。
しきい値調整
推定されたランクは浮動小数点となっていますが、実際のランクは整数値です。 この整数値に変換するためにしきい値を決める必要が有ります。 デフォルトで小数点第1位が5のところでしきい値にしてもいいですが、このしきい値をいい感じに調整してくれる方法も有ります。 今回は画像から回帰をすることろまでが主題となりますので、そこまでは踏み込みません。また機会があればご紹介します。
余談ですが、、、
resnet v2でクラス分類させたら、全然クラス分類のほうが精度が高かったです。 タスクによって解き方をしっかり使い分けることが重要です。この辺の話も機会があれば別途。
最後に
今回は、きゅうりの画像のランク付けを題材に、画像から回帰問題を解く方法についてご紹介しました。 最初にご紹介したとおり、画像から連続値を推定する技術は様々な分野に応用可能であり、FLECTでも実績があります。 ぜひご活用いただければありがたいと思います。
次回は、また別の技術をご紹介する予定です。ではでは。