[ML 101] Decision trees

[ML 101] Decision trees

How do decision trees work? How do decision trees choose attribute to split?

Andy Cheung · 4 minute read

This article aim to introduce decision tree and expaln what algorithm it uses to split data.

When I first use DecisionTreeClassifier() in sklearn, I came up with a question “How do the tree know which attribute to split?” What criteria were used to selecte root node? Before answering this question, we must know what a decision tree is.

question mark

What is a decision tree?

In short, decision tree ask a question then divide the dataset based on the answer.


building block of Decision Tree

Immediately we will ask what is the rule for decision tree to ask a question?

First, we need to understand the basic building block in decision tree.

  • Root is the origin of the tree, there is only one root for each tree.
  • Edge is the link between two nodes, a tree with N nodes will have maximum N-1 edges, notice that edge has direction.
  • Parent is a node with edge linked toward other nodes.
  • Child is a node with a Parent node.
  • Leaf node is the node without any child node.
  • Height of the tree is the number of edges on the longest path from the root to a Leaf node.
  • Depth of a node is number of edges from the node to root.Here is great explanation from stackoverflow

Example: img decision tree In above example, root is node with question "Is this article useful?", it has two children connected by two edges and both children are Leaf node. Height of this tree is 1, and both Leaf nodes has Depth of 1.


Select the most informative attribute

picking up stone

Finally, we can talk about the question I asked in the begining. The answer is use Entropy to find out the most informative attribute, then use it to split the data. There are three frequencly used algorithms to create a decision tree, they are:

  • Iterative Dichotomiser 3 (ID3)
  • C4.5
  • Classification And Regression Trees (CART)

they each use sligthly different method to meausre impurness of data.

Entropy

Entropy is the randomness or uncertainty, in a constant dataset entropy is zero, thus our problem become minimize entropy. Entropy of classes is defined as equation where p of i is the probability of i.

Information gain

Remember our goal is to find the most informative attribute, so find out attribute have highest information gain can help. Information gain of attribute is defined as equation

Where H before is entropy of current node, H after is weighted sum of splited entropy. Let's take a look at following example:

Income Tech_lover Age iPhone_user
high Y young Y
high N old Y
low Y middle Y
high Y middle Y
low N young N
low Y young N
low N old N
high N middle N

let say our target is to predict someone use iPhone or not, and we got the data like this.

First, entropy of class attribute is:

H = - 4/8*log2(4/8) - 4/8 * log2(4/8)

H = 1

this mean the class iPhone user is totally random. but it didn't stop us to calculate information gain of each attribute! For attribute Income, entropy for low income:

H = -3/4 log2(3/4) - 1/4log2(1/4)

H = 0.8113

similarly for high income:

H = 0.8113

the weighted entropy:

H = 4/8 * H_low + 4/8 * H_high

H = 0.8113 Information gain = entropy of class - entropy if splitted by attribute

IG = 1 - 0.8113

IG = 0.1887

Information gain for attribute income is 0.1887 Repeat above steps for other attributes. we get the following information gain: IG_income = 0.1887

IG_tech = 0.1887

IG_age = 0.0613

So, we can say that attribute age is not so informative compare to income and tech lover. ID3 use entropy and information gain to generate decision tree. has iphone

Gini impurity

Gini index can be used for impurity measurement. CART algorithm use Gini impurity to create decision tree. equation where p of i is probability of attribute i.

Split ratio

to reduce a bias towards multi-valued attributes by taking the number and size of branches into account when choosing an attribute. source

C4.5 algorithm measure impurity slightly different compare to ID3, C4.5 use normalized information gain

normalized information gain

Where N(t_i) is number of event i appears. N(t) is total count of events and t is set of events. This normalized information gain ensure the split is effective.

Summary

fundamental function of decision tree is divide data into subset. Impurity measurement is used to find the most informative attribute, for different algorithms there are different ways to measure impurity, in real world three of the algorithms were flavored, iD3, C4.5 and CART.

In real life, decision tree often have problem of overfitting, in this case multiple trees can make a better decision, which I will discuss later.


My Twitter

Decision tree
ML