Skip to content

Mathematical Foundations

This document provides a complete mathematical treatment of all algorithms and techniques used in NeuroShard's distributed LLM training system. Every equation is explained with intuition and derivation.

1. Training Objective

1.1 Language Modeling Loss

NeuroShard trains a causal language model to predict the next token. Given a sequence of tokens x1,x2,,xT, the objective is to maximize:

L(θ)=1Tt=1TlogPθ(xt|x<t)

Where:

  • θ = model parameters
  • xt = token at position t
  • x<t = all tokens before position t
  • Pθ = probability distribution from the model

1.2 Cross-Entropy Loss

The model outputs logits zRV (where V is vocabulary size), converted to probabilities via softmax:

P(xt=k|x<t)=exp(zk)j=1Vexp(zj)=softmax(z)k

The cross-entropy loss for a single token with true label y is:

LCE=logP(y)=zy+logj=1Vexp(zj)

Gradient with respect to logits:

LCEzk=P(k)1[k=y]=softmax(z)k1[k=y]

This elegant result shows the gradient is simply the difference between predicted probability and the one-hot target.


2. The DiLoCo Algorithm

DiLoCo (Distributed Low-Communication training) is a two-level optimization algorithm that reduces communication by orders of magnitude.

2.1 Algorithm Overview

Inner Loop (local, no communication):

θt+1(i)=θt(i)ηinnergt(i)

Where gt(i)=θL(θt(i),Bt(i)) is the gradient on node i for batch Bt(i).

Pseudo-Gradient Computation (after H inner steps):

Δθ(i)=θ0(i)θH(i)=t=0H1ηinnergt(i)

Aggregation (across N nodes):

Δθ¯=Aggregate(Δθ(1),Δθ(2),,Δθ(N))

Outer Loop (Nesterov momentum update):

θnew=θ0+ηouterNesterov(Δθ¯)

2.2 Why Pseudo-Gradients Approximate True Gradients

Over H inner steps, the pseudo-gradient accumulates:

Δθ=ηinnert=0H1gt

By the law of large numbers, as H:

1Ht=0H1gtE[g]=L(θ)

Therefore:

ΔθHηinnerL(θ)

The pseudo-gradient points in the same direction as the true gradient, scaled by Hηinner.

2.3 Convergence Guarantee

Under standard assumptions (L-smooth loss, bounded variance σ2):

E[L(θT)]L(θ)O(1TH)

This matches the convergence rate of synchronous SGD while requiring H× less communication.


3. The Inner Optimizer: AdamW

The inner loop uses AdamW, which combines Adam's adaptive learning rates with decoupled weight decay.

3.1 Algorithm

Given gradient gt=θL(θt):

Moment estimates:

mt=β1mt1+(1β1)gt(first moment / mean)vt=β2vt1+(1β2)gt2(second moment / variance)

Bias correction:

m^t=mt1β1t,v^t=vt1β2t

Update with decoupled weight decay:

θt+1=θtη(m^tv^t+ϵ+λθt)

3.2 Hyperparameters

ParameterSymbolDefaultPurpose
Learning rateη104Step size
First moment decayβ10.9Gradient momentum
Second moment decayβ20.95Variance estimation
Epsilonϵ108Numerical stability
Weight decayλ0.1L2 regularization

3.3 Intuition

  • First moment (mt): Exponential moving average of gradients → provides momentum
  • Second moment (vt): Exponential moving average of squared gradients → adapts learning rate per parameter
  • Bias correction: Compensates for initialization at zero (important early in training)
  • Decoupled weight decay: Unlike L2 regularization, applies decay directly to weights, not through gradients

4. The Outer Optimizer: Nesterov Momentum

The outer loop applies Nesterov accelerated gradient descent to pseudo-gradients.

4.1 Standard Momentum

Classical momentum:

vt=μvt1+Δθtθt+1=θt+ηvt

4.2 Nesterov Momentum (Look-Ahead)

Nesterov momentum evaluates the gradient at a "look-ahead" point:

vt=μvt1+Δθtθt+1=θt+η(μvt+Δθt)

Expanded form:

θt+1=θt+ημ(μvt1+Δθt)+ηΔθt

4.3 Why Nesterov Works Better

The key insight is that Nesterov momentum makes a correction based on where momentum will take us, not where we currently are:

Standard:   θ → θ + μv → evaluate gradient → update
Nesterov:   θ → θ + μv → evaluate gradient at look-ahead → correct update

