« home

Masked Autoregressive Flow

machine learninggenerative modelingprobabilitystatisticsnormalizing flowscetztikz

Illustration of the slow (sequential) forward pass of a Masked Autoregressive Flow (MAF) layer as introduced in arxiv:1705.07057. Inspired by https://blog.evjang.com/2018/01/nf2.html.


Masked Autoregressive Flow

  Download

PNGPDFSVG

  Code

  LaTeX

masked-autoregressive-flow.tex (49 lines)

\documentclass[tikz]{standalone}

\usetikzlibrary{calc,positioning}

\begin{document}
\begin{tikzpicture}[
    thick, text centered,
    box/.style={draw, thin, minimum width=1cm},
    func/.style={circle, text=white},
    input/.style={draw=red, very thick},
  ]

  % x nodes
  \node[box, input, fill=blue!20] (x1) {$x_1$};
  \node[box, input, fill=blue!20, right of=x1] (x2) {$x_2$};
  \node[right of=x2] (xdots1) {\dots};
  \node[box, input, fill=blue!20, right of=xdots1] (xd) {$x_d$};
  \node[box, fill=green!60!black, text opacity=1, opacity=0.4, right=2 of xd] (xdp1) {$x_{d+1}$};
  \node[right of=xdp1] (xdots2) {\dots};
  \node[box, fill=green!60!black, text opacity=1, opacity=0.4, right of=xdots2] (xD) {$x_D$};

  % z nodes
  \node[box, fill=blue!20, below=3 of x1] (z1) {$z_1$};
  \node[box, fill=blue!20, right of=z1] (z2) {$z_2$};
  \node[right of=z2] (zdots1) {\dots};
  \node[box, fill=blue!20, right of=zdots1] (zd) {$z_d$};
  \node[box, input, fill=orange!40, right=2 of zd] (zdp1) {$z_{d+1}$};
  \node[right of=zdp1] (zdots2) {\dots};
  \node[box, fill=orange!40, right of=zdots2] (zD) {$z_D$};

  % z to x lines
  \draw[->] (zdp1) -- (xdp1);

  % scale and translate functions
  \node[func, font=\large, fill=teal, above right=0.1] (t) at ($(zd)!0.5!(xdp1)$) {$t$};
  \fill[teal, opacity=0.5] (x1.south west) -- (t.center) -- (xd.south east) -- (x1.south west);

  \node[func, font=\large, fill=orange, below left=0.1] (s) at ($(zd)!0.5!(xdp1)$) {$s$};
  \fill[orange, opacity=0.5] (x1.south west) -- (s.center) -- (xd.south east) -- (x1.south west);

  % feeding in s and t
  \node[func, inner sep=0, fill=orange] (odot1) at ($(zdp1)!0.4!(xdp1)$) {$\odot$};
  \node[func, inner sep=0, fill=teal] (oplus1) at ($(zdp1)!0.7!(xdp1)$) {$\oplus$};
  \draw[orange, ->] (s) to[bend right=5] (odot1);
  \draw[teal, ->] (t) to[bend right=5] (oplus1);

\end{tikzpicture}
\end{document}

  Typst

masked-autoregressive-flow.typ (119 lines)

#import "@preview/cetz:0.3.2": canvas, draw

#set page(width: auto, height: auto, margin: 8pt)

#canvas({
  import draw: line, content, rect, circle, hobby, on-layer

  // Define styles and constants
  let node-width = 1
  let node-height = 0.45
  let horiz-sep = 1.2
  let vert-sep = 4
  let arrow-style = (end: "stealth", fill: black, scale: .5)
  let (orange, blue, teal) = (rgb("#e8c268"), rgb("#63a7e390"), rgb("#008080"))
  let input-style = (paint: red, thickness: 1pt)

  // Helper function for boxes
  let box(pos, label, fill: none, name: none, input: false) = {
    rect(
      pos,
      (rel: (node-width, node-height)),
      fill: fill,
      stroke: if input { input-style } else { (paint: black, thickness: 0.3pt) },
      name: name,
    )
    content(name, text(baseline: -1pt)[#label])
  }

  // Helper function for dots between boxes
  let c-dots(left-name, right-name) = {
    content((left-name, 50%, right-name), text(size: 14pt)[$dots.c$])
  }

  // Create nodes in both rows
  for (y, prefix, colors) in ((0, "x", (blue, rgb(0%, 100%, 0%, 20%))), (-vert-sep, "z", (blue, orange))) {
    // Left group (indices 1, 2, d)
    for (i, x) in ((1, 0), (2, 1), ("d", 3)) {
      box(
        (x * horiz-sep, y),
        $#prefix#sub(str(i))$,
        fill: colors.at(0),
        name: prefix + str(i),
        input: prefix == "x",
      )
    }
    c-dots(prefix + "2", prefix + "d")

    // Right group (indices d+1, D)
    for (ii, x-pos) in (("d+1", 5), ("D", 7)) {
      box(
        (x-pos * horiz-sep, y),
        $#prefix#sub(str(ii))$,
        fill: colors.at(1),
        name: prefix + (if ii == "d+1" { "d-plus-1" } else { ii }),
        input: prefix == "z" and ii == "d+1",
      )
    }
    c-dots(prefix + "d-plus-1", prefix + "D")
  }

  // Single vertical connecting line
  line("zd-plus-1", "xd-plus-1", mark: arrow-style, name: "line-d-plus-1")

  // Function circles and triangles
  for (label, color, pos, rel-pos) in (
    ("t", teal, (4.3 * horiz-sep, 0.4 * -vert-sep), none),
    ("s", orange, none, (-.6, -.75)),
  ) {
    on-layer(
      1, // render above the filled triangles
      content(
        if pos != none { pos } else { (rel: rel-pos, to: "t-circle") },
        text(fill: white, baseline: if label == "s" { -1pt } else { 0pt })[#label],
        frame: "circle",
        name: label + "-circle",
        stroke: none,
        fill: color,
        padding: 2pt,
      ),
    )
    line(
      "x1.south-west",
      label + "-circle",
      "xd.south-east",
      fill: color.transparentize(40%),
      close: true,
      stroke: none,
      name: label + "-triangle",
    )
  }

  // Operation circles
  for (op, (color, label, pos)) in (
    "odot": (orange, $dot.circle$, "40%"),
    "oplus": (teal, $plus.circle$, "70%"),
  ).pairs() {
    content(
      "line-d-plus-1." + pos,
      text(fill: white, baseline: -.2pt)[#label],
      frame: "circle",
      name: "line-d-plus-1-" + op,
      stroke: none,
      fill: color,
      padding: .1pt,
    )
  }

  // Connect s and t to operations
  for (src, op, color) in (("s", "odot", orange), ("t", "oplus", teal)) {
    hobby(
      src + "-circle",
      "line-d-plus-1-" + op,
      mark: (..arrow-style, offset: 5pt),
      stroke: color + 0.75pt,
      tension: 0.8,
    )
  }
})