Multiclass Classification: All you need to know

Reading Time: 3 minutes

In this blog we are diving into a multiclass classification problem, and how you can make a good predictive model over an image data set.

Classification is one of the most important practice among the AI/ML practitioner. It it important to master this skill as in real world see a lot of use of this. Ex: Tesla uses it to classification among with other models to identify objects nearby the car. And it is pretty successful to identify humans and cars differently.

What is classification?

A classification means able to identify a set of data as a part of single class. A machine learning model is able to classify the input data, by finding similar trends or patterns within the data, into a set of classes. When it comes to image classification, the model is basically trained to identify different objects.

Multiclass classification is a classification task with more than two classes. Each sample can only be labelled as one class. For example, classification using features extracted from a set of images of fruit, where each image may either be of an orange, an apple, or a pear.

Practical Implementation:

So now let’s see how you can implement a basic multiclass classifier. For this demo we will built a TensorFlow keras model and will use a mnist digits data set which is readily available in TensorFlow. A digit data set is basically a hand written images of digits and we will categorise them to what it looks like. So let’s get started:

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

First we load the digits data set from keras library and will separate it into training and testing set.

mnist = tf.keras.datasets.mnist
(x_train,y_train) , (x_test,y_test) = mnist.load_data()
x_train = tf.keras.utils.normalize(x_train,axis=1)
x_test = tf.keras.utils.normalize(x_test,axis=1)

Now the below function will help you to view the image data.\

def draw(n):
This is the 1st image in the training set which you can see, looks like 5.

Now lets make a model using tensorflow keras library. Note that the input shape is 28*28 pixels. And the output neuron will have 10 different outputs. I have run it for 3 iterations but feel to do your own experiments.

model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Flatten(input_shape=(28, 28)))


Now that we have trained our model, let’s check the accuracy on the testing set.

val_loss,val_acc = model.evaluate(x_test,y_test)
print("loss-> ",val_loss,"\naccuracy-> ",val_acc)

//Output for loos will be around 0.08802 and accuracy around 0.9728

Great! now that we have built a model with good accuracy, you can pick a number from data set and check if the model is predicting right.

print('label -> ',y_test[2])
print('prediction -> ',np.argmax(predictions[2]))
As we can see the looks like 1 and our model has identified it correctly.


So we have now seen how a multiclass classification can be built and it is useful in identifying identities. Free free to explore them using other data set as well. Hope this blog helps you in understanding some basics.