Coursera Machine Learning - Week3-2 Regularization

CourseraのMachine Learningについてまとめています。 前回は、Week3の前半、Classificationについてまとめました。

今回は、Week3の後半、Regularizationについてです。


Regularization

Regularizationは、オーバーフィットを避けるために行います。Linear regressionとLogistic regressionのオーバーフィットのイメージです。


Linear regression_Under fit Over fit


Logistic regression_Under fit Over fit

Overfittingを避けるためのオプションとしては以下があります。

  • Feature(変数、特徴量)の数を減らす
    • どのFeatureを使うのかマニュアルで選ぶ
    • Model selection algorithm(本Machine Learningコースで後ほど学ぶそうです)
  • Regularization
    • 全てのFeatureをキープするが、パラメータの大きさを小さくする(多くのFeatureがあり、それらが少しずつ予測に寄与している場合にうまく動くそうです)

 

Regularizationでは、Overfitを避けるために、パラメータθにペナルティをかけます。具体的には、下記の図のように、コスト関数にθの二乗を加え、θが大きくなり過ぎないようにします。こうすることで、予測のラインがグニャグニャの複雑すぎるものになることを防ぎます。


Intuition of Regularization

Regularizationを行う場合のLinear Regressionのコスト関数は、以下になります。


Regularization Cost function

Regularizationの対象とするθは、θ0(Featureと関係のないバイアス項)を除いた、θ1以降です。また、λはRegularizationパラメータで、マニュアルで決める必要があります。λが大き過ぎると、θが全て0に近くため、Under fitしてしまい、小さ過ぎると、Regularizationの意味がなくなり、Overfitしてしまいます。

Gradient Descentについては、これまでと同様に、コスト関数をθで微分したものにLearning rateを掛けて更新していきます。具体的には下記の数式になります。


Regularization Gradient Descent


Logistic Regressionの場合も、Linear Regressionと同様に考えれば大丈夫です。コスト関数は、下記のものになり、Gradient Descentは、Linear Regressionと同じになります。


Regularized Logistic Regression Cost Function

 

プログラミング演習

Week3のプログラミング演習では、前半で学んだLogistic Regressionを実装し、その後、RegularizedされたLogistic Regressionを実装します。演習の項目は、具体的には以下のものです。

  • Sigmoid Function
  • Compute cost for logistic regression
  • Gradient for logistic regression
  • Predict Function
  • Compute cost for regularized LR
  • Gradient for regularized LR

 


次回は、Week4 Neural Netoworks Representationについてまとめます。ついに、Neural Networkが登場します。

 

コース全体の目次とそのまとめ記事へのリンクは、下記の記事にまとめていますので、参照ください。

Coursera Machine Learningまとめ



本記事を読んでいただきありがとうございます。
機械学習を実際に使うにあたり、Coursera MLと合わせておすすめしたい書籍を紹介します。


Pythonではじめる機械学習 ―scikit-learnで学ぶ特徴量エンジニアリングと機械学習の基礎

scikit-learnを用いた機械学習を学ぶのに最適な本です。



ゼロから作るDeep Learning ―Pythonで学ぶディープラーニングの理論と実装

Deep Learningと言えばこれ。TensorFlowやPyTorch等のフレームワークを用いずに、基礎の理論からDeep Learningを実装します。Week4Week5の記事を読んで、より深く理解したいと思った人におすすめです。



Kaggleで勝つデータ分析の技術

データ分析について学び始めた人におすすめです。