フレクトのクラウドblog re:newal

http://blog.flect.co.jp/cloud/からさらに引っ越しています

きゅうり画像で回帰問題を解いてみた

みなさんこんにちは。技術開発室の岡田です。

前回の投稿では、AWS re:invent2019のレポートをしました。 いやー、楽しいイベントだったなー。来年も行きたいなー(と書き続けたら行かせてもらえないだろうか。ちなみに増員はしてもらえそうな雰囲気だった。)

cloud.flect.co.jp

さて、今回から私の担当している機械学習に関する投稿に戻りますが、今回は「画像から回帰問題を解く」をテーマにしたいと思います。

はじめに

画像を用いた機械学習といえば分類問題(Classification)が一般的ですが、実は回帰問題(Regression)を解くこともできます。 有名なところでは「顔の画像から年齢を推定する」というものが有ります。こちらの記事で詳しく紹介されている方がおられますので、ご参照ください。

FLECTでも同様の技術を用いたソリューションを提供しています。 詳しくは述べられないのですが、ざっくり一例をご紹介しますと、下記の図はとある工業製品の摩耗状態を画像から判定するモデルになります。 この工業製品の摩耗状態はとある指標で定量的に10段階で計測できるのですが、図の横軸が計測した正しい摩耗状態、縦軸が画像から推定した摩耗状態となります。かなりしっかりと推定ができているように見えると思います。

f:id:Wok:20191223195757p:plain

上図ではテストサンプルの件数が多くわかりにくいですが、推定した摩耗状態と実際の摩耗状態の誤差をバイオリンチャートで表現すると次の図のようになっています。 縦軸が0を中心に摩耗状態の段階誤差の範囲を示しています。ほぼ1段階以内の誤差で推定することができています。 (自画自賛ですがすごいと思ってる!興味のある方はご連絡を。)

f:id:Wok:20191223195822p:plain

この回帰の技術に関連して、ちょっとした実験をしてみましたのでご紹介します。

実験の動機

まず、実験をしてみた動機からご説明します。

みなさんもGoogleなどで検索していただくとわかると思いますが、画像で回帰問題を解くことについてWeb上にはそれほど多くの情報がありません。大体が上記ブログと同様に年齢推定の話です。 FLECTでも回帰問題を扱うことになりブログで情報発信をしようと思ってその時にその理由がわかったのですが、業務上扱っているデータは当然あるのですがブログで扱えるような公のデータがないのです(私の観測範囲では)。 なので、実験的に何かをしてブログを書くというのが難しかったのではないかと考えています。

で、我々も同じ理由でだいぶ長い間ブログを書けずにいたのですが、あるとき、ぱっと、あることを思い出しました。 そういえば、昔、きゅうりの仕分け(ランクづけ)を画像分類でやっていた人がいたなぁ、と。

それからずっと、このきゅうりのランクを回帰で推定できないか?というのが頭から離れずにいたのですが、この年末幸運にも少し時間が空いたので早速実験をしてみた、ということになります。

ということで、以下、きゅうりのランクを回帰で推定する話をつらつらと書いていきます。 なお、今回用いたソースコード(notebook)は文末のgit repositoryに格納しておきます。

きゅうりの仕分けとデータセット

ご存知の方も多いと思いますが、自動車部品メーカーに務められていた方が退職して、Deep Learningを用いてきゅうりの仕分けをするシステムを作りました。 当時だいぶバズって、Googleのブログにも特集されています。詳細はこの記事を見ていただくのが良いかと思います。

cloudplatform-jp.googleblog.com

この仕分けシステムでは、2LからCまで、きゅうりを9つのランクに仕分けることができるそうです。 そして、この学習用のデータはなんとgitで公開されているのです(Creative Commons)。ありがたい!

github.com

今回は、この中でもprototype_2の学習用のデータを用いて実験を行いました。 なお、このprototype_2は上下側面の3方向からの撮影したデータとなっていますが、今回は上から撮影したもののみを用いることにしました。 また、画像の中には、手が写り込んでいたりする写真などが含まれていますが、これらは事前に取り除きました(OTHERラベルがついている)。

一応、分布を確認するとこんな感じに各クラス800枚前後でまんべんなく格納されています。ちなみに横軸の数値は2L〜Cまでのランクに対応づいています。縦軸は画像数。

f:id:Wok:20191220152417p:plain

なお、本データはCIFAR10の形式で格納されていますので、読み出し方はCIFAR10のサイトかgit repositoryを参考にしてください。