This "look-ahead" property provides:

  • Faster convergence near minima
  • Better handling of curved loss surfaces
  • Automatic slowdown when overshooting

4.4 Implementation

python
# Nesterov momentum update
v = μ * v + Δθ                    # Update velocity
θ = θ + η ** v + Δθ)          # Apply with look-ahead

This is equivalent to:

θt+1=θt+ημ2vt1+η(1+μ)Δθt

5. Byzantine-Tolerant Aggregation

When aggregating gradients from potentially malicious nodes, we need robust methods.

5.1 Problem Formulation

Given N gradient contributions {Δθ(1),,Δθ(N)} where up to f may be Byzantine (arbitrary), find an aggregate Δθ¯ such that training converges.

5.2 Simple Mean (Vulnerable)

Δθ¯=1Ni=1NΔθ(i)

Vulnerability: A single Byzantine node can set Δθ(bad)=M for arbitrarily large M, corrupting the mean.

5.3 Coordinate-Wise Median

For each parameter j:

Δθ¯j=median(Δθj(1),,Δθj(N))

Robustness: Tolerates up to (N1)/2 Byzantine nodes.

Limitation: High variance compared to mean; ignores correlation between coordinates.

5.4 Trimmed Mean

Remove the top and bottom α fraction of values, then average:

Δθ¯j=1N2ki=k+1NkΔθj,sorted(i)

Where k=αN.

Default: α=0.1 (remove top 10% and bottom 10%)

Robustness: Tolerates up to α fraction of Byzantine nodes.

5.5 Krum

Select the gradient closest to the majority.

Score function (for each gradient i):

S(i)=jNiΔθ(i)Δθ(j)2

Where Ni is the set of (Nf2) nearest neighbors of i.

Selection:

i=argminiS(i)Δθ¯=Δθ(i)

Robustness: Provably robust when N2f+3.

Theorem (Blanchard et al., 2017): If at most f of N gradients are Byzantine, Krum selects a gradient Δθ(i) such that:

Δθ(i)L2(2f+2)σ2

where σ2 is the variance of honest gradients.

5.6 Multi-Krum

Average the top m gradients by Krum score:

Δθ¯=1miMΔθ(i)

Where M contains the m=Nf indices with lowest Krum scores.

Benefit: Lower variance than Krum while maintaining robustness.

5.7 Geometric Median

Find the point minimizing sum of Euclidean distances:

Δθ¯=argminxi=1NxΔθ(i)2

Weiszfeld Algorithm (iterative solution):

x(t+1)=i=1NΔθ(i)x(t)Δθ(i)2i=1N1x(t)Δθ(i)2

Robustness: Optimal breakdown point of (N1)/2.

5.8 Comparison Table

MethodByzantine ToleranceVarianceComplexity
Mean0LowestO(N)
Median(N1)/2HighO(NlogN)
Trimmed MeanαNLowO(NlogN)
Krum(N3)/2Very HighO(N2d)
Multi-Krum(N3)/2MediumO(N2d)
Geometric Median(N1)/2LowO(Niter)

Where d is the number of parameters.


6. Gradient Validation

Before aggregation, incoming gradients are validated.

6.1 Cosine Similarity Check

Measures alignment between submitted gradient gs and reference gradient gr:

cos(gs,gr)=gsgrgs2gr2

Rejection criterion: cos(gs,gr)<τ (default τ=0.3)

Intuition: Honest gradients should point in similar directions (same optimization target). Anti-correlated gradients suggest malicious intent.

6.2 Magnitude Ratio Check

ρ=gs2gr2

Rejection criterion: ρ>ρmax or ρ<ρmin (default: 10× range)

Intuition: Gradients should have similar scale. Extreme magnitudes suggest scaling attacks.

6.3 Variance Ratio Check

Var(gs)Var(gr)>Vmax

Intuition: Abnormally high variance suggests noise injection.


7. Gradient Compression

For bandwidth efficiency, gradients are compressed before transmission.

7.1 Top-K Sparsification

Keep only the k largest magnitude elements:

TopK(g,k)={(i,gi):iargtopk(|g|,k)}

Sparsity: k=0.1d (keep 10%)

Error bound: The approximation error is bounded by the sum of discarded elements:

gTopK(g,k)22=iTopKgi2

7.2 Quantization

Map floating-point values to integers:

q(x)=round(x2b11max|x|)

Dequantization:

x^=q(x)max|x|2b11

Quantization error (per element):

|xx^|max|x|2b2

For 8-bit quantization with max|x|=1:

|xx^|12540.4%

