文章目录
学习来源: 日撸 Java 三百行(51-60天,kNN 与 NB)
一、具体含义
- KNN是K-Nearest Neighbor的英文简写,中文直译就是K个最近邻,有人干脆称之为“最近邻算法”。
- 字母“K”也许看着新鲜,不过作用其实早在中学就接触过。在学习排列组合时,教材都喜欢用字母“n”来指代多个,譬如“求n个数的和”,这里面也没有什么秘密,就是约定俗成的用法。而KNN算法的字母K扮演的就是与n同样的角色。K的值是多少,就代表使用了多少个最近邻。机器学习总要有自己的约定俗成,没来由地就是喜爱用“K”而不是“n”来指代多个,类似的命名方法还有后面将要提到的K-means算法。
- KNN的关键在于最近邻,光看名字似乎与分类没有什么关系,但前面我们介绍了, KNN的核心在于多数表决,而谁有投票表决权呢?就是这个“最近邻”,也就是以待分类样本点为中心,距离最近的K个点。这K个点中什么类别的占比最多,待分类样本点就属于什么类别。
二、特点
- 简单. 没有学习过程, 也被称为惰性学习 lazy learning. 类似于开卷考试, 在已有数据中去找答案.
- 本源. 找相似, 正是人类认识事物的常用方法, 隐藏于人类或者其他动物的基因里面. 当然, 人类也会上当, 例如有人把邻居的滴水观音误认为是芋头, 偷食后中毒.
- 效果好. 永远不要小视 kNN, 对于很多数据, 你很难设计算法超越它.
- 适应性强. 可用于分类, 回归. 可用于各种数据.
- 可扩展性强. 设计不同的度量, 可获得意想不到的效果.
- 一般需要对数据归一化.
- 复杂度高. 这也是 kNN 最重要的缺点. 对于每一个测试数据, 复杂度为 O ( ( m + k ) n ) O((m+k)n)O((m+k)n), 其中 n nn 为训练数据个数, m mm 为条件属性个数, k kk 为邻居个数. 代码见 computeNearests().
三、基本思路
- KNN最核心的功能“分类”是通过多数表决“投票”来完成的,具体方法是在待分类点的K个最近邻中查看哪个类别占比最多。哪个类别多,待分类点就属于哪个类别。
- 怎样确定K——是一个需要根据实际情况调节以便取得更好拟合效果的参数,可以根据交叉验证等实验方法,结合工作经验进行设置。
- KNN中常用的度量方法——欧几里得距离和曼哈顿距离。
四、具体步骤
- 找K个最近邻。KNN分类算法的核心就是找最近的K个点,选定度量距离的方法之后,以待分类样本点为中心,分别测量它到其他点的距离,找出其中的距离最近的 K个,这就是K个最近邻。
- 统计最近邻的类别占比。确定了最近邻之后,统计出每种类别在最近邻中的占比。
- 选取占比最多的类别作为待分类样本的类别。
五、实现代码
package machinelearning.knn;
import java.io.FileReader;
import java.util.Arrays;
import java.util.Random;
import weka.core.Instances;
/**
 * 
 * @author Ling Lin E-mail:linling0.0@foxmail.com
 * 
 * @version 创建时间:2022年4月30日 上午10:19:24
 * 
 */
