« home

Normalizing Flow Coupling Layer

machine learninggenerative modelingprobabilitystatisticscetztikz

Simple 2d example illustrating the role of the Jacobian determinant in the change of variables formula. Inspired by Ari Seff in https://youtu.be/i7LjDvsLWCg?t=250.


Normalizing Flow Coupling Layer

  Download

PNGPDFSVG

  Code

  LaTeX

normalizing-flow-coupling-layer.tex (50 lines)

\documentclass[tikz]{standalone}

\usepackage{mathtools}

\usetikzlibrary{calc,positioning,shapes.geometric}

\renewcommand\vec[1]{\boldsymbol{#1}}

\begin{document}
\begin{tikzpicture}[
    thick, node distance=15mm,
    set/.style={draw, diamond, text width=8mm, align=center},
    op/.style={draw, circle, text width=5mm, align=center, fill=orange!40},
  ]

  \node[set, fill=blue!20] (z1) {$\vec z_{1:d}$};
  \node[op, right=of z1] (eq) {\raisebox{-1ex}=};
  \node[set, right=of eq, fill=blue!20] (x1) {$\vec x_{1:d}$};
  \draw[->] (z1) edge (eq) (eq) edge (x1);

  \node[set, below=1 of z1, fill=green!30] (z2) {$\mathclap{\vec z_{d+1:D}}$};
  \node[op, right=of z2] (g) {$g$};
  \node[below=1em of g] (forward) {forward pass};
  \node[set, right=of g, fill=yellow!40] (x2) {$\mathclap{\vec x_{d+1:D}}$};
  \draw[->] (z2) edge (g) (g) edge (x2);

  \node[op] (m) at ($(z1)!0.5!(g)$) {$m$};
  \draw[->] (z1) edge (m) (m) edge (g);

  \begin{scope}[xshift=9cm]

    \node[set, fill=blue!20] (z1) {$\vec z_{1:d}$};
    \node[op, right=of z1] (eq) {\raisebox{-1ex}=};
    \node[set, right=of eq, fill=blue!20] (x1) {$\vec x_{1:d}$};
    \draw[<-] (z1) edge (eq) (eq) edge (x1);

    \node[set, below=1 of z1, fill=green!30] (z2) {$\mathclap{\vec z_{d+1:D}}$};
    \node[op, right=of z2] (g) {$\mathclap{g^{-1}}$};
    \node[below=1em of g] (inverse) {inverse pass};
    \node[set, right=of g, fill=yellow!40] (x2) {$\mathclap{\vec x_{d+1:D}}$};
    \draw[<-] (z2) edge (g) (g) edge (x2);

    \node[op] (m) at ($(x1)!0.5!(g)$) {$m$};
    \draw[->] (x1) edge (m) (m) edge (g);

  \end{scope}

\end{tikzpicture}
\end{document}

  Typst

normalizing-flow-coupling-layer.typ (114 lines)

#import "@preview/cetz:0.3.2": canvas, draw

#set page(width: auto, height: auto, margin: 8pt)

// polygons added in https://github.com/cetz-package/cetz/pull/777. once released, use them to turn rect into diamonds. can't get rotate(z: 45deg) to work for unknown reasons
#canvas({
  import draw: line, content, group, circle, rect

  let node-sep = 2.5 // Horizontal separation between nodes
  let vert-sep = 2.5 // Vertical separation between rows

  // Node styles
  let arrow-style = (
    mark: (end: "stealth", fill: black, scale: 0.75),
    stroke: black + 0.7pt,
  )

  // Helper to draw diamond node with its label
  let diamond(pos, name, label, fill: none) = {
    rect(
      (pos.at(0) - 0.5, pos.at(1) - 0.5),
      (pos.at(0) + 0.5, pos.at(1) + 0.5),
      stroke: black + 0.7pt,
      fill: fill,
      name: name,
    )
    content(pos, label, anchor: "center")
  }

  // Helper to draw circle node with its label
  let circle-node(pos, name, label) = {
    circle(pos, radius: 0.4, name: name, stroke: black + 0.7pt, fill: rgb("#ffa64d").lighten(40%))
    content(pos, label, anchor: "center")
  }

  // Forward pass (left side)
  // First row
  let z1-pos = (0, 0)
  let eq1-pos = (node-sep, 0)
  let x1-pos = (2 * node-sep, 0)

  // Second row
  let z2-pos = (0, -vert-sep)
  let g1-pos = (node-sep, -vert-sep)
  let x2-pos = (2 * node-sep, -vert-sep)

  // Middle node
  let m1-pos = (node-sep / 2, -vert-sep / 2)

  // Draw forward pass nodes
  diamond(z1-pos, "z1", $arrow(z)_(1:d)$, fill: rgb("#cce5ff"))
  circle-node(eq1-pos, "eq1", "=")
  diamond(x1-pos, "x1", $arrow(x)_(1:d)$, fill: rgb("#cce5ff"))
  diamond(z2-pos, "z2", $arrow(z)_(d+1:D)$, fill: rgb("#ccffcc"))
  circle-node(g1-pos, "g1", $arrow(g)$)
  diamond(x2-pos, "x2", $arrow(x)_(d+1:D)$, fill: rgb("#fff5cc"))
  circle-node(m1-pos, "m1", "m")

  // Forward pass arrows
  line("z1", "eq1", ..arrow-style)
  line("eq1", "x1", ..arrow-style)
  line("z2", "g1", ..arrow-style)
  line("g1", "x2", ..arrow-style)
  line("z1", "m1", ..arrow-style)
  line("m1", "g1", ..arrow-style)

  // Label under g1
  content(
    (rel: (0, -1), to: "g1"),
    [forward pass],
    anchor: "south",
  )

  // Inverse pass (right side)
  let right-x = 5 * node-sep

  // First row
  let z1r-pos = (right-x, 0)
  let eq2-pos = (right-x + node-sep, 0)
  let x1r-pos = (right-x + 2 * node-sep, 0)

  // Second row
  let z2r-pos = (right-x, -vert-sep)
  let g2-pos = (right-x + node-sep, -vert-sep)
  let x2r-pos = (right-x + 2 * node-sep, -vert-sep)

  // Middle node
  let m2-pos = (right-x + 1.5 * node-sep, -vert-sep / 2)

  // Draw inverse pass nodes
  diamond(z1r-pos, "z1r", $arrow(z)_(1:d)$, fill: rgb("#cce5ff"))
  circle-node(eq2-pos, "eq2", "=")
  diamond(x1r-pos, "x1r", $arrow(x)_(1:d)$, fill: rgb("#cce5ff"))
  diamond(z2r-pos, "z2r", $arrow(z)_(d+1:D)$, fill: rgb("#ccffcc"))
  circle-node(g2-pos, "g2", $arrow(g)^(-1)$)
  diamond(x2r-pos, "x2r", $arrow(x)_(d+1:D)$, fill: rgb("#fff5cc"))
  circle-node(m2-pos, "m2", "m")

  // Inverse pass arrows (reversed direction)
  line("eq2", "z1r", ..arrow-style)
  line("x1r", "eq2", ..arrow-style)
  line("g2", "z2r", ..arrow-style)
  line("x2r", "g2", ..arrow-style)
  line("x1r", "m2", ..arrow-style)
  line("m2", "g2", ..arrow-style)

  // Label under g2
  content(
    (rel: (0, -1), to: "g2"),
    [inverse pass],
    anchor: "south",
  )
})