« home

MADE

machine learninggenerative modelingprobabilitystatisticscetztikz

TikZ-reproduction of fig. 1 from the paper MADE: Masked Autoencoder for Distribution Estimation (arxiv:1502.03509).


MADE

  Download

PNGPDFSVG

  Code

  LaTeX

made.tex (110 lines)

\documentclass[tikz]{standalone}

\usepackage{xstring}

\usetikzlibrary{calc,positioning}

\newcommand\drawNodes[2]{
  % #1 (str): namespace
  % #2 (list[list[str]]): list of labels to print in the node of each neuron
  \foreach \neurons [count=\lyrIdx] in #2 {
    \StrCount{\neurons}{,}[\layerLen] % use xstring package to save each layer size into \layerLen macro
    \foreach \n [count=\nIdx] in \neurons
      \node[neuron] (#1-\lyrIdx-\nIdx) at (\layerLen/2-\nIdx, 1.5*\lyrIdx) {\n};
  }
}

\newcommand\denselyConnectNodes[2]{
  % #1 (str): namespace
  % #2 (list[int]): number of nodes in each layer
  \foreach \n [count=\lyrIdx, remember=\lyrIdx as \previdx, remember=\n as \prevn] in #2 {
    \foreach \y in {1,...,\n} {
      \ifnum \lyrIdx > 1
        \foreach \x in {1,...,\prevn}
          \draw[->] (#1-\previdx-\x) -- (#1-\lyrIdx-\y);
      \fi
    }
  }
}

