« home

Variational Autoencoder

machine learninggenerative modelingBayesianneural networksprobabilitycetztikz

Variational autoencoder architecture. The earliest type of generative machine learning model. Inspired by https://towardsdatascience.com/intuitively-understanding-variational-autoencoders-1bfe67eb5daf.


Variational Autoencoder

  Download

PNGPDFSVG

  Code

  LaTeX

variational-autoencoder.tex (78 lines)

\documentclass[tikz]{standalone}

\usepackage{xstring}

\usetikzlibrary{fit,positioning}

\newcommand\drawNodes[2]{
  % #1 (str): namespace
  % #2 (list[list[str]]): list of labels to print in the node of each neuron
  \foreach \neurons [count=\lyrIdx] in #2 {
    \StrCount{\neurons}{,}[\layerLen] % use xstring package to save each layer size into \layerLen macro
    \foreach \n [count=\nIdx] in \neurons
      \node[neuron] (#1-\lyrIdx-\nIdx) at (2*\lyrIdx, \layerLen/2-1.4*\nIdx) {\n};
  }
}

\newcommand\denselyConnectNodes[2]{
  % #1 (str): namespace
  % #2 (list[int]): number of nodes in each layer
  \foreach \n [count=\lyrIdx, remember=\lyrIdx as \previdx, remember=\n as \prevn] in #2 {
    \foreach \y in {1,...,\n} {
      \ifnum \lyrIdx > 1
        \foreach \x in {1,...,\prevn}
          \draw[->] (#1-\previdx-\x) -- (#1-\lyrIdx-\y);
      \fi
    }
  }
}

\begin{document}
\begin{tikzpicture}[
    shorten >=1pt, shorten <=1pt,
    neuron/.style={circle, draw, minimum size=4ex, thick},
    legend/.style={font=\large\bfseries},
  ]

  % encoder
  \drawNodes{encoder}{{{,,,,}, {,,,}, {,,}}}
  \denselyConnectNodes{encoder}{{5, 4, 3}}

  % decoder
  \begin{scope}[xshift=11cm]
    \drawNodes{decoder}{{{,,}, {,,,}, {,,,,}}}
    \denselyConnectNodes{decoder}{{3, 4, 5}}
  \end{scope}

  % mu, sigma, sample nodes
  \foreach \idx in {1,...,3} {
      \coordinate[neuron, right=2 of encoder-3-2, yshift=\idx cm,, fill=yellow, fill opacity=0.2] (mu-\idx);
      \coordinate[neuron, right=2 of encoder-3-2, yshift=-\idx cm, fill=blue, fill opacity=0.1] (sigma-\idx);
      \coordinate[neuron, right=4 of encoder-3-2, yshift=\idx cm-2cm, fill=green, fill opacity=0.1] (sample-\idx);
    }

  % mu, sigma, sample boxes
  \node [label=$\mu$, fit=(mu-1) (mu-3), draw, fill=yellow, opacity=0.45] (mu) {};
  \node [label=$\sigma$, fit=(sigma-1) (sigma-3), draw, fill=blue, opacity=0.3] (sigma) {};
  \node [label=sample, fit=(sample-1) (sample-3), draw, fill=green, opacity=0.3] (sample) {};

  % mu, sigma, sample connections
  \draw[->] (mu.east) edge (sample.west) (sigma.east) -- (sample.west);
  \foreach \a in {1,2,3}
  \foreach \b in {1,2,3} {
      \draw[->] (encoder-3-\a) -- (mu-\b);
      \draw[->] (encoder-3-\a) -- (sigma-\b);
      \draw[->] (sample-\a) -- (decoder-1-\b);
    }

  % input + output labels
  \foreach \idx in {1,...,5} {
      \node[left=0 of encoder-1-\idx] {$x_\idx$};
      \node[right=0 of decoder-3-\idx] {$\hat x_\idx$};
    }
  \node[above=0.1 of encoder-1-1] {input};
  \node[above=0.1 of decoder-3-1] {output};

\end{tikzpicture}
\end{document}

  Typst

variational-autoencoder.typ (164 lines)

#import "@preview/cetz:0.3.2": canvas, draw

#set page(width: auto, height: auto, margin: 8pt)

