What is Decision Tree?
- A Decision
- Outcome
and, A Decision tree includes three type of Nodes:-
-
Root node: The top node of the tree comprising all the data.
-
Splitting node: A node that assigns data to a subgroup.
-
Terminal node: Final decision (outcome).
To reach to an outcome or to get the result, the process starts with the root node, based on decision made on root node the next node i.e. splitter node is selected and based on decision made on split node another child split node is selected and this process goes on we reach to the terminal node and value of terminal node is our outcome.
Decision Tree in Apache Spark
It might sound strange or geeky, that there is no implementation of the Decision tree in Apache Spark, well technically yes because In Apache Spark, you can find implementation of Random Forest algorithm in which number of trees can be specified by user So, under the hood Apache Spark call the Random forest with one tree.
In Apache Spark, The decision tree is a greedy algorithm that performs a recursive binary partitioning of the feature space. The tree predicts the same label for each bottom most (leaf) partition. Each partition is chosen greedily by selecting the best split from a set of possible splits, in order to maximize the information gain at a tree node.
The node impurity is a measure of the homogeneity of the labels at the node. The current implementation provides two impurity measures for classification (Gini impurity and entropy)
Stopping rule
The recursive tree construction is stopped at a node when one of the following conditions is met:
- The node depth is equal to the training
maxDepth
parameter. - No split candidate leads to an information gain greater than.
minInfoGain
- No split candidate produces child nodes which each have at least training
minInstancesPerNode
instances.
Useful Parameters
- algo:- It can be either
Classification
orRegression
. - numClasses :- No of classification classes.
- maxDepth :- Maximum depth of a tree in term of nodes.
- minInstancesPerNode :- For a node to be split further, each of its children must receive at least this number of training instances
- minInfoGain :- For a node to be split further, the split must improve at least this much.
- maxBins :- Number of bins used when discretizing continuous features
Preparing training data for Decision Tree
You can not directly feed any data to the Decision tree. it demands the special type of format to feed to the decision tree. You can use the HashingTF technique to convert the training data to labelled data so that decision tree can understand.This process is also known as Standardization of data.
Feeding and obtaining Result
Once data has been standardized then you can feed the same Decision Tree Algorithm for classification but before than you need to split the data for training and testing purpose i.e. to test the accuracy you need to hold some part of data for testing. You can feed the data like this.
val splits = data.randomSplit(Array(0.7, 0.3)) val (trainingData, testData) = (splits(0), splits(1)) // Train a DecisionTree model. // Empty categoricalFeaturesInfo indicates all features are continuous. val numClasses = 2 val categoricalFeaturesInfo = Map[Int, Int]() val impurity = "gini" val maxDepth = 5 val maxBins = 32 val model = DecisionTree.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo, impurity, maxDepth, maxBins)
Here, data is my standardized input data, to which i split in ratio 7:3 for training and testing purpose respectively, We are using `gini` impurity with maximum depth as `5`.
Once the model is generated you can try predict the classification for the other data, but before that we need to validate the accuracy of classification for recently generated Model. You can validate or compute the accuracy by computing `Test Error`.
// Evaluate model on test instances and compute test error val labelAndPreds = testData.map { point => val prediction = model.predict(point.features) (point.label, prediction) } val testErr = labelAndPreds.filter(r => r._1 != r._2).count().toDouble / testData.count() println("Test Error = " + testErr)
Less the value of Test Error better the Model is Prepared.You can take a look at running example here.
Refrences:- Apache Spark Documentation
Reblogged this on pranjut.
Reblogged this on Coding, Unix & Other Hackeresque Things.