Decision Tree in Apache Spark

Reading Time: 3 minutes

What is Decision Tree?

A decision tree is a powerful method for classification and prediction and for facilitating decision making in sequential decision problems. A Decision tree is made up of two components:-
  • A Decision
  • Outcome

and, A Decision tree includes three type of Nodes:-

  • Root node: The top node of the tree comprising all the data.
  • Splitting node: A node that assigns data to a subgroup.
  • Terminal node: Final decision (outcome).

To reach to an outcome or to get the result, the process starts with the root node, based on decision made on root node the next node i.e. splitter node is selected and based on decision made on split node another child split node is selected and this process goes on we reach to the terminal node and value of terminal node is our outcome.

Decision Tree in Apache Spark

It might sound strange or geeky, that there is no implementation of the Decision tree in Apache Spark, well technically yes because In Apache Spark, you can find implementation of Random Forest algorithm in which number of trees can be specified by user So, under the hood Apache Spark call the Random forest with one tree.

In Apache Spark, The decision tree is a greedy algorithm that performs a recursive binary partitioning of the feature space. The tree predicts the same label for each bottom most (leaf) partition. Each partition is chosen greedily by selecting the best split from a set of possible splits, in order to maximize the information gain at a tree node.

The node impurity is a measure of the homogeneity of the labels at the node. The current implementation provides two impurity measures for classification (Gini impurity and entropy)

Decision Tree

Stopping rule

The recursive tree construction is stopped at a node when one of the following conditions is met:

  1. The node depth is equal to the trainingmaxDepth parameter.
  2. No split candidate leads to an information gain greater than.minInfoGain
  3. No split candidate produces child nodes which each have at least trainingminInstancesPerNode instances.

Useful Parameters

  • algo:- It can be either Classification or Regression.
  • numClasses :- No of classification classes.
  • maxDepth :- Maximum depth of a tree in term of nodes.
  • minInstancesPerNode :- For a node to be split further, each of its children must receive at least this number of training instances
  • minInfoGain :- For a node to be split further, the split must improve at least this much.
  • maxBins :- Number of bins used when discretizing continuous features

Preparing training data for Decision Tree

You can not directly feed any data to the Decision tree. it demands the special type of format to feed to the decision tree. You can use the HashingTF technique to convert the training data to labelled data so that decision tree can understand.This process is also known as Standardization of data.

Feeding and obtaining Result

Once data has been standardized then you can feed the same Decision Tree Algorithm for classification but before than you need to split the data for training and testing purpose i.e. to test the accuracy you need to hold some part of data for testing. You can feed the data like this.

val splits = data.randomSplit(Array(0.7, 0.3))
val (trainingData, testData) = (splits(0), splits(1))

// Train a DecisionTree model.
//  Empty categoricalFeaturesInfo indicates all features are continuous.
val numClasses = 2
val categoricalFeaturesInfo = Map[Int, Int]()
val impurity = "gini"
val maxDepth = 5
val maxBins = 32

val model = DecisionTree.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo,
  impurity, maxDepth, maxBins)

Here, data is my standardized input data, to which i split in ratio 7:3 for training and testing purpose respectively, We are using `gini` impurity with maximum depth as `5`.

Once the model is generated you can try predict the classification for the other data, but before that we need to validate the accuracy of classification for recently generated Model. You can validate or compute the accuracy by computing `Test Error`.

// Evaluate model on test instances and compute test error
val labelAndPreds = { point =>
  val prediction = model.predict(point.features)
  (point.label, prediction)
val testErr = labelAndPreds.filter(r => r._1 != r._2).count().toDouble / testData.count()
println("Test Error = " + testErr)

Less the value of Test Error better the Model is Prepared.You can take a look at running example here.

Refrences:- Apache Spark Documentation



Written by 

Principal Architect at Knoldus Inc

2 thoughts on “Decision Tree in Apache Spark4 min read

Comments are closed.