#canvas({
  import draw: line, circle, content, rect

  let node-style = (stroke: black + 1pt, fill: white)
  let layer-sep = 2 // Horizontal separation between layers
  let node-sep = 1.4 // Vertical separation between nodes
  let arrow-style = (stroke: .5pt, mark: (end: "stealth", fill: black, scale: .3))

  // Helper function to draw a layer of nodes
  let draw-layer(x, nodes, prefix: "") = {
    let top-y = nodes / 2
    let bottom-y = nodes / 2 - node-sep * (nodes - 1)

    for ii in range(nodes) {
      circle(
        (x, nodes / 2 - node-sep * ii),
        radius: 0.3,
        name: prefix + str(ii + 1),
        ..node-style,
      )
    }
    // Create named points for the layer bounds
    circle((x, top-y), radius: 0, name: prefix + "-top", fill: none)
    circle((x, bottom-y), radius: 0, name: prefix + "-bottom", fill: none)
  }

  // Helper to connect all nodes between layers
  let connect-layers(from-prefix, to-prefix, from-nodes, to-nodes) = {
    for ii in range(from-nodes) {
      for jj in range(to-nodes) {
        line(
          (from-prefix + str(ii + 1)),
          (to-prefix + str(jj + 1)),
          ..arrow-style,
        )
      }
    }
  }

  // Draw encoder
  draw-layer(0, 5, prefix: "e1") // Input layer
  draw-layer(layer-sep, 4, prefix: "e2") // Hidden layer
  draw-layer(layer-sep * 2, 3, prefix: "e3") // Output layer

  // Connect encoder layers
  connect-layers("e1", "e2", 5, 4)
  connect-layers("e2", "e3", 4, 3)

  // Draw mu nodes
  let mu-x = layer-sep * 3
  for ii in range(3) {
    circle(
      (mu-x, 1.5 + ii),
      radius: 0.4,
      name: "mu" + str(ii + 1),
      fill: rgb(100%, 100%, 0%, 20%),
      ..node-style,
    )
  }

  // Draw sigma nodes
  for ii in range(3) {
    circle(
      (mu-x, -1.5 - ii),
      radius: 0.4,
      name: "sigma" + str(ii + 1),
      fill: rgb(0%, 0%, 100%, 10%),
      ..node-style,
    )
  }

  // Draw sample nodes
  let sample-x = mu-x + layer-sep
  for ii in range(3) {
    circle(
      (sample-x, ii - 1),
      radius: 0.4,
      name: "sample" + str(ii + 1),
      fill: rgb(0%, 100%, 0%, 10%),
      ..node-style,
    )
  }

  // Draw boxes around mu, sigma, sample nodes
  rect(
    (mu-x - 0.5, 1),
    (mu-x + 0.5, 4),
    fill: rgb(100%, 100%, 0%, 45%),
    name: "mu-box",
    stroke: .1pt,
  )
  content("mu-box.north", $mu$, anchor: "south", padding: 3pt)

  rect(
    (mu-x - 0.5, -4),
    (mu-x + 0.5, -1),
    fill: rgb(0%, 0%, 100%, 30%),
    name: "sigma-box",
    stroke: .1pt,
  )
  content("sigma-box.north", $sigma$, anchor: "south", padding: 3pt)

  rect(
    (sample-x - 0.5, -1.5),
    (sample-x + 0.5, 1.5),
    fill: rgb(0%, 100%, 0%, 30%),
    name: "sample-box",
    stroke: .1pt,
  )
  content("sample-box.north", text(size: 0.8em)[Sample], anchor: "south", padding: 3pt)

  // Connect encoder to mu and sigma
  for ii in range(3) {
    for jj in range(3) {
      line(("e3" + str(ii + 1)), ("mu" + str(jj + 1)), ..arrow-style)
      line(("e3" + str(ii + 1)), ("sigma" + str(jj + 1)), ..arrow-style)
    }
  }

  // Connect mu and sigma nodes to sample nodes
  line("mu-box", "sample-box", ..arrow-style)
  line("sigma-box", "sample-box", ..arrow-style)

  // Draw decoder (mirrored structure of encoder)
  let decoder-x = sample-x + layer-sep
  draw-layer(decoder-x, 3, prefix: "d1")
  draw-layer(decoder-x + layer-sep, 4, prefix: "d2")
  draw-layer(decoder-x + layer-sep * 2, 5, prefix: "d3")

  // Connect decoder layers
  connect-layers("d1", "d2", 3, 4)
  connect-layers("d2", "d3", 4, 5)

  // Connect sample to decoder
  for ii in range(3) {
    for jj in range(3) {
      line(("sample" + str(ii + 1)), ("d1" + str(jj + 1)), ..arrow-style)
    }
  }

  // Add input and output labels
  for ii in range(5) {
    content(
      "e1" + str(ii + 1) + ".west",
      $x_#(ii + 1)$,
      anchor: "east",
      padding: 3pt,
    )
    content(
      "d3" + str(ii + 1) + ".east",
      $hat(x)_#(ii + 1)$,
      anchor: "west",
      padding: 3pt,
    )
  }

  content("e11.north", text(weight: "regular")[Input], anchor: "south", padding: 5pt)
  content("d31.north", text(weight: "regular")[Output], anchor: "south", padding: 5pt)
})