Decision trees.

About This Module

Prework

Prework Reading

Pre-lecture Reflections

Lecture

Learning Objectives

Machine learning: supervised vs. unsupervised

Supervised

  • Labeled data
    • Example 1: images labeled with the objects: cat, dog, monkey, elephant, etc.
    • Example 2: medical data labeled with likelihood of cancer
  • Goal: discover a relationship between attributes to predict unknown labels

Unsupervised

  • Unlabeled data
  • Want to discover a relationship between data points
  • Examples:
    • clustering: partition your data into groups of similar objects
    • dimension reduction: for high dimensional data discover important attributes

Machine learning: Predictive vs descriptive vs prescriptive analytics

Descriptive

Use our data to explain what has happened in the past (i.e. find patterns in data that has already been observed)

Predictive

Use our data to predict what may happen in the future (i.e. apply the observed patterns to new observations and predict outcomes)

Prescriptive

Use our data and model to inform decisions we can make to achieve certain outcomes. This assumes a certain level of control over the data inputs to whatever process is being modeled. If no such control exists then prescriptive is not possible.

Supervised learning: Decision trees

Popular machine learning tool for predictive data analysis:

  • rooted tree
  • start at the root and keep going down
  • every internal node labeled with a condition
    • if satisfied, go left
    • if not satisfied, go right
  • leafs labeled with predicted labels

Does a player like bluegrass?

Drawing
Big challenge: finding a decision tree that matches data!
// First lets read in the sample data

:dep csv = { version = "^1.3" }
:dep serde = { version = "^1", features = ["derive"] }
:dep ndarray = { version = "^0.15.6" }
use ndarray::Array2;

// This lets us write `#[derive(Deserialize)]`.
use serde::Deserialize;

// We don't need to derive `Debug` (which doesn't require Serde), but it's a
// good habit to do it for all your types.
//
// Notice that the field names in this struct are NOT in the same order as
// the fields in the CSV data!
#[derive(Debug, Deserialize)]
#[serde(rename_all = "PascalCase")]
struct SerRecord {
    name: String,
    number: usize,
    year_born: usize,
    total_points: usize,
    PPG: f64,
}

let mut rdr = csv::Reader::from_path("players.csv").unwrap();
let mut v:Vec<SerRecord> = Vec::new();
// Loop over each record.
for result in rdr.deserialize() {
    // An error may occur, so abort the program in an unfriendly way.
    // We will make this more friendly later!
    let record:SerRecord = result.expect("a CSV record");
    v.push(record);
}
println!("{:#?}", v);
[
    SerRecord {
        name: "Kareem",
        number: 33,
        year_born: 1947,
        total_points: 38387,
        PPG: 24.6,
    },
    SerRecord {
        name: "Karl",
        number: 32,
        year_born: 1963,
        total_points: 36928,
        PPG: 25.0,
    },
    SerRecord {
        name: "LeBron",
        number: 23,
        year_born: 1984,
        total_points: 36381,
        PPG: 27.0,
    },
    SerRecord {
        name: "Kobe",
        number: 24,
        year_born: 1978,
        total_points: 33643,
        PPG: 25.0,
    },
    SerRecord {
        name: "Michael",
        number: 23,
        year_born: 1963,
        total_points: 32292,
        PPG: 30.1,
    },
]

Heuristics for constructing decision trees -- I

  • Start from a single node with all samples
  • Iterate:
    • select a node
    • use the samples in the node to split it into children using some splitting criteria
    • pass each sample to respective child
  • Label leafs

Let's try to predict what a player's favorite color is?

[Decision tree]

Heuristics for constructing decision trees -- II

  • Start from a single node with all samples
  • Iterate:
    • select a node
    • use the samples in the node to split it into children using some splitting criteria
    • pass each sample to respective child
  • Label leafs

We'll split on PPG.

The goal is to have each leaf be a single class.

Favorite color?

[Decision tree]

Heuristics for constructing decision trees -- III

  • Start from a single node with all samples
  • Iterate:
    • select a node
    • use the samples in the node to split it into children using some splitting criteria
    • pass each sample to respective child
  • Label leafs
Favorite color? [Decision tree]

Heuristics for constructing decision trees -- IV

  • Start from a single node with all samples
  • Iterate:
    • select a node
    • use the samples in the node to split it into children
    • pass each sample to respective child
  • Label leafs

Favorite color?

[Decision tree]

