import os
import re
import glob
import json
import logging
import torch
from nanochat.common import get_base_dir
from nanochat.gpt import GPT, GPTConfig
from nanochat.tokenizer import get_tokenizer
from nanochat.common import setup_default_logging
setup_default_logging()
logger = logging.getLogger(__name__)
def log0(message):
    if int(os.environ.get('RANK', 0)) == 0:
        logger.info(message)
def _detect_hc_from_state_dict(model_data):
    has_hc_keys = any(".branch.attn." in k or ".branch.mlp." in k for k in model_data.keys())
    has_static_alpha = any(".static_alpha" in k for k in model_data.keys())
    if not (has_hc_keys and has_static_alpha):
        return False, 1, None
    has_dynamic_alpha_scale = any(".dynamic_alpha_scale" in k for k in model_data.keys())
    has_pre_branch_scale = any(".pre_branch_scale" in k for k in model_data.keys())
    if has_dynamic_alpha_scale:
        hc_family = "HC"
    elif has_pre_branch_scale:
        hc_family = "mHC_variant"
    else:
        hc_family = "HC"
    for key, value in model_data.items():
        if ".attn.static_alpha" in key and value.dim() >= 1:
            if value.dim() == 2:
                num_streams = value.shape[0]
                log0(f"Detected num_residual_streams={num_streams} from static_alpha shape {value.shape}")
                return True, num_streams, hc_family
            elif value.dim() == 1:
                beta_key = key.replace("static_alpha", "static_beta")
                if beta_key in model_data:
                    num_streams = model_data[beta_key].shape[0]
                    log0(f"Detected num_residual_streams={num_streams} from static_beta shape")
                    return True, num_streams, hc_family
    log0("Warning: Detected HyperConnections structure but couldn't determine num_residual_streams, defaulting to 4")
    return True, 4, hc_family
def _patch_missing_config_keys(model_config_kwargs, model_data=None):
    if "window_pattern" not in model_config_kwargs:
        model_config_kwargs["window_pattern"] = "L"
        log0(f"Patching missing window_pattern in model config to 'L'")
    uses_hc = False
    num_streams = 1
    hc_family = None
    if model_data is not None and ("num_residual_streams" not in model_config_kwargs or "hc_type" not in model_config_kwargs):
        uses_hc, num_streams, hc_family = _detect_hc_from_state_dict(model_data)
    if "num_residual_streams" not in model_config_kwargs:
        model_config_kwargs["num_residual_streams"] = num_streams
        if uses_hc:
            log0(f"Patching missing num_residual_streams in model config to {num_streams} (detected from checkpoint)")
        else:
            log0(f"Patching missing num_residual_streams in model config to 1 (standard residuals)")
    if "hc_type" not in model_config_kwargs:
        if uses_hc:
            if hc_family == "mHC_variant":
                raise ValueError(
                    "Checkpoint uses mHC/mHC-lite/KromHC HyperConnections but 'hc_type' is missing from config. "
                    "Cannot distinguish between these variants from checkpoint alone. "
                    "Please ensure the model was saved with hc_type in the config, or manually specify it."
                )
            else:
                model_config_kwargs["hc_type"] = "HC"
                log0(f"Patching missing hc_type in model config to 'HC' (detected original HC/HC_fix from checkpoint)")
        else:
            model_config_kwargs["hc_type"] = "HC"
            log0(f"Patching missing hc_type in model config to 'HC' (default for standard model)")
def _patch_missing_keys(model_data, model_config):
    n_layer = model_config.n_layer
    use_hyper_connections = model_config.num_residual_streams > 1
    if not use_hyper_connections:
        if "resid_lambdas" not in model_data:
            model_data["resid_lambdas"] = torch.ones(n_layer)
            log0(f"Patching missing resid_lambdas in model data to 1.0")
        if "x0_lambdas" not in model_data:
            model_data["x0_lambdas"] = torch.zeros(n_layer)
            log0(f"Patching missing x0_lambdas in model data to 0.0")
