Mathematical Foundations
This document provides a complete mathematical treatment of all algorithms and techniques used in NeuroShard's distributed LLM training system. Every equation is explained with intuition and derivation.
1. Training Objective
1.1 Language Modeling Loss
NeuroShard trains a causal language model to predict the next token. Given a sequence of tokens
Where:
= model parameters = token at position = all tokens before position = probability distribution from the model
1.2 Cross-Entropy Loss
The model outputs logits
The cross-entropy loss for a single token with true label
Gradient with respect to logits:
This elegant result shows the gradient is simply the difference between predicted probability and the one-hot target.
2. The DiLoCo Algorithm
DiLoCo (Distributed Low-Communication training) is a two-level optimization algorithm that reduces communication by orders of magnitude.
2.1 Algorithm Overview
Inner Loop (local, no communication):
Where
Pseudo-Gradient Computation (after
Aggregation (across
Outer Loop (Nesterov momentum update):
2.2 Why Pseudo-Gradients Approximate True Gradients
Over
By the law of large numbers, as
Therefore:
The pseudo-gradient points in the same direction as the true gradient, scaled by
2.3 Convergence Guarantee
Under standard assumptions (L-smooth loss, bounded variance
This matches the convergence rate of synchronous SGD while requiring
3. The Inner Optimizer: AdamW
The inner loop uses AdamW, which combines Adam's adaptive learning rates with decoupled weight decay.
3.1 Algorithm
Given gradient
Moment estimates:
Bias correction:
Update with decoupled weight decay:
3.2 Hyperparameters
| Parameter | Symbol | Default | Purpose |
|---|---|---|---|
| Learning rate | Step size | ||
| First moment decay | 0.9 | Gradient momentum | |
| Second moment decay | 0.95 | Variance estimation | |
| Epsilon | Numerical stability | ||
| Weight decay | 0.1 | L2 regularization |
3.3 Intuition
- First moment (
): Exponential moving average of gradients → provides momentum - Second moment (
): Exponential moving average of squared gradients → adapts learning rate per parameter - Bias correction: Compensates for initialization at zero (important early in training)
- Decoupled weight decay: Unlike L2 regularization, applies decay directly to weights, not through gradients
4. The Outer Optimizer: Nesterov Momentum
The outer loop applies Nesterov accelerated gradient descent to pseudo-gradients.
4.1 Standard Momentum
Classical momentum:
4.2 Nesterov Momentum (Look-Ahead)
Nesterov momentum evaluates the gradient at a "look-ahead" point:
Expanded form:
4.3 Why Nesterov Works Better
The key insight is that Nesterov momentum makes a correction based on where momentum will take us, not where we currently are:
Standard: θ → θ + μv → evaluate gradient → update
Nesterov: θ → θ + μv → evaluate gradient at look-ahead → correct updateThis "look-ahead" property provides:
- Faster convergence near minima
- Better handling of curved loss surfaces
- Automatic slowdown when overshooting
4.4 Implementation
# Nesterov momentum update
v = μ * v + Δθ # Update velocity
θ = θ + η * (μ * v + Δθ) # Apply with look-aheadThis is equivalent to:
5. Byzantine-Tolerant Aggregation
When aggregating gradients from potentially malicious nodes, we need robust methods.
5.1 Problem Formulation
Given
5.2 Simple Mean (Vulnerable)
Vulnerability: A single Byzantine node can set
5.3 Coordinate-Wise Median
For each parameter
Robustness: Tolerates up to
Limitation: High variance compared to mean; ignores correlation between coordinates.
5.4 Trimmed Mean
Remove the top and bottom
Where
Default:
Robustness: Tolerates up to
5.5 Krum
Select the gradient closest to the majority.
Score function (for each gradient
Where
Selection:
Robustness: Provably robust when
Theorem (Blanchard et al., 2017): If at most
where
5.6 Multi-Krum
Average the top
Where
Benefit: Lower variance than Krum while maintaining robustness.
5.7 Geometric Median
Find the point minimizing sum of Euclidean distances:
Weiszfeld Algorithm (iterative solution):
Robustness: Optimal breakdown point of
5.8 Comparison Table
| Method | Byzantine Tolerance | Variance | Complexity |
|---|---|---|---|
| Mean | 0 | Lowest | |
| Median | High | ||
| Trimmed Mean | Low | ||
| Krum | Very High | ||
| Multi-Krum | Medium | ||
| Geometric Median | Low |
Where
6. Gradient Validation
Before aggregation, incoming gradients are validated.
6.1 Cosine Similarity Check
Measures alignment between submitted gradient
Rejection criterion:
Intuition: Honest gradients should point in similar directions (same optimization target). Anti-correlated gradients suggest malicious intent.
6.2 Magnitude Ratio Check
Rejection criterion:
Intuition: Gradients should have similar scale. Extreme magnitudes suggest scaling attacks.
6.3 Variance Ratio Check
Intuition: Abnormally high variance suggests noise injection.
7. Gradient Compression
For bandwidth efficiency, gradients are compressed before transmission.
7.1 Top-K Sparsification
Keep only the
Sparsity:
Error bound: The approximation error is bounded by the sum of discarded elements:
7.2 Quantization
Map floating-point values to integers:
Dequantization:
Quantization error (per element):
For 8-bit quantization with
7.3 Why Compression Works
Theorem (Stich et al., 2018): SGD with compressed gradients converges at rate:
Where
Intuition:
- SGD gradients are already noisy (mini-batch variance)
- Compression error is much smaller than mini-batch noise
- Averaging across nodes cancels compression errors (Central Limit Theorem)
8. Model Architecture Mathematics
8.1 RMS Normalization
Root Mean Square Layer Normalization:
Where:
Compared to LayerNorm:
RMSNorm omits the mean subtraction and bias, making it:
- ~10% faster to compute
- More stable for very deep networks
- Empirically equivalent performance
8.2 Rotary Position Embeddings (RoPE)
RoPE encodes position through rotation in 2D subspaces.
Rotation matrix for position
Frequency schedule:
Application to query/key vectors (treating pairs of dimensions):
Key property (relative position awareness):
The attention score depends only on the relative position
Complex number formulation (equivalent, more elegant):
Where
8.3 Grouped Query Attention (GQA)
Standard multi-head attention has
Projections:
Where
Head expansion (repeat KV heads to match query heads):
Attention computation:
Memory savings: KV cache reduced by factor
8.4 Scaled Dot-Product Attention
Where:
= queries = keys = values = causal mask ( for future positions)
Why scale by
If
Scaling by
This prevents softmax saturation (extreme probabilities) which would cause vanishing gradients.
8.5 SwiGLU Activation
A gated linear unit with SiLU (Swish) activation:
Where:
Full FFN block:
Why gating helps:
- Allows the network to selectively pass information
- Smoother gradients than ReLU
- Empirically better performance for LLMs
Comparison of activations:
| Activation | Formula | Gradient |
|---|---|---|
| ReLU | ||
| GELU | Smooth | |
| SiLU/Swish |
9. Transformer Forward Pass
9.1 Single Layer
For input
# Pre-norm attention
h = x + Attention(RMSNorm(x))
# Pre-norm FFN
out = h + FFN(RMSNorm(h))Mathematically:
9.2 Full Forward Pass
# Embedding
h_0 = Embed(tokens)
# Transformer layers
for l in range(L):
h_{l+1} = TransformerBlock_l(h_l)
# Output
logits = LMHead(RMSNorm(h_L))9.3 Parameter Count
For a model with:
= hidden dimension = number of layers = attention heads = KV heads = head dimension = = FFN intermediate dimension = vocabulary size
Per-layer parameters:
| Component | Parameters |
|---|---|
| Q projection | |
| K projection | |
| V projection | |
| O projection | |
| Gate projection | |
| Up projection | |
| Down projection | |
| RMSNorm (×2) |
Total:
Simplified (assuming
10. Backpropagation Through Transformers
10.1 Gradient Flow
The gradient of loss with respect to layer
The residual connection (
10.2 Gradient Clipping
Before applying gradients, clip the global norm:
Where
Purpose: Prevents exploding gradients from destabilizing training.
11. Complete Training Algorithm
Putting it all together:
Algorithm: NeuroShard DiLoCo Training
Inputs:
- Model
with parameters - Inner optimizer (AdamW) with learning rate
- Outer optimizer (Nesterov) with learning rate
, momentum - Inner steps
, nodes , aggregation function
For each outer step
Save initial weights:
for all nodes Inner loop (on each node
independently): for t = 0 to H-1: Sample batch B_t^{(i)} Compute loss: L = CrossEntropy(f_θ(B_t), labels) Compute gradient: g_t = ∇_θ L Clip gradient: g_t = clip(g_t, max_norm) Update: θ = AdamW(θ, g_t)Compute pseudo-gradient:
Compress (optional):
Exchange via gossip protocol
Validate each received gradient:
for each peer gradient Δθ^{(j)}: if cosine_sim(Δθ^{(j)}, Δθ^{(i)}) < τ: reject if magnitude_ratio out of bounds: rejectAggregate:
Outer update (Nesterov):
v = μ * v + Δθ_bar θ = θ_0 + η_outer * (μ * v + Δθ_bar)Broadcast new
to all nodes
12. Convergence Analysis
12.1 Assumptions
- L-smoothness:
- Bounded variance:
- Bounded gradients:
12.2 Main Result
Theorem: Under the above assumptions, DiLoCo with
Optimal learning rate:
Resulting convergence rate:
This shows:
- Linear speedup with
nodes ✓ - Convergence improves with more inner steps
✓ - Same asymptotic rate as synchronous SGD ✓
13. Summary of Key Equations
| Concept | Equation |
|---|---|
| Cross-Entropy Loss | |
| AdamW Update | |
| Nesterov Momentum | |
| Pseudo-Gradient | |
| Trimmed Mean | |
| Krum Score | |
| RMSNorm | |
| RoPE | |
| Attention | |
| SwiGLU |
References
- DiLoCo: Douillard et al., "DiLoCo: Distributed Low-Communication Training of Language Models" (2023)
- AdamW: Loshchilov & Hutter, "Decoupled Weight Decay Regularization" (2019)
- Nesterov: Nesterov, "A method for solving the convex programming problem with convergence rate O(1/k²)" (1983)
- Krum: Blanchard et al., "Machine Learning with Adversaries: Byzantine Tolerant Gradient Descent" (2017)
- RoPE: Su et al., "RoFormer: Enhanced Transformer with Rotary Position Embedding" (2021)
- GQA: Ainslie et al., "GQA: Training Generalized Multi-Query Transformer Models" (2023)
- SwiGLU: Shazeer, "GLU Variants Improve Transformer" (2020)
- Gradient Compression: Stich et al., "Sparsified SGD with Memory" (2018)
