May 12, 2023

Matrix Multiplication for AI: Notes and Intuition

While learning AI and transformers, I realized something interesting: most AI systems are basically doing vector-by-matrix and matrix-by-matrix multiplications over and over again. If you understand matrix multiplication deeply, a lot of ML suddenly becomes much easier.

Why Matrix Multiplication Feels Core to AI

Examples I keep seeing everywhere:

AI SystemFormula
Linear regressiony = Xw
Neural network layery = Wx + b
Transformer queriesQ = XW_Q
Transformer keysK = XW_K
Transformer valuesV = XW_V
Attention scoresQK^T
Context mixingsoftmax(QK^T)V

So most deep learning is basically:

matrix multiplication + nonlinear activation.

What Matrix Multiplication Really Means (My Intuition)

I try to think of matrix multiplication in a few ways:

ViewMeaning
Linear combinationweighted sum of features
Projectionmove data to a new space
Feature mixingcombine information
Geometryrotation / scaling of vectors

Example:

Embedding -> attention space

Q = XW_Q

Here the weight matrix is basically deciding how to mix embedding features.

Python Walkthrough (All Concepts)

import numpy as np

# 1) Linear regression: y = Xw
X = np.random.randn(4, 3)   # 4 samples, 3 features
w = np.random.randn(3, 1)
y = X @ w

# 2) Neural network layer: y = Wx + b
W = np.random.randn(5, 3)   # 5 outputs, 3 inputs
b = np.random.randn(5, 1)
x = np.random.randn(3, 1)
layer_out = W @ x + b

# 3) Linear combination (feature mixing)
v1 = np.array([[1.0], [0.0], [0.0]])
v2 = np.array([[0.0], [1.0], [0.0]])
basis = np.concatenate([v1, v2], axis=1)  # 3 x 2
coeffs = np.array([[2.0], [3.0]])         # 2 x 1
combo = basis @ coeffs                    # 3 x 1

# 4) Projection onto a subspace (orthonormal basis U)
U, _ = np.linalg.qr(np.random.randn(3, 2))  # 3 x 2, columns orthonormal
v = np.random.randn(3, 1)
proj = U @ (U.T @ v)

# 5) Geometry (2D rotation)
theta = np.deg2rad(30)
R = np.array([[np.cos(theta), -np.sin(theta)],
              [np.sin(theta),  np.cos(theta)]])
v2d = np.array([[1.0], [0.0]])
rotated = R @ v2d

# 6) Transformers: Q, K, V and attention
n, d, dk, dv = 3, 4, 4, 4
X = np.random.randn(n, d)
Wq = np.random.randn(d, dk)
Wk = np.random.randn(d, dk)
Wv = np.random.randn(d, dv)

Q = X @ Wq
K = X @ Wk
V = X @ Wv

scores = Q @ K.T
def softmax(a):
  a = a - a.max(axis=-1, keepdims=True)
  exp = np.exp(a)
  return exp / exp.sum(axis=-1, keepdims=True)

A = softmax(scores)
context = A @ V

# 7) Shape reasoning example
shape_example = (np.random.randn(1024, 4096) @ np.random.randn(4096, 4096)).shape

What AI Researchers Are Comfortable With

Strong ML engineers can immediately reason about shapes like:

(1024 x 4096) x (4096 x 4096)

They know instantly:

  • compute cost
  • memory cost
  • resulting tensor shape

Understanding shapes is extremely important in transformers.

My Personal Take

If someone truly understands:

  • matrix multiplication
  • dot products
  • vector projections
  • gradients

they probably understand 60-70% of modern deep learning math already.

The rest is mainly:

  • optimization
  • probability
  • architecture design

Mental Model I Like

Matrix multiplication = information mixing.

Every neural network layer is basically mixing information between dimensions.

That simple idea explains most of deep learning.


Thanks for reading! If you want to see future content, you can follow me on Twitter or get connected over at LinkedIn.


Support My Content

If you find my content helpful, consider supporting a humanitarian cause (building homes for elderly people in rural Terai region of Nepal) that I am planning with your donation:

Ethereum (ETH)

0xB62409A5B227D2aE7D8C66fdaA5EEf4eB4E37959

Thank you for your support!