Creator: Francois Fleuret (original)
Flow diagram of single-head attention illustrating the equation with border colors to indicate tensor dimensions.
% Written by Francois Fleuret <francois@fleuret.org>
% https://twitter.com/francoisfleuret/status/1529744066086424577.
% Original TikZ source: https://fleuret.org/git-extract/tex/single-attention.tex
% Any copyright is dedicated to the Public Domain.
% https://creativecommons.org/publicdomain/zero/1.0
\documentclass[tikz]{standalone}
\usepackage{mathtools}
\def\transpose{^{\top}}
\DeclareMathOperator\softmax{softmax}
\DeclareMathOperator\Attention{Attention}
\usetikzlibrary{positioning, arrows.meta}
\begin{document}
\begin{tikzpicture}[
value/.style = {
font=\scriptsize, rectangle, draw=black!50, fill=white, thick,
inner sep=3pt, inner xsep=2pt, minimum size=10pt, minimum height=20pt
},
parameter/.style = {
font=\scriptsize, rectangle, draw=black!50, fill=lightblue!15, thick,
inner sep=0pt, inner xsep=2pt, minimum size=10pt, minimum height=20pt
},
operation/.style = {
font=\scriptsize, rectangle, draw=black!50, fill=teal!30, thick,
inner sep=3pt, minimum size=10pt, minimum height=20pt
},
flow/.style={->,shorten <= 1pt,shorten >= 1pt, draw=black!50, thick},
f2f/.style={draw=black!50, thick},
v2f/.style={{Bar[width=1.5mm]}-,shorten <= 0.75pt,draw=black!50, thick},
f2v/.style={->,shorten >= 0.75pt,draw=black!50, thick}
]
\node[font=\bfseries] at (3.5, 2.3) {Single-head attention};
\node at (3.5, -2.5) {$\displaystyle \Attention(Q, K, V) = \softmax_\text{row} \left( \frac{Q K\transpose}{\sqrt{d}} \right) V$};
\node[value, minimum height=0.8cm,minimum width=0.7cm] (K) at (0, 0) {$K$};
\node[value, minimum height=1.2cm,minimum width=0.7cm] (Q) [above=0.5cm of K] {$Q$};
\node[value, minimum height=0.8cm,minimum width=1.0cm] (V) [below=0.5cm of K] {$V$};
\node[operation,minimum height=0.4cm,minimum width=0.4cm] (att) [right=0.5cm of K] {$\cdot\transpose$};
\node[operation,minimum height=0.4cm,minimum width=0.4cm] (sm) [right=0.25cm of att] {$\softmax$};
\node[value, minimum height=1.2cm,minimum width=0.8cm] (A) [right=0.5cm of sm] {$A$};
\node[operation,minimum height=0.4cm,minimum width=0.4cm] (prod) [right=0.5cm of A] {$\cdot$};
\node[value, minimum height=1.2cm,minimum width=1.0cm] (Y) [right=0.5cm of prod] {$Y$};
\draw[v2f,rounded corners=1mm] (K) -- (att);
\draw[v2f,rounded corners=1mm] (Q) -| (att);
\draw[f2f,rounded corners=1mm] (att) -- (sm);
\draw[f2v,rounded corners=1mm] (sm) -- ([xshift=-1pt]A.west);
\draw[v2f,rounded corners=1mm] (A) -- (prod);
\draw[v2f,rounded corners=1mm] (V) -| (prod);
\draw[f2v,rounded corners=1mm] (prod) -- ([xshift=-1pt]Y.west);
\draw[very thick,yellow] ([yshift=1pt]Q.north west) -- ([yshift=1pt]Q.north east);
\draw[very thick,yellow] ([yshift=1pt]K.north west) -- ([yshift=1pt]K.north east);
\draw[very thick,orange] ([yshift=1pt]V.north west) -- ([yshift=1pt]V.north east);
\draw[very thick,orange] ([yshift=1pt]Y.north west) -- ([yshift=1pt]Y.north east);
\draw[very thick,red] ([xshift=-1pt]V.north west) -- ([xshift=-1pt]V.south west);
\draw[very thick,red] ([xshift=-1pt]K.north west) -- ([xshift=-1pt]K.south west);
\draw[very thick,cyan] ([xshift=-1pt]Q.north west) -- ([xshift=-1pt]Q.south west);
\draw[very thick,cyan] ([xshift=-1pt]Y.north west) -- ([xshift=-1pt]Y.south west);
\draw[very thick,cyan] ([xshift=-1pt]A.north west) -- ([xshift=-1pt]A.south west);
\draw[very thick,red] ([yshift=1pt]A.north west) -- ([yshift=1pt]A.north east);
\end{tikzpicture}
\end{document}
#import "@preview/cetz:0.3.2": canvas, draw
#set page(width: auto, height: auto, margin: 8pt)
#canvas({
import draw: rect, content, line, set-style
// Define spacing variables
let h-sep = 1.2 // horizontal separation between elements
let v-sep = 0.8 // vertical separation between elements
// Helper function for drawing a matrix with colored dimension indicators
let matrix(
pos, // (x, y) position of top-left corner
size, // (width, height) of matrix
label, // matrix label (e.g. Q, K, V)
top-color: none, // color for top dimension line
left-color: none, // color for left dimension line
style: (stroke: rgb(50%, 50%, 50%), fill: white, thickness: 1.5pt), // matrix style
) = {
let (x, y) = pos
let (w, h) = size
let offset = 0.1 // offset for dimension lines to avoid overlap
// Draw matrix rectangle
rect(pos, (x + w, y - h), ..style, name: label)
content(label, $#label$)
// Draw dimension indicators if colors specified
if top-color != none {
line(
(x - 0.02, y + offset),
(x + w + 0.02, y + offset),
stroke: (paint: top-color, thickness: 2pt),
)
}
if left-color != none {
line(
(x - offset, y + 0.02),
(x - offset, y - h - 0.02),
stroke: (paint: left-color, thickness: 2pt),
)
}
}
// Define styles
let value-style = (
stroke: rgb(50%, 50%, 50%),
fill: white,
thickness: 1.5pt,
)
let operation-style = (
stroke: rgb(50%, 50%, 50%),
fill: rgb(30%, 80%, 80%, 30%),
thickness: 1.5pt,
)
let edge-style = (
mark: (start: "|", offset: 0.075, scale: 1.3),
stroke: rgb(50%, 50%, 50%),
thickness: 1.5pt,
)
let arrow-style = (
mark: (
start: (symbol: "|", offset: 0.075, scale: 1.3),
end: (symbol: "stealth", offset: 0.15, scale: 0.45),
fill: rgb(50%, 50%, 50%),
),
stroke: rgb(50%, 50%, 50%),
thickness: 1.5pt,
)
// Title and equation
content((4, 2.5), text(weight: "bold", size: 1.2em)[Single-head attention], name: "title")
content(
(4, -2.75),
$"Attention"(Q, K, V) = "softmax"_"row" ( (Q K^top) / sqrt(d)) V$,
name: "equation",
)
// Main nodes using helper function
matrix(
(0, 2.7),
(0.7, 1.8),
"Q",
top-color: rgb("#FFFF00"),
left-color: rgb("#00FFFF"),
style: value-style,
)
matrix(
(0, 0.4),
(0.7, 0.8),
"K",
top-color: rgb("#FFFF00"),
left-color: rgb("#FF0000"),
style: value-style,
)
matrix(
(0, -1),
(1.0, 1.2),
"V",
top-color: rgb("#FFA500"),
left-color: rgb("#FF0000"),
style: value-style,
)
// Operation nodes with consistent spacing
content(
(h-sep + 0.4, 0),
$dot.op^top$,
frame: "rect",
stroke: rgb(50%, 50%, 50%) + .75pt,
fill: rgb(30%, 80%, 80%, 30%),
padding: (5pt, 3pt, 1pt),
name: "att",
)
content(
(2 * h-sep + 0.6, 0),
[softmax],
frame: "rect",
stroke: rgb(50%, 50%, 50%) + .75pt,
fill: rgb(30%, 80%, 80%, 30%),
padding: (2pt, 3pt, 3pt),
name: "softmax",
)
matrix(
(3 * h-sep + 0.7, 0.9),
(0.8, 1.8),
"A",
top-color: rgb("#FF0000"),
left-color: rgb("#00FFFF"),
style: value-style,
)
content(
(4 * h-sep + 1, 0),
$dot.op$,
frame: "rect",
stroke: rgb(50%, 50%, 50%),
fill: rgb(30%, 80%, 80%, 30%),
padding: (1pt, 4pt, 2pt),
name: "prod",
)
matrix(
(5 * h-sep + 0.7, 0.9),
(1.0, 1.8),
"Y",
top-color: rgb("#FFA500"),
left-color: rgb("#00FFFF"),
style: value-style,
)
// Arrows with proper right angles using perpendicular coordinates
// K to att (straight)
line("K.east", "att.west", ..edge-style, name: "k-to-att")
// Q to att (right angle)
line("Q.east", ("Q.east", "-|", "att.north"), "att.north", ..edge-style, name: "q-to-att")
// V to prod (right angle)
line("V.east", ("V.east", "-|", "prod.south"), "prod.south", ..edge-style, name: "v-to-prod")
// Other straight connections
line("att.east", "softmax.west", stroke: rgb(50%, 50%, 50%), name: "att-to-sm")
line("softmax.east", "A.west", ..arrow-style, name: "sm-to-a")
line("A.east", "prod.west", stroke: rgb(50%, 50%, 50%), name: "a-to-prod")
line("prod.east", "Y.west", ..arrow-style, name: "prod-to-y")
})