ねほり.com

何もないから何かみつかる

KaggleのDigit Recognizerで画像分析(skorch編)

      2020/07/23

前回の日記で「Pytorchは難しい」という記載しましたが、このままで終わることはできません。

なぜなら、将来的に「Tensorflow」と「Pytorch」のどちらが生き残るか現時点では分からないからです。

ネットで調べると、PytorchをラップするScikit-Learning互換のニューラルネットワークライブラリ「skorch」というものを見つけました。

今回は「skorch」を試してみます。

[参考] 過去の機械学習関係の記事

skorchとは何?

前述通り、PytorchをラップするScikit-Learning互換のニューラルネットワークライブラリです。

つまり、skorchを用いて作成したmodelオブジェクトには、scikit-learnのようにfitやpredictなどのメソッドが一通り揃っています。
 

ただし、マイナーなライブラリで利用者は少ないです。

少ないというか皆無です。比例するようにググっても情報が少ないです。

現在の最新は2019年11月の0.7.0です。

v0.1.0が2017年12月に登場し、Pytorchのバージョンアップに伴い進化してきていますが、未だにalpha版的な扱いです。

コミュニティもあまり活発ではなく、急に辞めてしまう可能性があります。

ただ、このままではPytorchの学習が進まないので、使ってみます。

インストール

pipでインストールができました。

使い方

結論から言うとScikit-Learning形式なので超簡単でした。

学習モデルの記載方法

これはPytorchの記述をそのまま利用します。

要するに「Define by run」な学習モデルを記載可能です。

学習用データ、検証用データの作成方法

今まで通りです。

ただし、Kerasとはデータのフォーマットの変換形式が異なります。

【Kerasの場合】

【Pytorchの場合】

ただ、Kerasのようにy_train に対してOnehotエンコーディングをすると、うまく学習が進みませんでした・・・(誰か理由を教えて下さい・・・)

学習方法の記載方法

skorch.net に次の3種類が用意されています。

関数説明
NeuralNetscikit-learnライクなmodelオブジェクトを作成するskorchのクラス。GridSearchCV(model, param_grid, scoring=’accuracy’)はできない
NeuralNetClassifier分類器をsklearn風に。Netオブジェクトのforwardメソッドの最後の活性化関数は必ずF.softmax(dim=-1)すること
NeuralNetRegressor回帰をsklearn風に

 

説明を読む限り、少し挙動に制限がありそうです。

初期化の際に、学習の仕方を決めます。パラメータは、PyTorchの関数を使用できます。

  • criterion : 損失関数の設定
  • optimizer : 最適化関数の設定
  • lr : 学習率の決定
  • module : pytorchで実装したnn.Module継承クラス
  • max_epochs : Epoch数
  • batch_size : ミニバッチサイズ
  • device : GPUの設定

その上で、訓練データの学習はScikit-Learning形式で次のように記載します。

予測の記載方法

これもScikit-Learning形式なので簡単です。

なお、skorch.NeuralNetを用いて作成したmodelオブジェクトは、最後のLinear層->log_softmaxの活性化が施された値で「y_pred.shape=(10000, 10)」となります。

このため、y_predに対してargmax(axis=1)を取る必要があります。

結果

出力されたCSVをサブミットすると「scored 0.98585」でKerasの「scored 0.98700」より少し低かったです。なぜ・・・。

なお、実際の出力画面は前回に比べて良くなったエポックに関しては色がつくよう仕様になっています。

まとめ

容易さはKerasと同じレベルです。

しばらくディープラーニングで解く場合には両方で記述するように努力してみます。

ただ「skorch」は癖も強そうなので、早めに「Pytorch」を使えるようになる必要はありそうです。

ソースコード

後学に向けてコメントも多くつけてます。

 - 2020年(社会人16年), 機械学習, テクノロジー

  関連記事

情報処理学会「CVIM研究会」で「卒論セッション」発表

さて、2月頃には提出が決まっていた情報処理学会「コンピュータビジョンとイメージメ …

Webサービスを支えるトレンド技術まとめ

FishEyeやCrucibleを使って開発を進め、Node.js+Expres …

i-gotU GT-600を買いました。

GPSデバイスを先週買った。が・・・先週はどこにも行ってないのでログを載せれず。 …

HTML5 canvas+JavaScriptでオセロ作成

2011年11月27日(日) HTML5 canvas+JavaScriptでオ …

千葉市検見川浜・稲毛海岸でマテ貝・シオフキガイ捕りと砂出し調理法

昨日、マテ貝を大量に貰ったので、自分たちでも挑戦したくなり「マテ貝」を取りに来ま …

KaggleのHouse Pricesで回帰分析(EDA:Exploratory Data Analysis編)

今回は、kaggleの入門者向けチュートリアルコンペ「住宅価格予測」をやってみま …

printf関数が自作できないと「C言語が書ける」と言うなかれ

2005年07月10日(日) C言語 プリンタを購入。やっぱ必要になりました・・ …

「ネットランナー」のトレーディングカード「ねとらん者」を大人買い(1/2)

もしも 童話世界に 2ch を作ったら ‥‥?  ■掲示板に戻る■ 全 …

curlと1024バイトとExpect: 100-continue

GWも出社して動作確認の手使いと、今月末リリース予定の実装を進め中。 山手線も会 …

「ついに証明された、新型コロナは空気感染する」はデマか?本当か?

新型コロナウイルスがまだ猛威を奮っています。   薬局に開店前に40分 …