Variational autoencoder architecture. The earliest type of generative machine learning model. Inspired by https://towardsdatascience.com/intuitively-understanding-variational-autoencoders-1bfe67eb5daf.
\documentclass[tikz]{standalone}
\usepackage{xstring}
\usetikzlibrary{fit,positioning}
\newcommand\drawNodes[2]{
% #1 (str): namespace
% #2 (list[list[str]]): list of labels to print in the node of each neuron
\foreach \neurons [count=\lyrIdx] in #2 {
\StrCount{\neurons}{,}[\layerLen] % use xstring package to save each layer size into \layerLen macro
\foreach \n [count=\nIdx] in \neurons
\node[neuron] (#1-\lyrIdx-\nIdx) at (2*\lyrIdx, \layerLen/2-1.4*\nIdx) {\n};
}
}
\newcommand\denselyConnectNodes[2]{
% #1 (str): namespace
% #2 (list[int]): number of nodes in each layer
\foreach \n [count=\lyrIdx, remember=\lyrIdx as \previdx, remember=\n as \prevn] in #2 {
\foreach \y in {1,...,\n} {
\ifnum \lyrIdx > 1
\foreach \x in {1,...,\prevn}
\draw[->] (#1-\previdx-\x) -- (#1-\lyrIdx-\y);
\fi
}
}
}
\begin{document}
\begin{tikzpicture}[
shorten >=1pt, shorten <=1pt,
neuron/.style={circle, draw, minimum size=4ex, thick},
legend/.style={font=\large\bfseries},
]
% encoder
\drawNodes{encoder}{{{,,,,}, {,,,}, {,,}}}
\denselyConnectNodes{encoder}{{5, 4, 3}}
% decoder
\begin{scope}[xshift=11cm]
\drawNodes{decoder}{{{,,}, {,,,}, {,,,,}}}
\denselyConnectNodes{decoder}{{3, 4, 5}}
\end{scope}
% mu, sigma, sample nodes
\foreach \idx in {1,...,3} {
\coordinate[neuron, right=2 of encoder-3-2, yshift=\idx cm,, fill=yellow, fill opacity=0.2] (mu-\idx);
\coordinate[neuron, right=2 of encoder-3-2, yshift=-\idx cm, fill=blue, fill opacity=0.1] (sigma-\idx);
\coordinate[neuron, right=4 of encoder-3-2, yshift=\idx cm-2cm, fill=green, fill opacity=0.1] (sample-\idx);
}
% mu, sigma, sample boxes
\node [label=$\mu$, fit=(mu-1) (mu-3), draw, fill=yellow, opacity=0.45] (mu) {};
\node [label=$\sigma$, fit=(sigma-1) (sigma-3), draw, fill=blue, opacity=0.3] (sigma) {};
\node [label=sample, fit=(sample-1) (sample-3), draw, fill=green, opacity=0.3] (sample) {};
% mu, sigma, sample connections
\draw[->] (mu.east) edge (sample.west) (sigma.east) -- (sample.west);
\foreach \a in {1,2,3}
\foreach \b in {1,2,3} {
\draw[->] (encoder-3-\a) -- (mu-\b);
\draw[->] (encoder-3-\a) -- (sigma-\b);
\draw[->] (sample-\a) -- (decoder-1-\b);
}
% input + output labels
\foreach \idx in {1,...,5} {
\node[left=0 of encoder-1-\idx] {$x_\idx$};
\node[right=0 of decoder-3-\idx] {$\hat x_\idx$};
}
\node[above=0.1 of encoder-1-1] {input};
\node[above=0.1 of decoder-3-1] {output};
\end{tikzpicture}
\end{document}
#import "@preview/cetz:0.3.2": canvas, draw
#set page(width: auto, height: auto, margin: 8pt)
#canvas({
import draw: line, circle, content, rect
let node-style = (stroke: black + 1pt, fill: white)
let layer-sep = 2 // Horizontal separation between layers
let node-sep = 1.4 // Vertical separation between nodes
let arrow-style = (stroke: .5pt, mark: (end: "stealth", fill: black, scale: .3))
// Helper function to draw a layer of nodes
let draw-layer(x, nodes, prefix: "") = {
let top-y = nodes / 2
let bottom-y = nodes / 2 - node-sep * (nodes - 1)
for ii in range(nodes) {
circle(
(x, nodes / 2 - node-sep * ii),
radius: 0.3,
name: prefix + str(ii + 1),
..node-style,
)
}
// Create named points for the layer bounds
circle((x, top-y), radius: 0, name: prefix + "-top", fill: none)
circle((x, bottom-y), radius: 0, name: prefix + "-bottom", fill: none)
}
// Helper to connect all nodes between layers
let connect-layers(from-prefix, to-prefix, from-nodes, to-nodes) = {
for ii in range(from-nodes) {
for jj in range(to-nodes) {
line(
(from-prefix + str(ii + 1)),
(to-prefix + str(jj + 1)),
..arrow-style,
)
}
}
}
// Draw encoder
draw-layer(0, 5, prefix: "e1") // Input layer
draw-layer(layer-sep, 4, prefix: "e2") // Hidden layer
draw-layer(layer-sep * 2, 3, prefix: "e3") // Output layer
// Connect encoder layers
connect-layers("e1", "e2", 5, 4)
connect-layers("e2", "e3", 4, 3)
// Draw mu nodes
let mu-x = layer-sep * 3
for ii in range(3) {
circle(
(mu-x, 1.5 + ii),
radius: 0.4,
name: "mu" + str(ii + 1),
fill: rgb(100%, 100%, 0%, 20%),
..node-style,
)
}
// Draw sigma nodes
for ii in range(3) {
circle(
(mu-x, -1.5 - ii),
radius: 0.4,
name: "sigma" + str(ii + 1),
fill: rgb(0%, 0%, 100%, 10%),
..node-style,
)
}
// Draw sample nodes
let sample-x = mu-x + layer-sep
for ii in range(3) {
circle(
(sample-x, ii - 1),
radius: 0.4,
name: "sample" + str(ii + 1),
fill: rgb(0%, 100%, 0%, 10%),
..node-style,
)
}
// Draw boxes around mu, sigma, sample nodes
rect(
(mu-x - 0.5, 1),
(mu-x + 0.5, 4),
fill: rgb(100%, 100%, 0%, 45%),
name: "mu-box",
stroke: .1pt,
)
content("mu-box.north", $mu$, anchor: "south", padding: 3pt)
rect(
(mu-x - 0.5, -4),
(mu-x + 0.5, -1),
fill: rgb(0%, 0%, 100%, 30%),
name: "sigma-box",
stroke: .1pt,
)
content("sigma-box.north", $sigma$, anchor: "south", padding: 3pt)
rect(
(sample-x - 0.5, -1.5),
(sample-x + 0.5, 1.5),
fill: rgb(0%, 100%, 0%, 30%),
name: "sample-box",
stroke: .1pt,
)
content("sample-box.north", text(size: 0.8em)[Sample], anchor: "south", padding: 3pt)
// Connect encoder to mu and sigma
for ii in range(3) {
for jj in range(3) {
line(("e3" + str(ii + 1)), ("mu" + str(jj + 1)), ..arrow-style)
line(("e3" + str(ii + 1)), ("sigma" + str(jj + 1)), ..arrow-style)
}
}
// Connect mu and sigma nodes to sample nodes
line("mu-box", "sample-box", ..arrow-style)
line("sigma-box", "sample-box", ..arrow-style)
// Draw decoder (mirrored structure of encoder)
let decoder-x = sample-x + layer-sep
draw-layer(decoder-x, 3, prefix: "d1")
draw-layer(decoder-x + layer-sep, 4, prefix: "d2")
draw-layer(decoder-x + layer-sep * 2, 5, prefix: "d3")
// Connect decoder layers
connect-layers("d1", "d2", 3, 4)
connect-layers("d2", "d3", 4, 5)
// Connect sample to decoder
for ii in range(3) {
for jj in range(3) {
line(("sample" + str(ii + 1)), ("d1" + str(jj + 1)), ..arrow-style)
}
}
// Add input and output labels
for ii in range(5) {
content(
"e1" + str(ii + 1) + ".west",
$x_#(ii + 1)$,
anchor: "east",
padding: 3pt,
)
content(
"d3" + str(ii + 1) + ".east",
$hat(x)_#(ii + 1)$,
anchor: "west",
padding: 3pt,
)
}
content("e11.north", text(weight: "regular")[Input], anchor: "south", padding: 5pt)
content("d31.north", text(weight: "regular")[Output], anchor: "south", padding: 5pt)
})