Split selection

  • Typical heuristic: select a split that improves classification most
  • Various measures of "goodness" or "badness":
    • Information gain / Entropy
    • Ginni impurity
    • Variance
  • ID3
  • C4.5
  • C5.0
  • CART (used by linfa-trees, rustlearn, and scikit-learn)

You can read more about those algorithms at https://scikit-learn.org/stable/modules/tree.html#tree-algorithms-id3-c4-5-c5-0-and-cart

and see the mathematical formulation for CART here: https://scikit-learn.org/stable/modules/tree.html#mathematical-formulation

The Gini coefficient and Entropy (Impurity Measures)

  • Let's assume we have k classes that we are trying to decide.

  • We can estimate the probability by:

  • A node M containing N samples has a Gini coefficient defined as follows:

  • Or entropy defined by:

Advantages and disadvantages of decision trees

Advantages:

  • easy to interpret
  • not much data preparation needed
  • categorical and numerical data
  • relatively fast

Disadvantages:

  • can be very sensitive to data changes
  • can create an overcomplicated tree that matches the sample, but not the underlying problem
  • hard to find an optimal tree

Decision tree construction using linfa-tree

https://rust-ml.github.io/linfa/, https://crates.io/crates/linfa

https://docs.rs/linfa-trees/latest/linfa_trees/

Note: ignore machine learning context for now

First, we read our sample data and add information who likes pizza

Visualized

// Rust 2021
//:dep plotters={version = "^0.3.0", default_features = false, features = ["evcxr", "all_series"]}

// Rust 2024
:dep plotters={version = "^0.3.0", default-features = false, features = ["evcxr", "all_series"]}

:dep csv = { version = "^1.3" }
:dep serde = { version = "^1", features = ["derive"] }
:dep ndarray = { version = "^0.15.6" }
use ndarray::Array2;

// This lets us write `#[derive(Deserialize)]`.
use serde::Deserialize;

// We don't need to derive `Debug` (which doesn't require Serde), but it's a
// good habit to do it for all your types.
//
// Notice that the field names in this struct are NOT in the same order as
// the fields in the CSV data!
#[derive(Debug, Deserialize)]
#[serde(rename_all = "PascalCase")]
struct SerRecord {
    name: String,
    number: usize,
    year_born: usize,
    total_points: usize,
    PPG: f64,
}

let mut rdr = csv::Reader::from_path("players.csv").unwrap();
let mut v:Vec<SerRecord> = Vec::new();
// Loop over each record.
for result in rdr.deserialize() {
    // An error may occur, so abort the program in an unfriendly way.
    // We will make this more friendly later!
    let record:SerRecord = result.expect("a CSV record");
    v.push(record);
}

let mut flat_values: Vec<f64> = Vec::new();
for s in &v {
    flat_values.push(s.total_points as f64);
    flat_values.push(s.PPG);
    flat_values.push(s.year_born as f64);
    flat_values.push(s.number as f64);
}
let array = Array2::from_shape_vec((v.len(), 4), flat_values).expect("Error creating ndarray");
println!("{:?}", array);

let likes_pizza = [1,0,0,1,0];

extern crate plotters;
use plotters::prelude::*;
{
let x_values = array.column(0);
let y_values = array.column(1);

evcxr_figure((800, 800), |root| {
    let mut chart = ChartBuilder::on(&root)
    // the caption for the chart
        .caption("Scatter Plot", ("Arial", 20).into_font())
        .x_label_area_size(40)
        .y_label_area_size(40)
        .build_cartesian_2d(32000f64..39000f64, 24f64..31f64)?;
   // the X and Y coordinates spaces for the chart
    chart.configure_mesh()
        .x_desc("Total Points")
        .y_desc("PPG")
        .draw()?;

    chart.draw_series(
            x_values.iter()
                .zip(y_values.iter())
                .zip(likes_pizza.iter())
                .map(|((total, ppg), likes)| {
                    let point = (*total, *ppg);
                    let size = 20;
                    let color = Palette99::pick(*likes as usize % 10); // Choose color based on 'LikesPizza'
                    Circle::new(point, size as i32, color.filled())
                })
        )?;

    Ok(())
})}
The type of the variable v was redefined, so was lost.


[[38387.0, 24.6, 1947.0, 33.0],
 [36928.0, 25.0, 1963.0, 32.0],
 [36381.0, 27.0, 1984.0, 23.0],
 [33643.0, 25.0, 1978.0, 24.0],
 [32292.0, 30.1, 1963.0, 23.0]], shape=[5, 4], strides=[4, 1], layout=Cc (0x5), const ndim=2
