In this blog, we are going to go through about one of the widely used classification algorithm called KNN (K-Nearest Neighbors).
Since I started doing data science, I observed that most of the problems end up with classification model
The main reason behind this biased property is, most of the analytic problems are based on decision making.
For instance, to identify loan applicants as low, medium, or high credit risks and whether an email is a spam or not, etc. These analyses are more insightful and directly linked to an implementation roadmap.
Index for this blog:
- what is KNN (K-Nearest Neighbors) is?
- How does KNN algorithm actually work?
- How do we choose the factor K?
- Pseudocode for KNN (K-Nearest Neighbors)
- Implementation in python from scratch and using scikit-learn
what is KNN (K-Nearest Neighbors) is?
k-Nearest Neighbors can be used for both classification and regression. K nearest neighbors is a simple algorithm that stores all available cases and classifies new cases based on a similarity measure]
KNN is a type of instance-based learning, or lazy learning, where the function is only approximated locally and all computation is deferred until classification. The k-NN algorithm is among the simplest of all machine learning algorithms.
How does the k-Nearest Neighbors algorithm actually work?
Let’s take a simple case to understand this algorithm. Following is a spread of red circles (RC) and green squares (GS) :
You intend to find out the class of the blue star (BS) . BS can either be RC or GS and nothing else. The “K” is KNN algorithm is the nearest neighbors we wish to take vote from. Let’s say K = 3. Hence, we will now make a circle with BS as center just as big as to enclose only three datapoints on the plane. Refer to following diagram for more details:
The 3 nearest points to SB is all RC. Hence, with smart confidence level we are able to say that the bs should belong to the category RC. Here, the choice became very obvious as all 3 votes from the nearest neighbor went to RC. the choice of the parameter K is extremely crucial during this algorithm. Next we will be going to understand what are the factors to be considered to conclude the most effective K.
Note : Some assumptions of KNN:
- Chose an odd vale of K when you have 2 classes to avoid ties. i.e. If the new data point is right between the two classes it cannot decide which one to go with.
- K must not be a multiple of the number of classes
- If K is very small ( Overfit ), will not be accurate if you have many data points (n)
- If K is very large ( Underfit ), K must not be equal to the number of data points n
How do we choose the factor K?
First let us try to understand what exactly does K influence in the algorithm. If we see the last example, given that all the 6 training observation remain constant, with a given K value we can make boundaries of each class.
As you can see, the error rate at K=1 is always zero for the training sample. This is because the closest point to any training data point is itself.Hence the prediction is always accurate with K=1. If validation error curve would have been similar, our choice of K would have been 1. Following is the validation error curve with varying value of K:
This makes the story more clear. At K=1, we were overfitting the boundaries. Hence, error rate initially decreases and reaches a minima. After the minima point, it then increase with increasing K. To get the optimal value of K, you can segregate the training and validation from the initial dataset. Now plot the validation error curve to get the optimal value of K. This value of K should be used for all predictions.
For more you can refer to Determined K value
Pseudocode for KNN(K-Nearest Neighbors)
Anyone can implement a KNN model by following given below steps of pseudocode.
- Load the data
- Initialize the value of k
- To getting the predicted class, iterate from 1 to the total number of training data points
- Calculate the distance between test data and each row of training data. Here we will use Euclidean distance as our distance metric since it’s the most popular method. The other metrics that can be used are Chebyshev, cosine, etc.
- Sort the calculated distances in ascending order based on distance values
- Get top k rows from the sorted array
- Get the most frequent class of these rows
- Return the predicted class