k近傍法アルゴリズムをフルスクラッチで実装してみたよ。

k近傍法アルゴリズムをフルスクラッチで実装してみたよ。

こんにちは!エンジョンワークス 機械学習エンジニアのshunです!

今日はもっともシンプルな機械学習アルゴリズムと言われている、K近傍方(KNN)をpythonで実装してみましょう。

Knnの特徴は、シンプルが故に教師データからは学習しません、線形回帰のようにコスト関数を求める工程がないのです。なので教師データからパラメーターを調整することがなく、教師データを丸暗記します。怠惰学習(lazy learner)とも言われている所以です。

具体的に何に使うのでしょうか?

分類問題に適用できます。

  • ユーザーの購買行動予測
  • おすすめ映画予測
  • クレジットカードの不正利用検知

では、K近傍方(KNN)とは具体的に何でしょうか?

図の場合、未知のデータ☆はどちらのクラス(紫?黄色?)に属するでしょうか?

クラスB(紫)ですね。

kとは何でしょう?☆を推論したい場合、もっとも近くにあるデータポイントの数です。

k=3とすると、直近の3ポイントを比較します。☆はクラスBの2ポイントと近いですね。

k=6の場合はどうでしょうか?クラスAが4ポイントですね。

☆はこの場合クラスAとなります。

このようにどちらのクラスに所属するかは、多数決で決めていきます。

では、未知のデータとの距離はどうやって測定するのでしょうか?

かの有名なユークリッド距離で測定します。これで座標同士の距離を測定します。データ間でまっすぐ定規を引くイメージです。

エッセンスはこれだけです。

未知のデータをプロットして、直近のk個との距離を測る。

シンプルですよね?

シンプル故、強力なアルゴリズムです。

早速、実装してみましょう。プロダクションであればscikit-learnを使えばいいと思います。しかし、一からアルゴリズムを組んだほうが、理解が進みますよね。

1 .ユークリッド距離

import math
 def euclidanDistance(instance1,instance2,lenght):
     distance = 0
     for x in range(lenght):
         distance += pow((instance1[x] - instance2[x]),2)
     return math.sqrt(distance)
 data1 = [2,2,2,'a']
 data2 = [4,4,4,'b']
 distance = euclidanDistance(data1,data2,3)
 print('Distance:'  + repr(distance))

2 .もっとも近いデータポイントを返す。

import operator

def getNeighbors(trainingSet,testInstance,k):

    distance = []
    length = len(testInstance)-1

    for x in range(len(trainingSet)):
        dist = euclidanDistance(testInstance,trainingSet[x],length)
        distance.append((trainingSet[x],dist))
    distance.sort(key=operator.itemgetter(1))
    neighbors = []

    for x in range(k):
        neighbors.append(distance[x][0])
    return neighbors

実行してみます。

trainSet = [[2,2,2,'a'],[4,4,4,'b']]
testInstance = [5,5,5]
k  =1

neighbors = getNeighbors(trainSet,testInstance,1)
print(neighbors)

3 .どのクラスに属すか推論する

import operator

def getResponse(neighbors):
    classVotes = {}
    for x in range(len(neighbors)):
        response = neighbors[x][-1]
        if response in classVotes:
            classVotes[response] += 1
        else:
            classVotes[response] = 1
    sortedVotes = sorted(classVotes.items(),key=operator.itemgetter(1),reverse=True)

    return sortedVotes[0][0]

#neighbors = [[1,1,1,'a'],[2,2,2,'a'],[3,3,3,'b']]
response = getResponse(neighbors)
print(response)

4 .精度を計算

def getAccuracy(testSet,predictions):
    correct = 0

    for x in range(len(testSet)):
        if testSet[x][-1] is predictions[x]:
            correct +=1
    return (correct/float(len(testSet)))  * 100.0

testSet = [[1,1,1,'a'],[2,2,2,'a'],[3,3,3,'b']]
predictions = ['a','a','a']
accuracy = getAccuracy(testSet,predictions)

print(accuracy)

5 .最後に作った関数を1つにしてみましょう。

def main(trainingSet,testSet):

    print('Train set:' + repr(len(trainingSet)))
    print('Test set:' + repr(len(testSet)))
    predictions = []
    k=1

    for x in range(len(testSet)):
        neighbors = getNeighbors(trainingSet,testSet[x],k)
        result = getResponse(neighbors)
        predictions.append(result)
        print('> predicted=' + repr(result)+ ',actual='+ repr(testSet[x][-1]))
    accuracy = getAccuracy(testSet,predictions)
    print('Accuracy: ' + repr(accuracy)+'%')

trainSet = [[2,2,2,'a'],[4,4,4,'b'],[3,3,3,'c'],[7,7,7,'e']]
testSet =[[8,8,8,'e']]
main(trainSet,testSet)

いかがでしたか?機械学習はライブラリやクラウドサービスが出揃って便利になりましたね。なので機械学習を実装するのは簡単です。一方で本当に理解したければ、数式やアルゴリズムレベルでマスターすることが大切です。

またデータセットも予め用意されているデータを使うのが楽ですが、自分が推論したいリアル問題を扱ったデータセットを作ってみるのも理解の助けになります。irisデータでscikit-learnだけだと、何となく理解した気になりますよね。

特に自分が関心ある分野のデータを使うとよく理解できます。

最後まで読んでくれた人はありがとうございます。

エンジョイワークスではバックエンドエンジニアを募集しております。プログラミングを通じて空き家問題を解決したい人、街づくりする人を増やしたい人、ご応募お待ちしております!

応募はこちらから↓
https://enjoyworks.jp/recruit

一覧へ戻る

最新記事