« home

Single-head attention

Creator: Francois Fleuret (original)

machine learningattention mechanismattention is all you needtransformercetztikz

Flow diagram of single-head attention illustrating the equation Attention(Q,K,V)=softmaxrow(QKd)V\displaystyle \mathrm{Attention}(Q, K, V) = \mathrm{softmax}_\text{row} \left( \frac{Q K^\top}{\sqrt{d}} \right) V with border colors to indicate tensor dimensions.


Single-head attention

  Download

PNGPDFSVG

  Code

  LaTeX

single-head-attention.tex (75 lines)

% Written by Francois Fleuret <francois@fleuret.org>
% https://twitter.com/francoisfleuret/status/1529744066086424577.

% Original TikZ source: https://fleuret.org/git-extract/tex/single-attention.tex

% Any copyright is dedicated to the Public Domain.
% https://creativecommons.org/publicdomain/zero/1.0

\documentclass[tikz]{standalone}

\usepackage{mathtools}
\def\transpose{^{\top}}
\DeclareMathOperator\softmax{softmax}
\DeclareMathOperator\Attention{Attention}

\usetikzlibrary{positioning, arrows.meta}

\begin{document}

\begin{tikzpicture}[
  value/.style = {
    font=\scriptsize, rectangle, draw=black!50, fill=white,   thick,
    inner sep=3pt, inner xsep=2pt, minimum size=10pt, minimum height=20pt
  },
  parameter/.style = {
    font=\scriptsize, rectangle, draw=black!50, fill=lightblue!15, thick,
    inner sep=0pt, inner xsep=2pt, minimum size=10pt, minimum height=20pt
  },
  operation/.style = {
    font=\scriptsize, rectangle,    draw=black!50, fill=teal!30, thick,
    inner sep=3pt, minimum size=10pt, minimum height=20pt
  },
  flow/.style={->,shorten <= 1pt,shorten >= 1pt, draw=black!50, thick},
  f2f/.style={draw=black!50, thick},
  v2f/.style={{Bar[width=1.5mm]}-,shorten <= 0.75pt,draw=black!50, thick},
  f2v/.style={->,shorten >= 0.75pt,draw=black!50, thick}
]

  \node[font=\bfseries] at (3.5, 2.3) {Single-head attention};
  \node at (3.5, -2.5) {$\displaystyle \Attention(Q, K, V) = \softmax_\text{row} \left( \frac{Q K\transpose}{\sqrt{d}} \right) V$};
  \node[value,    minimum height=0.8cm,minimum width=0.7cm] (K) at (0, 0) {$K$};
  \node[value,    minimum height=1.2cm,minimum width=0.7cm] (Q) [above=0.5cm of K] {$Q$};
  \node[value,    minimum height=0.8cm,minimum width=1.0cm] (V) [below=0.5cm of K] {$V$};
  \node[operation,minimum height=0.4cm,minimum width=0.4cm] (att) [right=0.5cm of K] {$\cdot\transpose$};
  \node[operation,minimum height=0.4cm,minimum width=0.4cm] (sm) [right=0.25cm of att] {$\softmax$};
  \node[value,    minimum height=1.2cm,minimum width=0.8cm] (A) [right=0.5cm of sm] {$A$};
  \node[operation,minimum height=0.4cm,minimum width=0.4cm] (prod) [right=0.5cm of A] {$\cdot$};
  \node[value,    minimum height=1.2cm,minimum width=1.0cm] (Y) [right=0.5cm of prod] {$Y$};

  \draw[v2f,rounded corners=1mm] (K) -- (att);
  \draw[v2f,rounded corners=1mm] (Q) -| (att);
  \draw[f2f,rounded corners=1mm] (att) -- (sm);
  \draw[f2v,rounded corners=1mm] (sm) -- ([xshift=-1pt]A.west);

  \draw[v2f,rounded corners=1mm] (A) -- (prod);
  \draw[v2f,rounded corners=1mm] (V) -| (prod);
  \draw[f2v,rounded corners=1mm] (prod) -- ([xshift=-1pt]Y.west);

  \draw[very thick,yellow] ([yshift=1pt]Q.north west) -- ([yshift=1pt]Q.north east);
  \draw[very thick,yellow] ([yshift=1pt]K.north west) -- ([yshift=1pt]K.north east);
  \draw[very thick,orange] ([yshift=1pt]V.north west) -- ([yshift=1pt]V.north east);
  \draw[very thick,orange] ([yshift=1pt]Y.north west) -- ([yshift=1pt]Y.north east);

  \draw[very thick,red] ([xshift=-1pt]V.north west) -- ([xshift=-1pt]V.south west);
  \draw[very thick,red] ([xshift=-1pt]K.north west) -- ([xshift=-1pt]K.south west);
  \draw[very thick,cyan] ([xshift=-1pt]Q.north west) -- ([xshift=-1pt]Q.south west);
  \draw[very thick,cyan] ([xshift=-1pt]Y.north west) -- ([xshift=-1pt]Y.south west);

  \draw[very thick,cyan] ([xshift=-1pt]A.north west) -- ([xshift=-1pt]A.south west);
  \draw[very thick,red] ([yshift=1pt]A.north west) -- ([yshift=1pt]A.north east);

\end{tikzpicture}

\end{document}

  Typst

single-head-attention.typ (176 lines)

#import "@preview/cetz:0.3.2": canvas, draw