7.3 Why Compression Works

Theorem (Stich et al., 2018): SGD with compressed gradients converges at rate:

E[L(θT)]L(θ)O(1T+ωT)

Where ω is the compression ratio. The extra ω/T term vanishes asymptotically.

Intuition:

  1. SGD gradients are already noisy (mini-batch variance)
  2. Compression error is much smaller than mini-batch noise
  3. Averaging across nodes cancels compression errors (Central Limit Theorem)

8. Model Architecture Mathematics

8.1 RMS Normalization

Root Mean Square Layer Normalization:

RMSNorm(x)=xRMS(x)γ

Where:

RMS(x)=1di=1dxi2+ϵ

Compared to LayerNorm:

LayerNorm(x)=xμσγ+β

RMSNorm omits the mean subtraction and bias, making it:

  • ~10% faster to compute
  • More stable for very deep networks
  • Empirically equivalent performance

8.2 Rotary Position Embeddings (RoPE)

RoPE encodes position through rotation in 2D subspaces.

Rotation matrix for position m and frequency θi:

Rθi,m=(cos(mθi)sin(mθi)sin(mθi)cos(mθi))

Frequency schedule:

θi=100002i/d

Application to query/key vectors (treating pairs of dimensions):

RoPE(x,m)=(Rθ0,mRθ1,m)x

Key property (relative position awareness):

RoPE(q,m),RoPE(k,n)=q,Rθ,nmk

The attention score depends only on the relative position (nm), not absolute positions.

Complex number formulation (equivalent, more elegant):

RoPE(x,m)=xeimθ

Where x is viewed as complex numbers and is element-wise multiplication.

8.3 Grouped Query Attention (GQA)

Standard multi-head attention has H heads for Q, K, and V. GQA uses fewer KV heads.

Projections:

Q=xWQRB×L×H×dhK=xWKRB×L×G×dhV=xWVRB×L×G×dh

Where G<H is the number of KV groups.

Head expansion (repeat KV heads to match query heads):

K=repeat(K,H/G),V=repeat(V,H/G)

Attention computation:

Attention(Q,K,V)=softmax(QKTdh)V

Memory savings: KV cache reduced by factor H/G (e.g., 4× for H=8,G=2).

8.4 Scaled Dot-Product Attention

Attention(Q,K,V)=softmax(QKTdk+M)V

Where:

  • QRLq×dk = queries
  • KRLk×dk = keys
  • VRLk×dv = values
  • M = causal mask ( for future positions)

Why scale by dk?

If q,k have unit variance, then:

Var(qk)=dk

Scaling by dk restores unit variance:

Var(qkdk)=1

This prevents softmax saturation (extreme probabilities) which would cause vanishing gradients.

8.5 SwiGLU Activation

A gated linear unit with SiLU (Swish) activation:

SwiGLU(x)=SiLU(xWgate)(xWup)

Where:

SiLU(x)=xσ(x)=x1+ex

Full FFN block:

