Generative Adversarial Network

Creator: Petar Veličković (original)

Generative adversarial network (GAN) architecture. A GAN has two parts. The discriminator DD acts as a classifier that learns to distinguish fake data produced by the generator GG from real data. GG incurs a penalty when DD detects implausible results. This signal is backpropagated through the generator weights such that GG learns to produce more realistic samples over time, eventually fooling the discriminator if training succeeds.

    ->, thick,
    node/.style={circle, fill=teal!60},
    label/.style={below, font=\footnotesize},

  \node[node] (zin) {$\vec z_\text{in}$};
  \node[node, right=5em of zin] (fake) {$\vec x_\text{fake}$};
  \draw (zin) -- node[above] {$G(\vec x)$} node[label] {generator} (fake);

  \draw[<-] (zin) -- node[above] {$p_\theta(\vec z)$} node[label] {latent noise} ++(-3,0);
  \node[node, above=of fake] (real) {$\vec x_\text{real}$};
  \draw[<-] (real) -- node[above] {$p_\text{data}(\vec x)$} ++(-3,0);
  \node[node, right=6em of fake] (D) at ($(fake)!0.5!(real)$) {$\vec x$};
  \node[right=7em of D] (out) {real?};
  \draw (D) -- node[above] {$D(\vec x)$} node[label] {discriminator} (out);

  \coordinate[right=2.5em of fake, circle, fill, inner sep=0.15em] (pt1);
  \coordinate[right=2.5em of real, circle, fill, inner sep=0.15em] (pt2);

  \draw[-, dashed] (pt1) edge[bend left] coordinate[circle, fill=orange, inner sep=1mm, pos=0.7] (pt3) (pt2);
  \draw (fake) -- (pt1) (real) -- (pt2) (pt3) -- (D);



generative-adversarial-network.typ (84 lines)

#import "@preview/cetz:0.3.2": canvas, draw

#set page(width: auto, height: auto, margin: 8pt)

  import draw: line, content, circle, group, hobby, on-layer

  // Style definitions
  let node-style = (
    stroke: none,
    fill: rgb("#66B2B2"), // teal!60 equivalent
    radius: 0.53,
  let arrow-style = (
    stroke: black + 0.8pt,
    mark: (end: "stealth", scale: 0.4, fill: black),

  let (y-real, y-fake) = (2, 0)

  // Draw nodes
  // z_in node
  circle((0, y-fake), name: "zin", ..node-style)
  content("zin", $arrow(z)_"in"$)

  // x_fake node
  circle((3, y-fake), name: "fake", ..node-style)
  content("fake", $arrow(x)_"fake"$)

  // x_real node
  circle((3, y-real), name: "real", ..node-style)
  content("real", $arrow(x)_"real"$)

  // x node (discriminator input)
  circle((6, y-real / 2), name: "D", ..node-style, radius: 0.4)
  content("D", $arrow(x)$)

  // Output node
  content((9, y-real / 2), text(size: 0.9em, baseline: -1pt)[real?], name: "out", padding: 2pt)

  // Draw arrows and their labels
  // Generator input arrow
  line((-2.5, y-fake), "zin", ..arrow-style, name: "zin-line")
  content("zin-line.mid", $p_theta (arrow(z))$, anchor: "south", padding: 0.1)
  content("zin-line.mid", text(size: 0.8em)[latent noise], anchor: "north", padding: 0.1)

  // Generator arrow
  line("zin", "fake", ..arrow-style, name: "fake-line")
  content("fake-line.mid", $G(arrow(x))$, anchor: "south", padding: 0.1)
  content("fake-line.mid", text(size: 0.8em)[generator], anchor: "north", padding: 0.1)

  // Real data arrow
  line((-2, y-real), "real", ..arrow-style, name: "real-line")
  content("real-line.mid", $p_"data" (arrow(x))$, anchor: "south", padding: 0.1)

  // Connection points with names
  circle((4.5, y-fake), radius: 0.06, fill: black, name: "dot1")
  circle((4.5, y-real), radius: 0.06, fill: black, name: "dot2")
    circle((4.25, 2 * y-real / 3), radius: 0.12, fill: orange, stroke: none, name: "dot3"),

  // Draw connecting lines with names
  line("fake", "dot1", ..arrow-style, name: "conn1")
  line("real", "dot2", ..arrow-style, name: "conn2")
  line("dot3", "D", ..arrow-style, name: "conn3")

  // Draw dashed curve using named points
    (4.2, (y-real - y-fake) / 2),
    stroke: (dash: "dashed"),
    omega: 2,
    name: "dashed-curve",

  // Discriminator arrow and labels
  line("D", "out", ..arrow-style, name: "disc-line")
  content("disc-line.mid", $D(arrow(x))$, anchor: "south", padding: 0.1)
  content("disc-line.mid", text(size: 0.8em)[discriminator], anchor: "north", padding: 0.15)