#set page(width: auto, height: auto, margin: 8pt)

#canvas({
  import draw: rect, content, line, set-style

  // Define spacing variables
  let h-sep = 1.2 // horizontal separation between elements
  let v-sep = 0.8 // vertical separation between elements

  // Helper function for drawing a matrix with colored dimension indicators
  let matrix(
    pos, // (x, y) position of top-left corner
    size, // (width, height) of matrix
    label, // matrix label (e.g. Q, K, V)
    top-color: none, // color for top dimension line
    left-color: none, // color for left dimension line
    style: (stroke: rgb(50%, 50%, 50%), fill: white, thickness: 1.5pt), // matrix style
  ) = {
    let (x, y) = pos
    let (w, h) = size
    let offset = 0.1 // offset for dimension lines to avoid overlap

    // Draw matrix rectangle
    rect(pos, (x + w, y - h), ..style, name: label)
    content(label, $#label$)

    // Draw dimension indicators if colors specified
    if top-color != none {
      line(
        (x - 0.02, y + offset),
        (x + w + 0.02, y + offset),
        stroke: (paint: top-color, thickness: 2pt),
      )
    }
    if left-color != none {
      line(
        (x - offset, y + 0.02),
        (x - offset, y - h - 0.02),
        stroke: (paint: left-color, thickness: 2pt),
      )
    }
  }

  // Define styles
  let value-style = (
    stroke: rgb(50%, 50%, 50%),
    fill: white,
    thickness: 1.5pt,
  )

  let operation-style = (
    stroke: rgb(50%, 50%, 50%),
    fill: rgb(30%, 80%, 80%, 30%),
    thickness: 1.5pt,
  )

  let edge-style = (
    mark: (start: "|", offset: 0.075, scale: 1.3),
    stroke: rgb(50%, 50%, 50%),
    thickness: 1.5pt,
  )

  let arrow-style = (
    mark: (
      start: (symbol: "|", offset: 0.075, scale: 1.3),
      end: (symbol: "stealth", offset: 0.15, scale: 0.45),
      fill: rgb(50%, 50%, 50%),
    ),
    stroke: rgb(50%, 50%, 50%),
    thickness: 1.5pt,
  )

  // Title and equation
  content((4, 2.5), text(weight: "bold", size: 1.2em)[Single-head attention], name: "title")
  content(
    (4, -2.75),
    $"Attention"(Q, K, V) = "softmax"_"row" ( (Q K^top) / sqrt(d)) V$,
    name: "equation",
  )

  // Main nodes using helper function
  matrix(
    (0, 2.7),
    (0.7, 1.8),
    "Q",
    top-color: rgb("#FFFF00"),
    left-color: rgb("#00FFFF"),
    style: value-style,
  )

  matrix(
    (0, 0.4),
    (0.7, 0.8),
    "K",
    top-color: rgb("#FFFF00"),
    left-color: rgb("#FF0000"),
    style: value-style,
  )

  matrix(
    (0, -1),
    (1.0, 1.2),
    "V",
    top-color: rgb("#FFA500"),
    left-color: rgb("#FF0000"),
    style: value-style,
  )

  // Operation nodes with consistent spacing
  content(
    (h-sep + 0.4, 0),
    $dot.op^top$,
    frame: "rect",
    stroke: rgb(50%, 50%, 50%) + .75pt,
    fill: rgb(30%, 80%, 80%, 30%),
    padding: (5pt, 3pt, 1pt),
    name: "att",
  )

  content(
    (2 * h-sep + 0.6, 0),
    [softmax],
    frame: "rect",
    stroke: rgb(50%, 50%, 50%) + .75pt,
    fill: rgb(30%, 80%, 80%, 30%),
    padding: (2pt, 3pt, 3pt),
    name: "softmax",
  )

  matrix(
    (3 * h-sep + 0.7, 0.9),
    (0.8, 1.8),
    "A",
    top-color: rgb("#FF0000"),
    left-color: rgb("#00FFFF"),
    style: value-style,
  )

  content(
    (4 * h-sep + 1, 0),
    $dot.op$,
    frame: "rect",
    stroke: rgb(50%, 50%, 50%),
    fill: rgb(30%, 80%, 80%, 30%),
    padding: (1pt, 4pt, 2pt),
    name: "prod",
  )

  matrix(
    (5 * h-sep + 0.7, 0.9),
    (1.0, 1.8),
    "Y",
    top-color: rgb("#FFA500"),
    left-color: rgb("#00FFFF"),
    style: value-style,
  )

  // Arrows with proper right angles using perpendicular coordinates
  // K to att (straight)
  line("K.east", "att.west", ..edge-style, name: "k-to-att")

  // Q to att (right angle)
  line("Q.east", ("Q.east", "-|", "att.north"), "att.north", ..edge-style, name: "q-to-att")

  // V to prod (right angle)
  line("V.east", ("V.east", "-|", "prod.south"), "prod.south", ..edge-style, name: "v-to-prod")

  // Other straight connections
  line("att.east", "softmax.west", stroke: rgb(50%, 50%, 50%), name: "att-to-sm")
  line("softmax.east", "A.west", ..arrow-style, name: "sm-to-a")
  line("A.east", "prod.west", stroke: rgb(50%, 50%, 50%), name: "a-to-prod")
  line("prod.east", "Y.west", ..arrow-style, name: "prod-to-y")
})