NeuroLLM Model Architecture
NeuroLLM is a custom transformer architecture designed for distributed training and inference across heterogeneous hardware.
Core Design
Each transformer block contains:
- RMSNorm (pre-norm)
- GQA Attention + RoPE + FlashAttention
- RMSNorm (pre-norm)
- SwiGLU FFN (up_proj, gate, down)
Architecture Specification
@dataclass
class ModelArchitecture:
num_layers: int = 16 # Depth
hidden_dim: int = 1024 # Width
num_heads: int = 8 # Attention heads
num_kv_heads: int = 2 # GQA key-value heads
ffn_dim: int = 4096 # FFN intermediate (4x hidden)
vocab_size: int = 50257 # GPT-2 vocabulary
max_seq_len: int = 2048 # Context length
dropout: float = 0.0 # No dropout for inference
rope_base: float = 10000.0 # RoPE thetaKey Innovations
1. RMSNorm (Pre-Normalization)
Uses Root Mean Square Layer Normalization instead of LayerNorm:
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
# RMS = sqrt(mean(x^2))
rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
return x / rms * self.weightBenefits:
- Faster than LayerNorm (no mean subtraction)
- More stable gradients for deep networks
- Used in LLaMA, Gemma, etc.
2. Rotary Position Embeddings (RoPE)
Position information encoded via rotation matrices:
def apply_rope(x: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
"""Apply rotary position embeddings to x."""
# x: (batch, heads, seq_len, head_dim)
# freqs: (seq_len, head_dim // 2)
x_complex = torch.view_as_complex(x.reshape(*x.shape[:-1], -1, 2))
freqs_complex = torch.polar(torch.ones_like(freqs), freqs)
x_rotated = x_complex * freqs_complex
return torch.view_as_real(x_rotated).flatten(-2)Properties:
- Relative position awareness
- Extrapolates to longer sequences
- No learnable parameters
Precomputation:
def precompute_rope_freqs(dim: int, max_seq: int, theta: float = 10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
positions = torch.arange(max_seq)
angles = torch.outer(positions, freqs)
return angles3. Grouped Query Attention (GQA)
Reduces KV cache memory by sharing key-value heads:
Standard MHA: Q[8 heads] x K[8 heads] x V[8 heads]
GQA (4:1): Q[8 heads] x K[2 heads] x V[2 heads]class GroupedQueryAttention(nn.Module):
def __init__(self, config):
self.num_heads = config.num_heads # 8 query heads
self.num_kv_heads = config.num_kv_heads # 2 kv heads
self.head_dim = config.hidden_dim // config.num_heads
self.q_proj = nn.Linear(config.hidden_dim, config.num_heads * self.head_dim)
self.k_proj = nn.Linear(config.hidden_dim, config.num_kv_heads * self.head_dim)
self.v_proj = nn.Linear(config.hidden_dim, config.num_kv_heads * self.head_dim)
self.o_proj = nn.Linear(config.num_heads * self.head_dim, config.hidden_dim)
def forward(self, x, freqs, mask=None):
B, L, _ = x.shape
q = self.q_proj(x).view(B, L, self.num_heads, self.head_dim)
k = self.k_proj(x).view(B, L, self.num_kv_heads, self.head_dim)
v = self.v_proj(x).view(B, L, self.num_kv_heads, self.head_dim)
# Apply RoPE
q = apply_rope(q, freqs)
k = apply_rope(k, freqs)
# Expand KV heads to match query heads
k = k.repeat_interleave(self.num_heads // self.num_kv_heads, dim=2)
v = v.repeat_interleave(self.num_heads // self.num_kv_heads, dim=2)
# Attention
attn = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
return self.o_proj(attn.reshape(B, L, -1))Memory savings: 4x reduction in KV cache for 8:2 ratio
4. SwiGLU Activation
Gated Linear Unit with SiLU (Swish) activation:
class SwiGLU(nn.Module):
def __init__(self, hidden_dim: int, ffn_dim: int):
super().__init__()
self.gate_proj = nn.Linear(hidden_dim, ffn_dim, bias=False)
self.up_proj = nn.Linear(hidden_dim, ffn_dim, bias=False)
self.down_proj = nn.Linear(ffn_dim, hidden_dim, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
gate = F.silu(self.gate_proj(x)) # Swish gating
up = self.up_proj(x) # Linear projection
return self.down_proj(gate * up) # Gated outputBenefits:
- Better than ReLU/GELU for LLMs
- Smooth gradients
- Used in LLaMA, PaLM, etc.
Transformer Block
Complete block combining all components:
class TransformerBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.attention = GroupedQueryAttention(config)
self.feed_forward = SwiGLU(config.hidden_dim, config.ffn_dim)
self.attention_norm = RMSNorm(config.hidden_dim)
self.ffn_norm = RMSNorm(config.hidden_dim)
def forward(self, x, freqs, mask=None):
# Pre-norm attention with residual
h = x + self.attention(self.attention_norm(x), freqs, mask)
# Pre-norm FFN with residual
out = h + self.feed_forward(self.ffn_norm(h))
return outArchitecture Scaling
Model dimensions scale based on network capacity:
| Network Memory | Hidden Dim | Layers | FFN Dim | Heads | Parameters |
|---|---|---|---|---|---|
| 40 GB | 1024 | 16 | 4096 | 8 | ~350M |
| 200 GB | 2048 | 24 | 8192 | 16 | ~3.5B |
| 800 GB | 3072 | 32 | 12288 | 24 | ~9.2B |
| 4 TB | 5120 | 48 | 20480 | 40 | ~45B |
| 8 TB | 7168 | 64 | 28672 | 56 | ~123B |
Scaling formula:
def calculate_optimal_architecture(total_memory_gb: float) -> ModelArchitecture:
# Width grows as sqrt of memory
hidden = int(256 * math.sqrt(total_memory_gb / 10))
hidden = ((hidden + 63) // 64) * 64 # Align to 64
# Depth grows logarithmically
layers = min(64, max(8, int(8 * math.log2(total_memory_gb / 10 + 1))))
# Heads scale with width
heads = max(4, hidden // 128)
return ModelArchitecture(
num_layers=layers,
hidden_dim=hidden,
num_heads=heads,
num_kv_heads=max(1, heads // 4),
ffn_dim=hidden * 4
)Sharding Strategy
Layers are distributed across nodes:
Assignment algorithm:
def calculate_layer_assignment(
node_capacities: Dict[str, float],
architecture: ModelArchitecture
) -> Dict[str, List[int]]:
"""Assign layers to nodes proportionally to their memory."""
total_memory = sum(node_capacities.values())
memory_per_layer = estimate_memory_per_layer(architecture)
assignments = {}
layer_idx = 0
# Sort nodes for deterministic assignment
for node_id, memory in sorted(node_capacities.items()):
num_layers = int(memory / memory_per_layer)
assignments[node_id] = list(range(layer_idx, layer_idx + num_layers))
layer_idx += num_layers
return assignmentsForward Pass
Distributed forward with activation streaming:
async def distributed_forward(tokens: torch.Tensor) -> torch.Tensor:
# Step 1: Embedding (Driver)
x = embedding(tokens)
# Step 2: Stream through layers
for layer_idx in range(num_layers):
peer = router.get_peer_for_layer(layer_idx)
x = await peer.forward_layer(layer_idx, x)
# Step 3: Final projection (Validator)
x = final_norm(x)
logits = lm_head(x)
return logitsMemory Estimation
Per-layer memory calculation:
def estimate_memory_per_layer(arch: ModelArchitecture) -> float:
"""Estimate memory per transformer layer in GB."""
h = arch.hidden_dim
f = arch.ffn_dim
heads = arch.num_heads
kv_heads = arch.num_kv_heads
head_dim = h // heads
# Attention parameters
attn_params = (
h * heads * head_dim + # Q
h * kv_heads * head_dim + # K
h * kv_heads * head_dim + # V
heads * head_dim * h # O
)
# FFN parameters
ffn_params = (
h * f + # gate_proj
h * f + # up_proj
f * h # down_proj
)
# Norms
norm_params = h * 2
total_params = attn_params + ffn_params + norm_params
# 4 bytes per param (float32)
bytes_per_layer = total_params * 4
# Add 50% overhead for gradients and activations
return bytes_per_layer * 1.5 / (1024 ** 3)Inference Optimization
KV Cache
Stores key-value states for autoregressive generation:
class KVCache:
def __init__(self, max_batch: int, max_seq: int, num_kv_heads: int, head_dim: int):
self.k_cache = torch.zeros(max_batch, num_kv_heads, max_seq, head_dim)
self.v_cache = torch.zeros(max_batch, num_kv_heads, max_seq, head_dim)
self.seq_len = 0
def update(self, k: torch.Tensor, v: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
seq_len = k.size(2)
self.k_cache[:, :, self.seq_len:self.seq_len + seq_len] = k
self.v_cache[:, :, self.seq_len:self.seq_len + seq_len] = v
self.seq_len += seq_len
return self.k_cache[:, :, :self.seq_len], self.v_cache[:, :, :self.seq_len]Flash Attention
Uses PyTorch's scaled_dot_product_attention which automatically selects:
- FlashAttention-2 (when available)
- Memory-efficient attention (for longer sequences)
- Standard attention (fallback)
Next Steps
- Dynamic Scaling — How architecture adapts
- DiLoCo Protocol — Distributed training
- Swarm Aggregation — Byzantine tolerance