public class KnnClassification {
	// Manhattan distance.曼哈顿距离
	public static final int MANHATTAN = 0;
	// Euclidean distance.欧几里得距离
	public static final int EUCLIDEAN = 1;
	// The distance measure.
	public int distanceMeasure = EUCLIDEAN;
	// A random instance
	public static final Random random = new Random();
	// The number of neighbors
	int numNeighbors = 7;
	// The whole dataset.
	Instances dataset;
	// The training set. Represented by the indices of the data.
	int[] trainingSet;
	// The testing set. Represented by the indices of the data.
	int[] testingSet;
	// The predictions.
	int[] predictions;
	/**
	 * The first constructor.
	 * 
	 * @param paraFilename
	 *            The arff filename.
	 * 
	 */
	public KnnClassification(String paraFilename) {
		try {
			FileReader fileReader = new FileReader(paraFilename);
			dataset = new Instances(fileReader);
			// The last attribute is the decision class.
			dataset.setClassIndex(dataset.numAttributes() - 1);
			fileReader.close();
		} catch (Exception ee) {
			System.out.println("Error occurred while trying to read \'" + paraFilename
					+ "\' in KnnClassification constructor.\r\n" + ee);
			System.exit(0);
		} // Of try
	}// Of the first constructor
	/**
	 * Get a random indices for data randomization.
	 * 
	 * @param paraLength
	 *            The length of the sequence.
	 * @return An array of indices,e.g., {4, 3, 1, 5, 0, 2} with length 6.
	 */
	public static int[] getRandomIndices(int paraLength) {
		int[] resultIndices = new int[paraLength];
		// Step 1. Initialize.
		for (int i = 0; i < paraLength; i++) {
			resultIndices[i] = i;
		} // Of for i
		// Step 2. Randomly swap.
		int tempFirst, tempSecond, tempValue;
		for (int i = 0; i < paraLength; i++) {
			// Generate two random indices.
			tempFirst = random.nextInt(paraLength);
			tempSecond = random.nextInt(paraLength);
			// Swap.
			tempValue = resultIndices[tempFirst];
			resultIndices[tempFirst] = resultIndices[tempSecond];
			resultIndices[tempSecond] = tempValue;
		} // Of for i
		return resultIndices;
	}// Of getRandomInndices
	/**
	 * Split the data into training and testing parts.
	 * 
	 * @param paraTrainingFraction
	 *            The fraction of the training set.
	 */
	public void splitTrainingTesting(double paraTrainingFraction) {
		int tempSize = dataset.numInstances();
		int[] tempIndices = getRandomIndices(tempSize);
		int tempTrainingSize = (int) (tempSize * paraTrainingFraction);
		trainingSet = new int[tempTrainingSize];
		testingSet = new int[tempSize - tempTrainingSize];
		for (int i = 0; i < tempTrainingSize; i++) {
			trainingSet[i] = tempIndices[i];
		} // Of for i
		for (int i = 0; i < tempSize - tempTrainingSize; i++) {
			testingSet[i] = tempIndices[tempTrainingSize + i];
		} // Of for i
	}// Of splitTrainingTesting
	/**
	 * Predict for the whole testing set. The results are stored in predictions.
	 * #see predictions.
	 */
	public void predict() {
		predictions = new int[testingSet.length];
		for (int i = 0; i < predictions.length; i++) {
			predictions[i] = predict(testingSet[i]);
		} // Of for i
	}// Of predict
	/**
	 * Predict for given instance.
	 * 
	 * @return The prediction.
	 */
	public int predict(int paraIndex) {
		int[] tempNeighbors = computeNearests(paraIndex);
		int resultPrediction = simpleVoting(tempNeighbors);
		return resultPrediction;
	}// Of predict
	/**
	 * The distance between two instances.
	 * 
	 * @param paraI
	 *            The index of the first instance.
	 * @param paraJ
	 *            The index of the second instance.
	 * @return The distance.
	 */
	public double distance(int paraI, int paraJ) {
		double resultDistance = 0;
		double tempDifference;
		switch (distanceMeasure) {
		case MANHATTAN:
			for (int i = 0; i < dataset.numAttributes() - 1; i++) {
				tempDifference = dataset.instance(paraI).value(i) - dataset.instance(paraJ).value(i);
				if (tempDifference < 0) {
					resultDistance -= tempDifference;
				} else {
					resultDistance += tempDifference;
				} // Of if
			} // Of for i
			break;
		case EUCLIDEAN:
			for (int i = 0; i < dataset.numAttributes() - 1; i++) {
				tempDifference = dataset.instance(paraI).value(i) - dataset.instance(paraJ).value(i);
				resultDistance += tempDifference * tempDifference;
			} // Of for i
			break;
		default:
			System.out.println("Unsupported distance measure: " + distanceMeasure);
		}// Of switch
		return resultDistance;
	}// Of distance
	/**
	 * Get the accuracy of the classifier.
	 * 
	 * @return The accuracy.
	 */
	public double getAccuracy() {
		// A double divides an int gets another double.
		double tempCorrect = 0;
		for (int i = 0; i < predictions.length; i++) {
			if (predictions[i] == dataset.instance(testingSet[i]).classValue()) {
				tempCorrect++;
			} // Of if
		} // Of for i
		return tempCorrect / testingSet.length;
	}// Of getAccuracy
	/**
	 * Compute the nearest k neighbors. Select one neighbor in each scan. In
	 * fact we can scan only once. You may implement it by yourself.
	 * 
	 * @param paraK
	 *            the k value for kNN.
	 * @param paraCurrent
	 *            current instance. We are comparing it with all others.
	 * @return the indices of the nearest instances.
	 */
	public int[] computeNearests(int paraCurrent) {
		int[] resultNearests = new int[numNeighbors];
		boolean[] tempSelected = new boolean[trainingSet.length];
		double tempDistance;
		double tempMinimalDistance;
		int tempMinimalIndex = 0;
		// Select the nearest paraK indices.
		for (int i = 0; i < numNeighbors; i++) {
			tempMinimalDistance = Double.MAX_VALUE;
			for (int j = 0; j < trainingSet.length; j++) {
				if (tempSelected[j]) {
					continue;
				} // Of if
				tempDistance = distance(paraCurrent, trainingSet[j]);
				if (tempDistance < tempMinimalDistance) {
					tempMinimalDistance = tempDistance;
					tempMinimalIndex = j;
				} // Of if
			} // Of for j
			resultNearests[i] = trainingSet[tempMinimalIndex];
			tempSelected[tempMinimalIndex] = true;
		} // Of for i
		System.out.println("The nearest of " + paraCurrent + " are: " + Arrays.toString(resultNearests));
		return resultNearests;
	}// Of computeNearests
	/**
	 * Voting using the instances.
	 * 
	 * @param paraNeighbors
	 *            The indices of the neighbors.
	 * @return The predicted label.
	 */
	public int simpleVoting(int[] paraNeighbors) {
		int[] tempVotes = new int[dataset.numClasses()];
		for (int i = 0; i < paraNeighbors.length; i++) {
			tempVotes[(int) dataset.instance(paraNeighbors[i]).classValue()]++;
		} // Of for i
		int tempMaximalVotingIndex = 0;
		int tempMaximalVoting = 0;
		for (int i = 0; i < dataset.numClasses(); i++) {
			if (tempVotes[i] > tempMaximalVoting) {
				tempMaximalVoting = tempVotes[i];
				tempMaximalVotingIndex = i;
			} // Of if
		} // Of for i
		return tempMaximalVotingIndex;
	}// Of simpleVoting
	/**
	 * The entrance of the program.
	 * 
	 * @param args
	 *            Not used now.
	 */
	public static void main(String args[]) {
		KnnClassification tempClassifier = new KnnClassification("D:/00/data/iris.arff");
		tempClassifier.splitTrainingTesting(0.8);
		tempClassifier.predict();
		System.out.println("The accuracy of the classifier is: " + tempClassifier.getAccuracy());
	}// Of main
}// Of class KnnClassification
运行结果:
 










