from typing import Dict, Any
from pathlib import Path
import math

root_dir = Path("src").absolute()

StoriesBaseArgs: Dict[str, Any] = {
    "data_dirs": [root_dir / "data/stories"],
    "timestamp": None,
    "log_level": "INFO",
    "res_dir": None,
    "aux_batch_limit": None,
    "core_batch_limit": None,
    "seed": -1,
    "batch_size": 128,
    "epochs": 1,
    "lr": 5e-3,
    "lr_schedule": True,
    "arbsub": False,
    "test_ood": False,
    "do_compile": True,
    "aux_labels": [],
    "core_labels": None,
    "ctx_len": 256,
    "num_layers": 8,
    "embed_dim": 512,
    "mlp_dim": 512 * 4,
    "stages": [],
    "do_cleanup_distributed": True,
    "accumulation_steps": 1,
    "optimize_routed_training": False,
}

RealisticBaseArgs: Dict[str, Any] = {
    "data_dirs": [root_dir / "data/fineweb", root_dir / "data/bigcode", root_dir / "data/arxiv"],
    "timestamp": None,
    "log_level": "DEBUG",
    "res_dir": None,
    "aux_batch_limit": 0.05,
    "core_batch_limit": "optimal",
    "seed": -1,
    "batch_size": 16,
    "epochs": 1,
    "lr": 1.2e-3,
    "lr_schedule": True,
    "core_labels": None,
    "arbsub": False,
    "test_ood": False,
    "do_compile": True,
    "aux_labels": ["bigcode", "biology", "nuclear", "cyber"],
    "ctx_len": 1024,
    "num_layers": 20,
    "embed_dim": 1536,
    "mlp_dim": 1536 * 4,
    "stages": [],
    "do_cleanup_distributed": True,
    "accumulation_steps": 1,
    "optimize_routed_training": True,
}

def _calc_exact_params(embed_dim: int, num_layers: int, vocab_size: int = 50304,
                       num_heads: int = 8, num_kv: int = 2) -> int:
    """Calculate exact parameter count matching base.py architecture."""
    d = embed_dim
    L = num_layers
    V = vocab_size
    head_dim = d // num_heads
    
    # Per-layer params
    # Attention with GQA
    attn_q = d * d + d  # Q projection + bias
    attn_kv = d * (2 * num_kv * head_dim) + (2 * num_kv * head_dim)  # KV + bias
    attn_o = d * d + d  # O projection + bias
    # MLP
    mlp_fc = d * 4 * d + 4 * d  # up projection + bias
    mlp_proj = 4 * d * d + d  # down projection + bias
    # Norms (RMSNorm has d params each)
    norms = d + d
    
    layer_params = attn_q + attn_kv + attn_o + mlp_fc + mlp_proj + norms
    
    # Global params
    embed = V * d
    unembed = V * d + V  # weight + bias
    final_norm = d
    
    return embed + unembed + final_norm + L * layer_params


def calc_realistic_model_params(target_params: int) -> Dict[str, int]:
    """
    Calculate model dimensions to achieve target parameter count
    while maintaining L/d ratio within ±5% of 0.015 for clean scaling laws.
    """
    V = 50304  # vocab size
    r = 0.015  # target shape ratio
    tol = 0.05  # shape tolerance
    
    # Solve 10.5*r*d³ + 2*V*d ≈ P for initial estimate via Newton's method
    d = ((target_params - V) / (10.5 * r)) ** (1/3)
    for _ in range(10):
        f = 10.5 * r * d**3 + 2 * V * d - (target_params - V)
        d -= f / (31.5 * r * d**2 + 2 * V)
    
    # Search nearby embed_dims (multiples of 32 for Rotary compatibility: head_dim must be divisible by 4)
    ALIGN = 32
    base_d = ALIGN * round(d / ALIGN)
    candidates = []
    for embed_dim in range(max(128, base_d - 3 * ALIGN), base_d + 4 * ALIGN, ALIGN):
        L_min = max(2, math.ceil(embed_dim * r * (1 - tol)))
        L_max = math.floor(embed_dim * r * (1 + tol))
        for num_layers in range(L_min, L_max + 1):
            params = _calc_exact_params(embed_dim, num_layers, V)
            candidates.append((abs(params - target_params), embed_dim, num_layers))
    
    _, best_d, best_L = min(candidates) if candidates else (0, max(128, base_d), 2)
    return {"embed_dim": best_d, "num_layers": best_L, "mlp_dim": best_d * 4}