\newcommand\connectSomeNodes[2]{
  % #1 (str): namespace
  % #2 (list[list[list[int]]]): for each node in each layer, list all connected nodes in the next layer
  \foreach \layer [count=\lyrIdx, evaluate=\lyrIdx as \nextLyr using int(\lyrIdx+1)] in #2
    \foreach \neuron [count=\nIdx] in \layer
        \foreach \edge in \neuron
          \draw[->] (#1-\lyrIdx-\nIdx) -- (#1-\nextLyr-\edge);
}

\begin{document}
\begin{tikzpicture}[
    shorten >=1pt, shorten <=1pt,
    neuron/.style={circle, draw, minimum size=4ex, thick},
    legend/.style={font=\large\bfseries},
  ]

  % Fully-connected neural net
  \drawNodes{fcnn}{{{,,}, {,,,}, {,,,}, {,,}}}
  \denselyConnectNodes{fcnn}{{3, 4, 4, 3}}

  \path (fcnn-1-1) -- (fcnn-2-1) node[midway, right=1ex] (W1) {$W_1$};
  \path (fcnn-2-1) -- (fcnn-3-1) node[midway, right=1ex] (W2) {$W_2$};
  \path (fcnn-3-1) -- (fcnn-4-1) node[midway, right=1ex] (V) {$V$};

  % MADE net
  \begin{scope}[xshift=9cm]
    \drawNodes{made}{{{3,1,2}, {2,1,2,2}, {1,2,2,1}, {3,1,2}}}
    \connectSomeNodes{made}{{
          {{}, {1,2,3,4}, {1,3,4}},
          {{2,3}, {1,2,3,4}, {2,3}, {2,3}},
          {{1,3}, {1}, {1}, {1,3}},
        }}
  \end{scope}

  % Input + output labels
  \foreach \idx in {1,2,3} {
      \node[below=0 of fcnn-1-\idx] {$x_\idx$};
      \node[above=0 of fcnn-4-\idx] {$\hat x_\idx$};
      \node[below=0 of made-1-\idx] {$x_\idx$};
    }

  % MADE output labels
  \node[xshift=2.5ex, above=0 of made-4-1] {$p(x_3|x_2)$};
  \node[above=0 of made-4-2] {$p(x_2)$};
  \node[xshift=-4ex, above=0 of made-4-3] {$p(x_1|x_2,x_3)$};

  % Bottom legend
  \node[legend, below=of fcnn-1-2] (encoder) {autoencoder};
  \node[legend, below=of made-1-2] (made) {MADE};
  \node[legend, right=2.5cm of encoder] (masks) {masks};
  \node[legend, yshift=-1pt] (masks) at ($(encoder)!0.55!(masks)$) {\texttimes};
  \node[legend, yshift=-1pt] (masks) at ($(masks)!0.65!(made)$) {$\longrightarrow$};

  % Mask matrices
  \begin{scope}[shift={(3cm,5cm)}, scale=0.4]
    \draw (0,0) grid (4,3);
    \node at (-1.8,1.5) {$M_V =$};
    \fill[black] (0,1) rectangle ++(4,1);
    \fill[black] (1,0) rectangle ++(2,1);

    \begin{scope}[yshift=-5cm]
      \draw (0,0) grid (4,4);
      \node at (-1.8,2) {$M_{W_2} =$};
      \fill[black] (0,0) rectangle ++(1,1);
      \fill[black] (0,3) rectangle ++(1,1);
      \fill[black] (2,0) rectangle ++(2,1);
      \fill[black] (2,3) rectangle ++(2,1);
    \end{scope}

    \begin{scope}[yshift=-10cm]
      \draw (0,0) grid (3,4);
      \node at (-1.8,2) {$M_{W_1} =$};
      \fill[black] (0,0) rectangle ++(1,4);
      \fill[black] (2,2) rectangle ++(1,1);
    \end{scope}

  \end{scope}

\end{tikzpicture}
\end{document}

  Typst

made.typ (211 lines)

#import "@preview/cetz:0.3.2": canvas, draw

#set page(width: auto, height: auto, margin: 8pt)

#canvas({
  import draw: line, content, circle, rect

  // Styles
  let arrow-style = (mark: (end: "stealth", fill: black, scale: 0.5, offset: 1pt), stroke: .5pt)
  let node-style = (stroke: black + 0.7pt)
  let layer-sep = 2 // now vertical separation
  let horiz-sep = 1.3 // horizontal separation between nodes in a layer

  // Helper to draw a vertical layer of nodes
  let draw-layer(y, nodes, prefix: "", masks: none, x-offset: 0) = {
    for i in range(nodes) {
      let x = (nodes - 1) * horiz-sep / 2 - i * horiz-sep + x-offset
      circle(
        (x, y),
        radius: 0.3,
        name: prefix + str(i),
        ..node-style,
      )
      if masks != none {
        content((x, y), str(masks.at(i)))
      }
    }
  }

  // Helper to connect all nodes between layers
  let connect-layers(from-prefix, to-prefix, from-nodes, to-nodes) = {
    for i in range(from-nodes) {
      for j in range(to-nodes) {
        line(
          from-prefix + str(i),
          to-prefix + str(j),
          ..arrow-style,
        )
      }
    }
  }

  // Adjust x-positions for the three components
  let fcnn-x = -5 // Autoencoder on the left
  let mask-x = 0 // Masks in the middle
  let made-x = 5 // MADE network on the right

  // Draw autoencoder (left side)
  // Draw all layers from bottom to top
  for (idx, (y, nodes)) in (
    (0, 3),
    (layer-sep, 4),
    (2 * layer-sep, 4),
    (3 * layer-sep, 3),
  ).enumerate() {
    draw-layer(y, nodes, prefix: "fcnn" + str(idx) + "-", x-offset: fcnn-x)
  }

  // Connect autoencoder layers and add weight labels
  for (from-idx, to-idx, layer-label) in ((0, 1, $W_1$), (1, 2, $W_2$), (2, 3, $V$)) {
    let from-prefix = "fcnn" + str(from-idx) + "-"
    let to-prefix = "fcnn" + str(to-idx) + "-"
    let from-nodes = if from-idx == 0 { 3 } else { 4 }
    let to-nodes = if to-idx == 3 { 3 } else { 4 }

    // Draw all connections
    for i in range(from-nodes) {
      for j in range(to-nodes) {
        line(
          from-prefix + str(i),
          to-prefix + str(j),
          ..arrow-style,
        )
      }
    }

    // Add weight label next to middle connection
    let mid-y = (from-idx + 0.5) * layer-sep
    content(
      (fcnn-x + 2.1 + if layer-label == $W_2$ { 0.3 } else { 0 }, mid-y),
      layer-label,
    )
  }

  // Draw mask matrices in the middle
  let mask-base-size = 1.25
  let mask-sep = 2.5

  // Helper to draw grid lines and filled cells
  let draw-mask(x, y, rows, cols, filled-cells) = {
    let width = mask-base-size * cols / 3 // normalize to make 3 cols = base size
    let height = mask-base-size * rows / 3 // normalize to make 3 rows = base size
    let cell-width = width / cols
    let cell-height = height / rows

    // Draw grid lines
    for i in range(cols + 1) {
      let x-pos = x - width / 2 + i * cell-width
      line(
        (x-pos, y),
        (x-pos, y + height),
        stroke: .2pt,
      )
    }
    for i in range(rows + 1) {
      let y-pos = y + i * cell-height
      line(
        (x - width / 2, y-pos),
        (x + width / 2, y-pos),
        stroke: .2pt,
      )
    }

    // Fill cells
    for (row, col) in filled-cells {
      rect(
        (x - width / 2 + col * cell-width, y + (rows - row - 1) * cell-height),
        (x - width / 2 + (col + 1) * cell-width, y + (rows - row) * cell-height),
        fill: black,
      )
    }
  }

  // Draw M_V mask (top, 2x4)
  let mv-width = mask-base-size * 4 / 3
  let mv-height = mask-base-size * 2 / 3
  rect(
    (mask-x - mv-width / 2, 2 * mask-sep),
    (mask-x + mv-width / 2, 2 * mask-sep + mv-height),
    name: "mv-box",
  )
  content((rel: (-.8, 0), to: "mv-box.west"), $M_V =$)
  draw-mask(mask-x, 2 * mask-sep, 2, 4, ((0, 1), (0, 2), (1, 0), (1, 1), (1, 2), (1, 3)))

  // Draw M_W2 mask (middle, 4x4)
  let mw2-size = mask-base-size * 4 / 3
  rect(
    (mask-x - mw2-size / 2, mask-sep),
    (mask-x + mw2-size / 2, mask-sep + mw2-size),
    name: "mw2-box",
  )
  content((rel: (-.8, 0), to: "mw2-box.west"), $M_(W_2) =$)
  draw-mask(mask-x, mask-sep, 4, 4, ((0, 0), (0, 2), (0, 3), (3, 0), (3, 2), (3, 3)))

  // Draw M_W1 mask (bottom, 4x3)
  let mw1-width = mask-base-size
  let mw1-height = mask-base-size * 4 / 3
  rect(
    (mask-x - mw1-width / 2, 0),
    (mask-x + mw1-width / 2, mw1-height),
    name: "mw1-box",
  )
  content((rel: (-.8, 0), to: "mw1-box.west"), $M_(W_1) =$)
  draw-mask(mask-x, 0, 4, 3, ((0, 0), (1, 0), (2, 0), (3, 0), (2, 2)))

  // Draw MADE (right side)
  // Draw MADE layers with masks
  for (idx, (y, nodes, masks)) in (
    (0, 3, (3, 1, 2)),
    (layer-sep, 4, (2, 1, 2, 2)),
    (2 * layer-sep, 4, (1, 2, 2, 1)),
    (3 * layer-sep, 3, (3, 1, 2)),
  ).enumerate() {
    draw-layer(y, nodes, prefix: "made" + str(idx) + "-", masks: masks, x-offset: made-x)
  }

  // Connect MADE layers with masked connections
  // Layer 1 -> 2
  for (from, tos) in ((0, ()), (1, (0, 1, 2, 3)), (2, (0, 2, 3))) {
    let from-node = "made0-" + str(from)
    for to in tos {
      line("made0-" + str(from), "made1-" + str(to), ..arrow-style)
    }
  }

  // Layer 2 -> 3
  for (from, tos) in ((0, (1, 2)), (1, (0, 1, 2, 3)), (2, (1, 2)), (3, (1, 2))) {
    for to in tos {
      line("made1-" + str(from), "made2-" + str(to), ..arrow-style)
    }
  }

  // Layer 3 -> 4
  for (from, tos) in ((0, (0, 2)), (1, (0,)), (2, (0,)), (3, (0, 2))) {
    for to in tos {
      line("made2-" + str(from), "made3-" + str(to), ..arrow-style)
    }
  }

  // Add input and output labels
  for i in range(3) {
    content((rel: (0, -0.6), to: "fcnn0-" + str(i)), $x_#i$)
    content((rel: (0, 0.6), to: "fcnn3-" + str(i)), $hat(x)_#i$)
    content((rel: (0, -0.6), to: "made0-" + str(i)), $x_#i$)
  }

  // Add MADE output labels
  content((rel: (0, 0.6), to: "made3-0"), $p(x_3|x_2)$)
  content((rel: (0, 0.6), to: "made3-1"), $p(x_2)$)
  content((rel: (-.2, 0.6), to: "made3-2"), $p(x_1|x_2,x_3)$)

  // Add bottom labels with vertical alignment and larger font
  let label-size = 1.5em
  let bottom-y = -1.5 // Fixed baseline for all bottom labels
  content((fcnn-x, bottom-y), text(weight: "bold", size: label-size)[autoencoder])
  content((mask-x - 2, bottom-y), text(weight: "bold", size: label-size)[$times$])
  content((mask-x, bottom-y), text(weight: "bold", size: label-size)[masks])
  content((mask-x + 2, bottom-y), text(weight: "bold", size: label-size)[$arrow.r$])
  content((made-x, bottom-y), text(weight: "bold", size: label-size)[MADE])
})