ネットワーク

今回はresnet v2を用いてやってみました。ポイントはFlatten()のあとにsoftmaxではなくreluを活性化関数とした全結合層を置くことです。 また、今回はCIFAR10と同じ画像なので、Inputのサイズは32x32x3になります。

    <略>
    x = AveragePooling2D(pool_size=8)(x)
    x = Flatten()(x)
    x = Dense(8, activation='relu')(x)
    outputs = Dense(1)(x)

モデル全体は次のような感じです。 なお、ある程度エイヤで作っているネットワークなので、チューニングの余地は有りますので、興味のある方は適当にいじってみるといいと思います。

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_8 (InputLayer)            (None, 32, 32, 3)    0                                            
__________________________________________________________________________________________________
conv2d_218 (Conv2D)             (None, 32, 32, 16)   448         input_8[0][0]                    
__________________________________________________________________________________________________
batch_normalization_197 (BatchN (None, 32, 32, 16)   64          conv2d_218[0][0]                 
__________________________________________________________________________________________________
activation_197 (Activation)     (None, 32, 32, 16)   0           batch_normalization_197[0][0]    
__________________________________________________________________________________________________
conv2d_219 (Conv2D)             (None, 32, 32, 16)   272         activation_197[0][0]             
__________________________________________________________________________________________________
batch_normalization_198 (BatchN (None, 32, 32, 16)   64          conv2d_219[0][0]                 
__________________________________________________________________________________________________
activation_198 (Activation)     (None, 32, 32, 16)   0           batch_normalization_198[0][0]    
__________________________________________________________________________________________________
conv2d_220 (Conv2D)             (None, 32, 32, 16)   2320        activation_198[0][0]             
__________________________________________________________________________________________________
batch_normalization_199 (BatchN (None, 32, 32, 16)   64          conv2d_220[0][0]                 
__________________________________________________________________________________________________
activation_199 (Activation)     (None, 32, 32, 16)   0           batch_normalization_199[0][0]    
__________________________________________________________________________________________________
conv2d_222 (Conv2D)             (None, 32, 32, 64)   1088        activation_197[0][0]             
__________________________________________________________________________________________________
conv2d_221 (Conv2D)             (None, 32, 32, 64)   1088        activation_199[0][0]             
__________________________________________________________________________________________________
add_64 (Add)                    (None, 32, 32, 64)   0           conv2d_222[0][0]                 
                                                                 conv2d_221[0][0]                 
__________________________________________________________________________________________________
batch_normalization_200 (BatchN (None, 32, 32, 64)   256         add_64[0][0]                     
__________________________________________________________________________________________________
activation_200 (Activation)     (None, 32, 32, 64)   0           batch_normalization_200[0][0]    
__________________________________________________________________________________________________
conv2d_223 (Conv2D)             (None, 32, 32, 16)   1040        activation_200[0][0]             
__________________________________________________________________________________________________
batch_normalization_201 (BatchN (None, 32, 32, 16)   64          conv2d_223[0][0]                 
__________________________________________________________________________________________________
activation_201 (Activation)     (None, 32, 32, 16)   0           batch_normalization_201[0][0]    
__________________________________________________________________________________________________
conv2d_224 (Conv2D)             (None, 32, 32, 16)   2320        activation_201[0][0]             
__________________________________________________________________________________________________
batch_normalization_202 (BatchN (None, 32, 32, 16)   64          conv2d_224[0][0]                 
__________________________________________________________________________________________________
activation_202 (Activation)     (None, 32, 32, 16)   0           batch_normalization_202[0][0]    
__________________________________________________________________________________________________
conv2d_225 (Conv2D)             (None, 32, 32, 64)   1088        activation_202[0][0]             
__________________________________________________________________________________________________
add_65 (Add)                    (None, 32, 32, 64)   0           add_64[0][0]                     
                                                                 conv2d_225[0][0]                 
__________________________________________________________________________________________________
batch_normalization_203 (BatchN (None, 32, 32, 64)   256         add_65[0][0]                     
__________________________________________________________________________________________________
activation_203 (Activation)     (None, 32, 32, 64)   0           batch_normalization_203[0][0]    
__________________________________________________________________________________________________
conv2d_226 (Conv2D)             (None, 32, 32, 16)   1040        activation_203[0][0]             
__________________________________________________________________________________________________
batch_normalization_204 (BatchN (None, 32, 32, 16)   64          conv2d_226[0][0]                 
__________________________________________________________________________________________________
activation_204 (Activation)     (None, 32, 32, 16)   0           batch_normalization_204[0][0]    
__________________________________________________________________________________________________
conv2d_227 (Conv2D)             (None, 32, 32, 16)   2320        activation_204[0][0]             
__________________________________________________________________________________________________
batch_normalization_205 (BatchN (None, 32, 32, 16)   64          conv2d_227[0][0]                 
__________________________________________________________________________________________________
activation_205 (Activation)     (None, 32, 32, 16)   0           batch_normalization_205[0][0]    
__________________________________________________________________________________________________
conv2d_228 (Conv2D)             (None, 32, 32, 64)   1088        activation_205[0][0]             
__________________________________________________________________________________________________
add_66 (Add)                    (None, 32, 32, 64)   0           add_65[0][0]                     
                                                                 conv2d_228[0][0]                 
__________________________________________________________________________________________________
batch_normalization_206 (BatchN (None, 32, 32, 64)   256         add_66[0][0]                     
__________________________________________________________________________________________________
activation_206 (Activation)     (None, 32, 32, 64)   0           batch_normalization_206[0][0]    
__________________________________________________________________________________________________
conv2d_229 (Conv2D)             (None, 16, 16, 64)   4160        activation_206[0][0]             
__________________________________________________________________________________________________
batch_normalization_207 (BatchN (None, 16, 16, 64)   256         conv2d_229[0][0]                 
__________________________________________________________________________________________________
activation_207 (Activation)     (None, 16, 16, 64)   0           batch_normalization_207[0][0]    
__________________________________________________________________________________________________
conv2d_230 (Conv2D)             (None, 16, 16, 64)   36928       activation_207[0][0]             
__________________________________________________________________________________________________
batch_normalization_208 (BatchN (None, 16, 16, 64)   256         conv2d_230[0][0]                 
__________________________________________________________________________________________________
activation_208 (Activation)     (None, 16, 16, 64)   0           batch_normalization_208[0][0]    
__________________________________________________________________________________________________
conv2d_232 (Conv2D)             (None, 16, 16, 128)  8320        add_66[0][0]                     
__________________________________________________________________________________________________
conv2d_231 (Conv2D)             (None, 16, 16, 128)  8320        activation_208[0][0]             
__________________________________________________________________________________________________
add_67 (Add)                    (None, 16, 16, 128)  0           conv2d_232[0][0]                 
                                                                 conv2d_231[0][0]                 
__________________________________________________________________________________________________
batch_normalization_209 (BatchN (None, 16, 16, 128)  512         add_67[0][0]                     
__________________________________________________________________________________________________
activation_209 (Activation)     (None, 16, 16, 128)  0           batch_normalization_209[0][0]    
__________________________________________________________________________________________________
conv2d_233 (Conv2D)             (None, 16, 16, 64)   8256        activation_209[0][0]             
__________________________________________________________________________________________________
batch_normalization_210 (BatchN (None, 16, 16, 64)   256         conv2d_233[0][0]                 
__________________________________________________________________________________________________
activation_210 (Activation)     (None, 16, 16, 64)   0           batch_normalization_210[0][0]    
__________________________________________________________________________________________________
conv2d_234 (Conv2D)             (None, 16, 16, 64)   36928       activation_210[0][0]             
__________________________________________________________________________________________________
batch_normalization_211 (BatchN (None, 16, 16, 64)   256         conv2d_234[0][0]                 
__________________________________________________________________________________________________
activation_211 (Activation)     (None, 16, 16, 64)   0           batch_normalization_211[0][0]    
__________________________________________________________________________________________________
conv2d_235 (Conv2D)             (None, 16, 16, 128)  8320        activation_211[0][0]             
__________________________________________________________________________________________________
add_68 (Add)                    (None, 16, 16, 128)  0           add_67[0][0]                     
                                                                 conv2d_235[0][0]                 
__________________________________________________________________________________________________
batch_normalization_212 (BatchN (None, 16, 16, 128)  512         add_68[0][0]                     
__________________________________________________________________________________________________
activation_212 (Activation)     (None, 16, 16, 128)  0           batch_normalization_212[0][0]    
__________________________________________________________________________________________________
conv2d_236 (Conv2D)             (None, 16, 16, 64)   8256        activation_212[0][0]             
__________________________________________________________________________________________________
batch_normalization_213 (BatchN (None, 16, 16, 64)   256         conv2d_236[0][0]                 
__________________________________________________________________________________________________
activation_213 (Activation)     (None, 16, 16, 64)   0           batch_normalization_213[0][0]    
__________________________________________________________________________________________________
conv2d_237 (Conv2D)             (None, 16, 16, 64)   36928       activation_213[0][0]             
__________________________________________________________________________________________________
batch_normalization_214 (BatchN (None, 16, 16, 64)   256         conv2d_237[0][0]                 
__________________________________________________________________________________________________
activation_214 (Activation)     (None, 16, 16, 64)   0           batch_normalization_214[0][0]    
__________________________________________________________________________________________________
conv2d_238 (Conv2D)             (None, 16, 16, 128)  8320        activation_214[0][0]             
__________________________________________________________________________________________________
add_69 (Add)                    (None, 16, 16, 128)  0           add_68[0][0]                     
                                                                 conv2d_238[0][0]                 
__________________________________________________________________________________________________
batch_normalization_215 (BatchN (None, 16, 16, 128)  512         add_69[0][0]                     
__________________________________________________________________________________________________
activation_215 (Activation)     (None, 16, 16, 128)  0           batch_normalization_215[0][0]    
__________________________________________________________________________________________________
conv2d_239 (Conv2D)             (None, 8, 8, 128)    16512       activation_215[0][0]             
__________________________________________________________________________________________________
batch_normalization_216 (BatchN (None, 8, 8, 128)    512         conv2d_239[0][0]                 
__________________________________________________________________________________________________
activation_216 (Activation)     (None, 8, 8, 128)    0           batch_normalization_216[0][0]    
__________________________________________________________________________________________________
conv2d_240 (Conv2D)             (None, 8, 8, 128)    147584      activation_216[0][0]             
__________________________________________________________________________________________________
batch_normalization_217 (BatchN (None, 8, 8, 128)    512         conv2d_240[0][0]                 
__________________________________________________________________________________________________
activation_217 (Activation)     (None, 8, 8, 128)    0           batch_normalization_217[0][0]    
__________________________________________________________________________________________________
conv2d_242 (Conv2D)             (None, 8, 8, 256)    33024       add_69[0][0]                     
__________________________________________________________________________________________________
conv2d_241 (Conv2D)             (None, 8, 8, 256)    33024       activation_217[0][0]             
__________________________________________________________________________________________________
add_70 (Add)                    (None, 8, 8, 256)    0           conv2d_242[0][0]                 
                                                                 conv2d_241[0][0]                 
__________________________________________________________________________________________________
batch_normalization_218 (BatchN (None, 8, 8, 256)    1024        add_70[0][0]                     
__________________________________________________________________________________________________
activation_218 (Activation)     (None, 8, 8, 256)    0           batch_normalization_218[0][0]    
__________________________________________________________________________________________________
conv2d_243 (Conv2D)             (None, 8, 8, 128)    32896       activation_218[0][0]             
__________________________________________________________________________________________________
batch_normalization_219 (BatchN (None, 8, 8, 128)    512         conv2d_243[0][0]                 
__________________________________________________________________________________________________
activation_219 (Activation)     (None, 8, 8, 128)    0           batch_normalization_219[0][0]    
__________________________________________________________________________________________________
conv2d_244 (Conv2D)             (None, 8, 8, 128)    147584      activation_219[0][0]             
__________________________________________________________________________________________________
batch_normalization_220 (BatchN (None, 8, 8, 128)    512         conv2d_244[0][0]                 
__________________________________________________________________________________________________
activation_220 (Activation)     (None, 8, 8, 128)    0           batch_normalization_220[0][0]    
__________________________________________________________________________________________________
conv2d_245 (Conv2D)             (None, 8, 8, 256)    33024       activation_220[0][0]             
__________________________________________________________________________________________________
add_71 (Add)                    (None, 8, 8, 256)    0           add_70[0][0]                     
                                                                 conv2d_245[0][0]                 
__________________________________________________________________________________________________
batch_normalization_221 (BatchN (None, 8, 8, 256)    1024        add_71[0][0]                     
__________________________________________________________________________________________________
activation_221 (Activation)     (None, 8, 8, 256)    0           batch_normalization_221[0][0]    
__________________________________________________________________________________________________
conv2d_246 (Conv2D)             (None, 8, 8, 128)    32896       activation_221[0][0]             
__________________________________________________________________________________________________
batch_normalization_222 (BatchN (None, 8, 8, 128)    512         conv2d_246[0][0]                 
__________________________________________________________________________________________________
activation_222 (Activation)     (None, 8, 8, 128)    0           batch_normalization_222[0][0]    
__________________________________________________________________________________________________
conv2d_247 (Conv2D)             (None, 8, 8, 128)    147584      activation_222[0][0]             
__________________________________________________________________________________________________
batch_normalization_223 (BatchN (None, 8, 8, 128)    512         conv2d_247[0][0]                 
__________________________________________________________________________________________________
activation_223 (Activation)     (None, 8, 8, 128)    0           batch_normalization_223[0][0]    
__________________________________________________________________________________________________
conv2d_248 (Conv2D)             (None, 8, 8, 256)    33024       activation_223[0][0]             
__________________________________________________________________________________________________
add_72 (Add)                    (None, 8, 8, 256)    0           add_71[0][0]                     
                                                                 conv2d_248[0][0]                 
__________________________________________________________________________________________________
batch_normalization_224 (BatchN (None, 8, 8, 256)    1024        add_72[0][0]                     
__________________________________________________________________________________________________
activation_224 (Activation)     (None, 8, 8, 256)    0           batch_normalization_224[0][0]    
__________________________________________________________________________________________________
average_pooling2d_8 (AveragePoo (None, 1, 1, 256)    0           activation_224[0][0]             
__________________________________________________________________________________________________
flatten_8 (Flatten)             (None, 256)          0           average_pooling2d_8[0][0]        
__________________________________________________________________________________________________
dense_15 (Dense)                (None, 8)            2056        flatten_8[0][0]                  
__________________________________________________________________________________________________
dense_16 (Dense)                (None, 1)            9           dense_15[0][0]                   
==================================================================================================
Total params: 848,497
Trainable params: 843,281
Non-trainable params: 5,216
__________________________________________________________________________________________________

レーニン

ネットワークを定義できたら、トレーニングを開始します。損失関数は今回はMSEにしています。 また、全サンプル数は7861で、これをTraining, Validation, Testでそれぞれ6:2:2に分割しています。 またこちらもエイヤですがエポック数は200にしています。トレーニングの結果は次の通り。 青線がlossでオレンジ線がval_lossです。大体130エポックで収束しているようですね。

f:id:Wok:20191220160258p:plain

評価

では、トレーニングしたモデルで評価してみましょう。 推定ランクをプロットした図です。なかなかうまく推定できているように見えます。 f:id:Wok:20191220160321p:plain

ただ、やはり、サンプル数が多いのでどの程度分散しているのかわかりにくいですね。 なので、こちらのバイオリンチャートを見ていただきましょう。 f:id:Wok:20191220160335p:plain

ご覧になっておわかりになると思いますが、ほぼ正しいランクの位置をてっぺんに正規分布していますね。 やや裾が重いところがあり、他のクラスと認識されてしまっているものが有るようですが、そこそこ正しく推定されたと見て良いと思います。 (作成者の方もおっしゃっていますが、人間がやっているのでランク付けもある程度えいやなところがあるらしいので。)

TensorFlowでディープラーニングによる『キュウリ』の仕分け | Workpiles

今回の場合は、各分類クラス間が離れていないため、人間がやっても“2LとL”や“LとBL”の判別は難しい(けっこう適当)だったりします。

といことで、今回の実験はひとまず成功したかなと思います。画像からの回帰、みなさんもチャレンジしてみてはいかがでしょうか。 結果をシェアしてもらえるとありがたいです。

しきい値調整

推定されたランクは浮動小数点となっていますが、実際のランクは整数値です。 この整数値に変換するためにしきい値を決める必要が有ります。 デフォルトで小数点第1位が5のところでしきい値にしてもいいですが、このしきい値をいい感じに調整してくれる方法も有ります。 今回は画像から回帰をすることろまでが主題となりますので、そこまでは踏み込みません。また機会があればご紹介します。

余談ですが、、、

resnet v2でクラス分類させたら、全然クラス分類のほうが精度が高かったです。 タスクによって解き方をしっかり使い分けることが重要です。この辺の話も機会があれば別途。

最後に

今回は、きゅうりの画像のランク付けを題材に、画像から回帰問題を解く方法についてご紹介しました。 最初にご紹介したとおり、画像から連続値を推定する技術は様々な分野に応用可能であり、FLECTでも実績があります。 ぜひご活用いただければありがたいと思います。

次回は、また別の技術をご紹介する予定です。ではでは。

リンク

github.com