Skip to content

NeuroLLM Model Architecture

NeuroLLM is a custom transformer architecture designed for distributed training and inference across heterogeneous hardware.

Core Design

Each transformer block contains:

  • RMSNorm (pre-norm)
  • GQA Attention + RoPE + FlashAttention
  • RMSNorm (pre-norm)
  • SwiGLU FFN (up_proj, gate, down)

Architecture Specification

python
@dataclass
class ModelArchitecture:
    num_layers: int = 16        # Depth
    hidden_dim: int = 1024      # Width  
    num_heads: int = 8          # Attention heads
    num_kv_heads: int = 2       # GQA key-value heads
    ffn_dim: int = 4096         # FFN intermediate (4x hidden)
    vocab_size: int = 50257     # GPT-2 vocabulary
    max_seq_len: int = 2048     # Context length
    dropout: float = 0.0        # No dropout for inference
    rope_base: float = 10000.0  # RoPE theta

Key Innovations

1. RMSNorm (Pre-Normalization)

Uses Root Mean Square Layer Normalization instead of LayerNorm:

python
class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(dim))
        self.eps = eps

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # RMS = sqrt(mean(x^2))
        rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
        return x / rms * self.weight

Benefits:

  • Faster than LayerNorm (no mean subtraction)
  • More stable gradients for deep networks
  • Used in LLaMA, Gemma, etc.

2. Rotary Position Embeddings (RoPE)

Position information encoded via rotation matrices:

