« home

Random Forest

machine learningensemblesclassificationregressioncetztikz

Diagram of the random forest (RF) algorithm (Breiman 2001). RFs are ensembles model consisting of binary decision trees that predicts the mode of individual tree predictions in classification or the mean in regression. Every node in a decision tree is a condition on a single feature, chosen to split the dataset into two so that similar samples end up in the same set. RFs are inspectable, invariant to scaling and other feature transformations, robust to inclusion of irrelevant features and can estimate feature importance via mean decrease in impurity (MDI).


Random Forest

  Download

PNGPDFSVG

  Code

  LaTeX

random-forest.tex (45 lines)

\documentclass[tikz]{standalone}

\usepackage{forest}
\usetikzlibrary{fit,positioning}

\tikzset{
  red arrow/.style={
    midway,red,sloped,fill, minimum height=3cm, single arrow, single arrow head extend=.5cm, single arrow head indent=.25cm,xscale=0.3,yscale=0.15,
    allow upside down
  },
  black arrow/.style 2 args={-stealth, shorten >=#1, shorten <=#2},
  black arrow/.default={1mm}{1mm},
  tree box/.style={draw, rounded corners, inner sep=0.5em},
  node box/.style={white, draw=black, text=black, rectangle, rounded corners},
}

\begin{document}
\begin{forest}
  for tree={l sep=3em, s sep=2em, anchor=center, inner sep=0.4em, fill=blue!50, circle, where level=2{no edge}{}}
  [
  Training Data, node box
  [sample and feature bagging, node box, alias=bagging, above=3em
  [,red!70,alias=a1[[,alias=a2][]][,red!70,edge label={node[above=1ex,red arrow]{}}[[][]][,red!70,edge label={node[above=1ex,red arrow]{}}[,red!70,edge label={node[below=1ex,red arrow]{}}][,alias=a3]]]]
  [,red!70,alias=b1[,red!70,edge label={node[below=1ex,red arrow]{}}[[,alias=b2][]][,red!70,edge label={node[above=1ex,red arrow]{}}]][[][[][,alias=b3]]]]
  [~~~$\dots$~,scale=2,no edge,fill=none,yshift=-3em]
  [,red!70,alias=c1[[,alias=c2][]][,red!70,edge label={node[above=1ex,red arrow]{}}[,red!70,edge label={node[above=1ex,red arrow]{}}[,alias=c3][,red!70,edge label={node[above=1ex,red arrow]{}}]][,alias=c4]]]]
  ]
  \node[tree box, fit=(a1)(a2)(a3)] (t1) {};
  \node[tree box, fit=(b1)(b2)(b3)] (t2) {};
  \node[tree box, fit=(c1)(c2)(c3)(c4)] (tn) {};
  \node[below right=0.5em, inner sep=0pt] at (t1.north west) {Tree 1};
  \node[below right=0.5em, inner sep=0pt] at (t2.north west) {Tree 2};
  \node[below right=0.5em, inner sep=0pt] at (tn.north west) {Tree $n$};
  \path (t1.south west)--(tn.south east) node[midway,below=4em, node box] (mean) {mean in regression or majority vote in classification};
  \node[below=3em of mean, node box] (pred) {prediction};
  \draw[black arrow={5mm}{4mm}] (bagging) -- (t1.north);
  \draw[black arrow] (bagging) -- (t2.north);
  \draw[black arrow={5mm}{4mm}] (bagging) -- (tn.north);
  \draw[black arrow={5mm}{5mm}] (t1.south) -- (mean);
  \draw[black arrow] (t2.south) -- (mean);
  \draw[black arrow={5mm}{5mm}] (tn.south) -- (mean);
  \draw[black arrow] (mean) -- (pred);
\end{forest}
\end{document}

  Typst

random-forest.typ (280 lines)

#import "@preview/cetz:0.3.3": canvas, draw
#import draw: line, content, circle, rect, group, set-style, on-layer

#set page(width: auto, height: auto, margin: 8pt)

// TODO add arrows next to each connecting line between red nodes

// Define styles and constants
#let arrow-style = (
  mark: (end: "stealth", fill: black, scale: 0.5),
  stroke: 0.5pt,
)
#let line-style = (stroke: 0.5pt)
#let red-arrow-style = (
  stroke: red,
  fill: red,
  mark: (end: "stealth", scale: 0.6, offset: 0.2, start: (symbol: "|", width: 1pt)),
  stroke-width: 1.2pt,
)
#let node-box-style = (
  stroke: 0.3pt,
  fill: white,
  radius: 3pt,
)
#let tree-box-style = (
  stroke: 0.3pt,
  fill: none,
  radius: 3pt,
)