def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, rank=0):
    if rank == 0:
        os.makedirs(checkpoint_dir, exist_ok=True)
        model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
        torch.save(model_data, model_path)
        logger.info(f"Saved model parameters to: {model_path}")
        meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
        with open(meta_path, "w", encoding="utf-8") as f:
            json.dump(meta_data, f, indent=2)
        logger.info(f"Saved metadata to: {meta_path}")
    if optimizer_data is not None:
        os.makedirs(checkpoint_dir, exist_ok=True)
        optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
        torch.save(optimizer_data, optimizer_path)
        logger.info(f"Saved optimizer state to: {optimizer_path}")
def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False, rank=0):
    model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
    model_data = torch.load(model_path, map_location=device)
    optimizer_data = None
    if load_optimizer:
        optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
        optimizer_data = torch.load(optimizer_path, map_location=device)
    meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
    with open(meta_path, "r", encoding="utf-8") as f:
        meta_data = json.load(f)
    return model_data, optimizer_data, meta_data
def build_model(checkpoint_dir, step, device, phase, config_overrides=None):
    assert phase in ["train", "eval"], f"Invalid phase: {phase}"
    model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, step, device, load_optimizer=False)
    if device.type in {"cpu", "mps"}:
        model_data = {
            k: v.float() if v.dtype == torch.bfloat16 else v
            for k, v in model_data.items()
        }
    model_data = {k.removeprefix("_orig_mod."): v for k, v in model_data.items()}
    model_config_kwargs = meta_data["model_config"]
    if config_overrides:
        for key, value in config_overrides.items():
            if value is not None:
                model_config_kwargs[key] = value
                log0(f"Config override: {key}={value}")
    _patch_missing_config_keys(model_config_kwargs, model_data)
    log0(f"Building model with config: {model_config_kwargs}")
    model_config = GPTConfig(**model_config_kwargs)
    _patch_missing_keys(model_data, model_config)
    with torch.device("meta"):
        model = GPT(model_config)
    model.to_empty(device=device)
    model.init_weights()
    model.load_state_dict(model_data, strict=True, assign=True)
    if phase == "eval":
        model.eval()
    else:
        model.train()
    tokenizer = get_tokenizer()
    assert tokenizer.get_vocab_size() == model_config_kwargs["vocab_size"]
    return model, tokenizer, meta_data
def find_largest_model(checkpoints_dir):
    model_tags = [f for f in os.listdir(checkpoints_dir) if os.path.isdir(os.path.join(checkpoints_dir, f))]
    if not model_tags:
        raise FileNotFoundError(f"No checkpoints found in {checkpoints_dir}")
    candidates = []
    for model_tag in model_tags:
        match = re.match(r"d(\d+)", model_tag)
        if match:
            model_depth = int(match.group(1))
            candidates.append((model_depth, model_tag))
    if candidates:
        candidates.sort(key=lambda x: x[0], reverse=True)
        return candidates[0][1]
    model_tags.sort(key=lambda x: os.path.getmtime(os.path.join(checkpoints_dir, x)), reverse=True)
    return model_tags[0]
def find_last_step(checkpoint_dir):
    checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "model_*.pt"))
    if not checkpoint_files:
        raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
    last_step = int(max(os.path.basename(f).split("_")[-1].split(".")[0] for f in checkpoint_files))
    return last_step
def load_model_from_dir(checkpoints_dir, device, phase, model_tag=None, step=None, config_overrides=None):
    if model_tag is None:
        model_tag = find_largest_model(checkpoints_dir)
        log0(f"No model tag provided, guessing model tag: {model_tag}")
    checkpoint_dir = os.path.join(checkpoints_dir, model_tag)
    if step is None:
        step = find_last_step(checkpoint_dir)
    assert step is not None, f"No checkpoints found in {checkpoint_dir}"
    log0(f"Loading model from {checkpoint_dir} with step {step}")
    model, tokenizer, meta_data = build_model(checkpoint_dir, step, device, phase, config_overrides=config_overrides)
    return model, tokenizer, meta_data
def load_model(source, *args, **kwargs):
    model_dir = {
        "base": "base_checkpoints",
        "mid": "mid_checkpoints",
        "sft": "chatsft_checkpoints",
        "rl": "chatrl_checkpoints",
    }[source]
    base_dir = get_base_dir()
    checkpoints_dir = os.path.join(base_dir, model_dir)
    return load_model_from_dir(checkpoints_dir, *args, **kwargs)