python
def apply_rope(x: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
    """Apply rotary position embeddings to x."""
    # x: (batch, heads, seq_len, head_dim)
    # freqs: (seq_len, head_dim // 2)
    
    x_complex = torch.view_as_complex(x.reshape(*x.shape[:-1], -1, 2))
    freqs_complex = torch.polar(torch.ones_like(freqs), freqs)
    x_rotated = x_complex * freqs_complex
    return torch.view_as_real(x_rotated).flatten(-2)

Properties:

  • Relative position awareness
  • Extrapolates to longer sequences
  • No learnable parameters

Precomputation:

python
def precompute_rope_freqs(dim: int, max_seq: int, theta: float = 10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
    positions = torch.arange(max_seq)
    angles = torch.outer(positions, freqs)
    return angles

3. Grouped Query Attention (GQA)

Reduces KV cache memory by sharing key-value heads:

Standard MHA:  Q[8 heads] x K[8 heads] x V[8 heads]
GQA (4:1):     Q[8 heads] x K[2 heads] x V[2 heads]
python
class GroupedQueryAttention(nn.Module):
    def __init__(self, config):
        self.num_heads = config.num_heads        # 8 query heads
        self.num_kv_heads = config.num_kv_heads  # 2 kv heads
        self.head_dim = config.hidden_dim // config.num_heads
        
        self.q_proj = nn.Linear(config.hidden_dim, config.num_heads * self.head_dim)
        self.k_proj = nn.Linear(config.hidden_dim, config.num_kv_heads * self.head_dim)
        self.v_proj = nn.Linear(config.hidden_dim, config.num_kv_heads * self.head_dim)
        self.o_proj = nn.Linear(config.num_heads * self.head_dim, config.hidden_dim)

    def forward(self, x, freqs, mask=None):
        B, L, _ = x.shape
        
        q = self.q_proj(x).view(B, L, self.num_heads, self.head_dim)
        k = self.k_proj(x).view(B, L, self.num_kv_heads, self.head_dim)
        v = self.v_proj(x).view(B, L, self.num_kv_heads, self.head_dim)
        
        # Apply RoPE
        q = apply_rope(q, freqs)
        k = apply_rope(k, freqs)
        
        # Expand KV heads to match query heads
        k = k.repeat_interleave(self.num_heads // self.num_kv_heads, dim=2)
        v = v.repeat_interleave(self.num_heads // self.num_kv_heads, dim=2)
        
        # Attention
        attn = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
        return self.o_proj(attn.reshape(B, L, -1))

Memory savings: 4x reduction in KV cache for 8:2 ratio

4. SwiGLU Activation

Gated Linear Unit with SiLU (Swish) activation:

python
class SwiGLU(nn.Module):
    def __init__(self, hidden_dim: int, ffn_dim: int):
        super().__init__()
        self.gate_proj = nn.Linear(hidden_dim, ffn_dim, bias=False)
        self.up_proj = nn.Linear(hidden_dim, ffn_dim, bias=False)
        self.down_proj = nn.Linear(ffn_dim, hidden_dim, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        gate = F.silu(self.gate_proj(x))  # Swish gating
        up = self.up_proj(x)              # Linear projection
        return self.down_proj(gate * up)  # Gated output

Benefits:

  • Better than ReLU/GELU for LLMs
  • Smooth gradients
  • Used in LLaMA, PaLM, etc.

Transformer Block

Complete block combining all components:

python
class TransformerBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.attention = GroupedQueryAttention(config)
        self.feed_forward = SwiGLU(config.hidden_dim, config.ffn_dim)
        self.attention_norm = RMSNorm(config.hidden_dim)
        self.ffn_norm = RMSNorm(config.hidden_dim)

    def forward(self, x, freqs, mask=None):
        # Pre-norm attention with residual
        h = x + self.attention(self.attention_norm(x), freqs, mask)
        # Pre-norm FFN with residual
        out = h + self.feed_forward(self.ffn_norm(h))
        return out

Architecture Scaling

Model dimensions scale based on network capacity:

Network MemoryHidden DimLayersFFN DimHeadsParameters
40 GB10241640968~350M
200 GB204824819216~3.5B
800 GB3072321228824~9.2B
4 TB5120482048040~45B
8 TB7168642867256~123B

Scaling formula:

python
def calculate_optimal_architecture(total_memory_gb: float) -> ModelArchitecture:
    # Width grows as sqrt of memory
    hidden = int(256 * math.sqrt(total_memory_gb / 10))
    hidden = ((hidden + 63) // 64) * 64  # Align to 64
    
    # Depth grows logarithmically
    layers = min(64, max(8, int(8 * math.log2(total_memory_gb / 10 + 1))))
    
    # Heads scale with width
    heads = max(4, hidden // 128)
    
    return ModelArchitecture(
        num_layers=layers,
        hidden_dim=hidden,
        num_heads=heads,
        num_kv_heads=max(1, heads // 4),
        ffn_dim=hidden * 4
    )

Sharding Strategy

Layers are distributed across nodes:

Assignment algorithm:

python
def calculate_layer_assignment(
    node_capacities: Dict[str, float],
    architecture: ModelArchitecture
) -> Dict[str, List[int]]:
    """Assign layers to nodes proportionally to their memory."""
    total_memory = sum(node_capacities.values())
    memory_per_layer = estimate_memory_per_layer(architecture)
    
    assignments = {}
    layer_idx = 0
    
    # Sort nodes for deterministic assignment
    for node_id, memory in sorted(node_capacities.items()):
        num_layers = int(memory / memory_per_layer)
        assignments[node_id] = list(range(layer_idx, layer_idx + num_layers))
        layer_idx += num_layers
    
    return assignments

Forward Pass

Distributed forward with activation streaming:

python
async def distributed_forward(tokens: torch.Tensor) -> torch.Tensor:
    # Step 1: Embedding (Driver)
    x = embedding(tokens)
    
    # Step 2: Stream through layers
    for layer_idx in range(num_layers):
        peer = router.get_peer_for_layer(layer_idx)
        x = await peer.forward_layer(layer_idx, x)
    
    # Step 3: Final projection (Validator)
    x = final_norm(x)
    logits = lm_head(x)
    
    return logits

Memory Estimation

Per-layer memory calculation:

python
def estimate_memory_per_layer(arch: ModelArchitecture) -> float:
    """Estimate memory per transformer layer in GB."""
    h = arch.hidden_dim
    f = arch.ffn_dim
    heads = arch.num_heads
    kv_heads = arch.num_kv_heads
    head_dim = h // heads
    
    # Attention parameters
    attn_params = (
        h * heads * head_dim +    # Q
        h * kv_heads * head_dim + # K
        h * kv_heads * head_dim + # V
        heads * head_dim * h      # O
    )
    
    # FFN parameters
    ffn_params = (
        h * f +  # gate_proj
        h * f +  # up_proj
        f * h    # down_proj
    )
    
    # Norms
    norm_params = h * 2
    
    total_params = attn_params + ffn_params + norm_params
    
    # 4 bytes per param (float32)
    bytes_per_layer = total_params * 4
    
    # Add 50% overhead for gradients and activations
    return bytes_per_layer * 1.5 / (1024 ** 3)

Inference Optimization

KV Cache

Stores key-value states for autoregressive generation:

python
class KVCache:
    def __init__(self, max_batch: int, max_seq: int, num_kv_heads: int, head_dim: int):
        self.k_cache = torch.zeros(max_batch, num_kv_heads, max_seq, head_dim)
        self.v_cache = torch.zeros(max_batch, num_kv_heads, max_seq, head_dim)
        self.seq_len = 0

    def update(self, k: torch.Tensor, v: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        seq_len = k.size(2)
        self.k_cache[:, :, self.seq_len:self.seq_len + seq_len] = k
        self.v_cache[:, :, self.seq_len:self.seq_len + seq_len] = v
        self.seq_len += seq_len
        return self.k_cache[:, :, :self.seq_len], self.v_cache[:, :, :self.seq_len]

Flash Attention

Uses PyTorch's scaled_dot_product_attention which automatically selects:

  • FlashAttention-2 (when available)
  • Memory-efficient attention (for longer sequences)
  • Standard attention (fallback)

Next Steps

Released under the MIT License.