FFN(x)=((SiLU(xWgate)(xWup))Wdown

Why gating helps:

  • Allows the network to selectively pass information
  • Smoother gradients than ReLU
  • Empirically better performance for LLMs

Comparison of activations:

ActivationFormulaGradient
ReLUmax(0,x)1[x>0]
GELUxΦ(x)Smooth
SiLU/Swishxσ(x)σ(x)(1+x(1σ(x)))

9. Transformer Forward Pass

9.1 Single Layer

For input xRB×L×d:

# Pre-norm attention
h = x + Attention(RMSNorm(x))

# Pre-norm FFN  
out = h + FFN(RMSNorm(h))

Mathematically:

h=x+Attention(RMSNorm(x))out=h+FFN(RMSNorm(h))

9.2 Full Forward Pass

# Embedding
h_0 = Embed(tokens)

# Transformer layers
for l in range(L):
    h_{l+1} = TransformerBlock_l(h_l)

# Output
logits = LMHead(RMSNorm(h_L))

9.3 Parameter Count

For a model with:

  • d = hidden dimension
  • L = number of layers
  • H = attention heads
  • G = KV heads
  • dh = head dimension = d/H
  • dff = FFN intermediate dimension
  • V = vocabulary size

Per-layer parameters:

ComponentParameters
Q projectiond×d
K projectiond×(Gdh)
V projectiond×(Gdh)
O projectiond×d
Gate projectiond×dff
Up projectiond×dff
Down projectiondff×d
RMSNorm (×2)2d

Total:

P=Vd+L(2d2+2dGdh+3ddff+2d)+d+Vd

Simplified (assuming G=H/4, dff=4d, tied embeddings):

PVd+L(2.5d2+12d2)Vd+14.5Ld2

10. Backpropagation Through Transformers

10.1 Gradient Flow

The gradient of loss with respect to layer l input:

Lhl=Lhl+1(I+Blocklhl)

The residual connection (I) ensures gradients flow directly, preventing vanishing gradients.

10.2 Gradient Clipping

Before applying gradients, clip the global norm:

g={gif g2ccgg2otherwise

Where c is the maximum norm (default: 1.0).

Purpose: Prevents exploding gradients from destabilizing training.


11. Complete Training Algorithm

Putting it all together:

Algorithm: NeuroShard DiLoCo Training

Inputs:

  • Model fθ with parameters θ
  • Inner optimizer (AdamW) with learning rate ηinner
  • Outer optimizer (Nesterov) with learning rate ηouter, momentum μ
  • Inner steps H, nodes N, aggregation function Agg

For each outer step k=1,2,:

  1. Save initial weights: θ0(i)θ for all nodes i

  2. Inner loop (on each node i independently):

    for t = 0 to H-1:
        Sample batch B_t^{(i)}
        Compute loss: L = CrossEntropy(f_θ(B_t), labels)
        Compute gradient: g_t = ∇_θ L
        Clip gradient: g_t = clip(g_t, max_norm)
        Update: θ = AdamW(θ, g_t)
  3. Compute pseudo-gradient:

    Δθ(i)=θ0(i)θ(i)
  4. Compress (optional):

    Δθcompressed(i)=Quantize(TopK(Δθ(i)))
  5. Exchange via gossip protocol

  6. Validate each received gradient:

    for each peer gradient Δθ^{(j)}:
        if cosine_sim(Δθ^{(j)}, Δθ^{(i)}) < τ: reject
        if magnitude_ratio out of bounds: reject
  7. Aggregate:

    Δθ¯=TrimmedMean({Δθ(i)}valid)
  8. Outer update (Nesterov):

    v = μ * v + Δθ_bar
    θ = θ_0 + η_outer * (μ * v + Δθ_bar)
  9. Broadcast new θ to all nodes


12. Convergence Analysis

12.1 Assumptions

  1. L-smoothness: L(θ)L(ϕ)Lθϕ
  2. Bounded variance: E[gL2]σ2
  3. Bounded gradients: L(θ)G

12.2 Main Result

Theorem: Under the above assumptions, DiLoCo with N nodes, H inner steps, and appropriate learning rates achieves:

1Tk=1TE[L(θk)2]O(L(θ0)LηTH+ηLσ2N+η2L2Hσ2)

Optimal learning rate: η=O(NTHLσ2)

Resulting convergence rate:

O(L(L0L)σ2NTH)

This shows:

  • Linear speedup with N nodes ✓
  • Convergence improves with more inner steps H
  • Same asymptotic rate as synchronous SGD ✓

13. Summary of Key Equations

ConceptEquation
Cross-Entropy LossL=logsoftmax(z)y
AdamW Updateθ=θη(m^/v^+λθ)
Nesterov Momentumθ=θ+η(μv+Δθ)
Pseudo-GradientΔθ=θ0θH
Trimmed Meanx¯=mean(xk+1:nk)
Krum ScoreS(i)=jNi|gigj|2
RMSNormx^=x/RMS(x)γ
RoPERoPE(x,m)=xeimθ
Attentionsoftmax(QKT/dk)V
SwiGLUSiLU(xWg)(xWu)

References

  1. DiLoCo: Douillard et al., "DiLoCo: Distributed Low-Communication Training of Language Models" (2023)
  2. AdamW: Loshchilov & Hutter, "Decoupled Weight Decay Regularization" (2019)
  3. Nesterov: Nesterov, "A method for solving the convex programming problem with convergence rate O(1/k²)" (1983)
  4. Krum: Blanchard et al., "Machine Learning with Adversaries: Byzantine Tolerant Gradient Descent" (2017)
  5. RoPE: Su et al., "RoFormer: Enhanced Transformer with Rotary Position Embedding" (2021)
  6. GQA: Ainslie et al., "GQA: Training Generalized Multi-Query Transformer Models" (2023)
  7. SwiGLU: Shazeer, "GLU Variants Improve Transformer" (2020)
  8. Gradient Compression: Stich et al., "Sparsified SGD with Memory" (2018)

Released under the Apache License 2.0.