Scatter Plot PPG Total Points 24.0 25.0 26.0 27.0 28.0 29.0 30.0 31.0 32000.0 33000.0 34000.0 35000.0 36000.0 37000.0 38000.0 39000.0

Question:

If we were to try to split RED and GREEN based on PPG, where would we split?

Data selection

  • set of inputs: X
  • set of desired outputs:y

Decision tree construction

  • How to decide which feature should be located at the root node,
  • Most accurate feature to serve as internal nodes or leaf nodes,
  • How to divide tree,
  • How to measure the accuracy of splitting tree and many more.

First the preamble with all the dependencies and the code to read the CSV file

//:dep plotters={version = "^0.3.0", default_features = false, features = ["evcxr", "all_series"]}
:dep csv = { version = "^1.3" }
:dep serde = { version = "^1", features = ["derive"] }
:dep ndarray = { version = "^0.15.6" }
:dep linfa = { git = "https://github.com/rust-ml/linfa" }
:dep linfa-trees = { git = "https://github.com/rust-ml/linfa" }


use ndarray::Array2;
use ndarray::array;
use linfa_trees::DecisionTree;
use linfa::prelude::*;
// This lets us write `#[derive(Deserialize)]`.
use serde::Deserialize;
use std::fs::File;
use std::io::Write;

// We don't need to derive `Debug` (which doesn't require Serde), but it's a
// good habit to do it for all your types.
//
// Notice that the field names in this struct are NOT in the same order as
// the fields in the CSV data!
#[derive(Debug, Deserialize)]
#[serde(rename_all = "PascalCase")]
struct SerRecord {
    name: String,
    number: usize,
    year_born: usize,
    total_points: usize,
    PPG: f64,
}

fn process_csv_file() -> Vec<SerRecord> {
  let mut rdr = csv::Reader::from_path("players.csv").unwrap();
  let mut v:Vec<SerRecord> = Vec::new();
  // Loop over each record.
  for result in rdr.deserialize() {
    // An error may occur, so abort the program in an unfriendly way.
    // We will make this more friendly later!
    let record:SerRecord = result.expect("a CSV record");
    v.push(record);
  }
  return v;
}


And the code to construct, train and measure a decision tree

We use the linfa Dataset:

  • The most commonly used typed of dataset.
  • It contains a number of records stored as an Array2 and
  • each record may correspond to multiple targets.
  • The targets are stored as an Array1 or Array2.

And construct a DecisionTree structure.

Then export to a TeX file and render the tree.

fn main() {
  let mut v = process_csv_file();
  let mut flat_values: Vec<f64> = Vec::new();
  for s in &v {
    flat_values.push(s.total_points as f64);
    flat_values.push(s.PPG);
    flat_values.push(s.year_born as f64);
  }
  let array = Array2::from_shape_vec((v.len(), 3), flat_values).expect("Error creating ndarray");

  let likes_pizza = array![1,0,0,1,0];

  let dataset = Dataset::new(array, likes_pizza).with_feature_names(vec!["total points", "PPG", "year born"]);
  let decision_tree = DecisionTree::params()
        .max_depth(Some(2))
        .fit(&dataset)
        .unwrap();

  let accuracy = decision_tree.predict(&dataset).confusion_matrix(&dataset).unwrap().accuracy();
    
  println!("The accuracy is: {:?}", accuracy);

  let mut tikz = File::create("decision_tree_example.tex").unwrap();
    tikz.write_all(
        decision_tree
            .export_to_tikz()
            .with_legend()
            .to_string()
            .as_bytes(),
    )
    .unwrap();
    println!(" => generate tree description with `latex decision_tree_example.tex`!");
}

main();
The type of the variable rdr was redefined, so was lost.
The type of the variable array was redefined, so was lost.


The accuracy is: 0.8
 => generate tree description with `latex decision_tree_example.tex`!
use std::process::Command;
let output = Command::new("pwd").output().expect("Failed to execute");
let output = Command::new("pdflatex").arg("decision_tree_example.tex").output().expect("Failed to execute");
let output = Command::new("sips").args(["-s", "format", "png", "decision_tree_example.pdf", "--out", "decision_tree_example.png"])
.output().expect("Failed to execute");
Updated Image

"Impurity" is another term for Gini index.

Note that Impurity decreases as we do more splits.

Also note that we can overfit the dataset. Will do worse on new data.

Techniques to avoid overfitting include setting a maximum tree depth.

Technical Coding Challenge

Coding Challenge

Coding Challenge Review