#canvas({
  // Set global content frame style
  set-style(
    content: (
      frame: "rect",
      stroke: 0.1pt,
      fill: white,
      inset: 3pt,
      radius: 3pt,
      padding: (3pt, 5pt, 2pt),
    ),
  )

  // Helper function for tree nodes
  let tree-node(position, fill: blue.lighten(40%), name: none) = {
    circle(position, radius: 0.2, fill: fill, stroke: 0.3pt, name: name)
  }

  // Helper function to draw a complete tree
  let draw-tree(x-position, tree-name, path-nodes: ()) = {
    let y-position = -2
    let level-spacing = 1.2 // Vertical spacing between levels
    let node-spacing = 1.0 // Base horizontal spacing between nodes
    let arrow-offset = 0.15 // Offset for red arrows

    group(
      name: tree-name,
      {
        // Root node (level 0)
        tree-node(
          (x-position, y-position),
          name: "root",
          fill: if "root" in path-nodes { red.lighten(30%) } else { blue.lighten(40%) },
        )

        // Level 1 nodes (2 nodes, evenly spaced)
        tree-node(
          (x-position - node-spacing, y-position - level-spacing),
          name: "left",
          fill: if "left" in path-nodes { red.lighten(30%) } else { blue.lighten(40%) },
        )
        tree-node(
          (x-position + node-spacing, y-position - level-spacing),
          name: "right",
          fill: if "right" in path-nodes { red.lighten(30%) } else { blue.lighten(40%) },
        )

        // Connect nodes with black lines
        line("root", "left", ..line-style)
        line("root", "right", ..line-style)

        // Add red arrows for paths
        if "root" in path-nodes and "left" in path-nodes {
          line(
            (x-position - arrow-offset, y-position),
            (x-position - node-spacing - arrow-offset, y-position - level-spacing),
            ..red-arrow-style,
          )
        }
        if "root" in path-nodes and "right" in path-nodes {
          line(
            (x-position + arrow-offset, y-position),
            (x-position + node-spacing + arrow-offset, y-position - level-spacing),
            ..red-arrow-style,
          )
        }

        // Level 2 nodes - different structure for each tree
        if tree-name == "tree1" {
          // Tree 1: left node splits into 3, right into 1
          let left-spacing = node-spacing * 0.67 // Tighter spacing for 3 nodes
          for (node-idx, offset-factor) in (-1, 0, 1).enumerate() {
            let child-x = x-position - node-spacing + offset-factor * left-spacing
            let child-y = y-position - 2 * level-spacing
            let node-name = "l" + str(node-idx)

            tree-node(
              (child-x, child-y),
              name: node-name,
              fill: if node-name in path-nodes { red.lighten(30%) } else { blue.lighten(40%) },
            )
            // Black line
            line("left", node-name, ..line-style)
            // Red arrow if in path
            if "left" in path-nodes and node-name in path-nodes {
              // Determine offset direction based on node position
              let offset-direction = if offset-factor < 0 { -arrow-offset } else { arrow-offset }
              line(
                (x-position - node-spacing + offset-direction, y-position - level-spacing),
                (child-x + offset-direction, child-y),
                ..red-arrow-style,
              )
            }
          }
          // Right side
          let right-child-x = x-position + node-spacing
          let right-child-y = y-position - 2 * level-spacing
          tree-node(
            (right-child-x, right-child-y),
            name: "r0",
            fill: if "r0" in path-nodes { red.lighten(30%) } else { blue.lighten(40%) },
          )
          line("right", "r0", ..line-style)
          if "right" in path-nodes and "r0" in path-nodes {
            line(
              (x-position + node-spacing + arrow-offset, y-position - level-spacing),
              (right-child-x + arrow-offset, right-child-y),
              ..red-arrow-style,
            )
          }
        } else if tree-name == "tree2" {
          // Tree 2: left node splits into 2, right into 2
          let side-spacing = node-spacing * 0.5 // Half spacing for 2 nodes each side
          for (node-idx, offset-factor) in (-1, 1).enumerate() {
            let child-x = x-position - node-spacing + offset-factor * side-spacing
            let child-y = y-position - 2 * level-spacing
            let node-name = "l" + str(node-idx)

            tree-node(
              (child-x, child-y),
              name: node-name,
              fill: if node-name in path-nodes { red.lighten(30%) } else { blue.lighten(40%) },
            )
            // Black line
            line("left", node-name, ..line-style)
            // Red arrow if in path
            if "left" in path-nodes and node-name in path-nodes {
              // Determine offset direction based on node position
              let offset-direction = if offset-factor < 0 { -arrow-offset } else { arrow-offset }
              line(
                (x-position - node-spacing + offset-direction, y-position - level-spacing),
                (child-x + offset-direction, child-y),
                ..red-arrow-style,
              )
            }
          }
          for (node-idx, offset-factor) in (-1, 1).enumerate() {
            let child-x = x-position + node-spacing + offset-factor * side-spacing
            let child-y = y-position - 2 * level-spacing
            let node-name = "r" + str(node-idx)

            tree-node(
              (child-x, child-y),
              name: node-name,
              fill: if node-name in path-nodes { red.lighten(30%) } else { blue.lighten(40%) },
            )
            // Black line
            line("right", node-name, ..line-style)
            // Red arrow if in path
            if "right" in path-nodes and node-name in path-nodes {
              // Determine offset direction based on node position
              let offset-direction = if offset-factor < 0 { -arrow-offset } else { arrow-offset }
              line(
                (x-position + node-spacing + offset-direction, y-position - level-spacing),
                (child-x + offset-direction, child-y),
                ..red-arrow-style,
              )
            }
          }
        } else {
          // Tree 3: left node splits into 1, right into 3
          let left-child-x = x-position - node-spacing
          let left-child-y = y-position - 2 * level-spacing
          tree-node(
            (left-child-x, left-child-y),
            name: "l0",
            fill: if "l0" in path-nodes { red.lighten(30%) } else { blue.lighten(40%) },
          )
          line("left", "l0", ..line-style)
          if "left" in path-nodes and "l0" in path-nodes {
            line(
              (x-position - node-spacing - arrow-offset, y-position - level-spacing),
              (left-child-x - arrow-offset, left-child-y),
              ..red-arrow-style,
            )
          }

          let right-spacing = node-spacing * 0.67 // Tighter spacing for 3 nodes
          for (node-idx, offset-factor) in (-1, 0, 1).enumerate() {
            let child-x = x-position + node-spacing + offset-factor * right-spacing
            let child-y = y-position - 2 * level-spacing
            let node-name = "r" + str(node-idx)

            tree-node(
              (child-x, child-y),
              name: node-name,
              fill: if node-name in path-nodes { red.lighten(30%) } else { blue.lighten(40%) },
            )
            // Black line
            line("right", node-name, ..line-style)
            // Red arrow if in path
            if "right" in path-nodes and node-name in path-nodes {
              // Determine offset direction based on node position
              let offset-direction = if offset-factor < 0 { -arrow-offset } else { arrow-offset }
              line(
                (x-position + node-spacing + offset-direction, y-position - level-spacing),
                (child-x + offset-direction, child-y),
                ..red-arrow-style,
              )
            }
          }
        }

        // Tree box and label (made slightly wider)
        rect(
          (x-position - 2.6, y-position - 2.6 * level-spacing),
          (x-position + 2.6, y-position + 0.5),
          stroke: 0.3pt,
          fill: none,
          radius: 3pt,
          name: tree-name,
        )
        content(
          (x-position - 1.7, y-position + 0.1),
          [Tree #if tree-name == "tree3" { $n$ } else { tree-name.last() }],
        )
      },
    )
  }

  // Draw main nodes
  content((0, 0), [Training Data], name: "training")
  content((0, 1), [Sample and Feature Bagging], name: "bagging")

  // Draw trees with different paths highlighted
  draw-tree(-6, "tree1", path-nodes: ("root", "left", "l1"))
  draw-tree(0, "tree2", path-nodes: ("root", "right", "r0"))
  draw-tree(6.75, "tree3", path-nodes: ("root", "right", "r2"))

  // Add dots between trees
  content((3.375, -3.5), text(size: 1.2em)[$dots.c$])

  // Draw mean/prediction nodes
  content(
    (0, -6.5),
    align(center)[Mean in regression or\ majority vote in classification],
    name: "mean",
    padding: (5pt, 5pt, 1.5pt),
  )
  content((0, -8), [Prediction], name: "pred", padding: (4pt, 5pt, 3pt))

  // Connect everything
  for tree-name in ("tree1", "tree2", "tree3") {
    line("training", tree-name + ".north", ..arrow-style)
    line(tree-name + ".south", "mean", ..arrow-style)
  }
  line("bagging", "training", ..arrow-style)
  line("mean", "pred", ..arrow-style)
})