from __future__ import annotations

import argparse
import copy
import json
import logging
import math
import random
import time
from contextlib import contextmanager
from pathlib import Path
from typing import Dict, Optional

import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import get_cosine_schedule_with_warmup
from tqdm.auto import tqdm

try:
    import wandb
    if not hasattr(wandb, "init"):
        wandb = None
except ImportError:
    wandb = None

from data import (
    MaskPruningCollator,
    build_chatml_training_dataset,
    build_training_dataset,
    load_dataset_specs,
)
from models import load_llama3_model
from models.mask_predictor import MaskPredictorOutput
from pruning.ffn_prune import prune_llama_ffn
from pruning.ffn_mask import clear_all_masks, set_layer_mask
from pruning.soft_topk import hard_topk_mask
from utils.config import load_yaml_config

logger = logging.getLogger("persona-pruner.train")


def _parse_chatml_overrides(entries):
    parsed = []
    for item in entries:
        parts = item.split(":")
        path = parts[0]
        split = parts[1] if len(parts) > 1 and parts[1] else "train"
        task = parts[2] if len(parts) > 2 and parts[2] else Path(path).stem
        parsed.append({"path": path, "split": split, "task_name": task})
    return parsed


def _sanitize_for_wandb(value):
    if isinstance(value, Path):
        return str(value)
    if isinstance(value, torch.dtype):
        return str(value)
    if isinstance(value, torch.device):
        return str(value)
    if isinstance(value, dict):
        return {str(k): _sanitize_for_wandb(v) for k, v in value.items()}
    if isinstance(value, (list, tuple, set)):
        return [_sanitize_for_wandb(v) for v in value]
    return value


@contextmanager
def apply_layer_masks(model, mask_tensor: torch.Tensor):
    """Context manager that applies per-layer masks during forward and clears afterwards."""
    batch_mask = mask_tensor
    for layer_idx in range(batch_mask.size(1)):
        set_layer_mask(model, layer_idx, batch_mask[:, layer_idx, :])
    try:
        yield
    finally:
        clear_all_masks(model)


class DirectMaskPredictor(nn.Module):
    """Learn mask scores directly without a prompt encoder/MLP."""

    def __init__(
        self,
        predictor_cfg: Dict[str, object],
        model_config,
        device: Optional[torch.device] = None,
    ):
        super().__init__()
        self.cfg = predictor_cfg
        self.device = device or torch.device("cuda")

        num_layers = predictor_cfg.get("num_layers")
        intermediate_size = predictor_cfg.get("intermediate_size")
        if model_config is not None:
            num_layers = num_layers or model_config.num_hidden_layers
            intermediate_size = intermediate_size or model_config.intermediate_size
        if num_layers is None or intermediate_size is None:
            raise ValueError("Both num_layers and intermediate_size must be provided via predictor_cfg or model_config.")

        self.num_layers = num_layers
        self.intermediate_size = intermediate_size
        target_size = predictor_cfg.get("target_intermediate_size")
        self.target_intermediate_size = target_size or intermediate_size

        init_std = float(predictor_cfg.get("score_init_std", 0.02))
        init_mean = float(predictor_cfg.get("score_init_mean", 0.0))
        self.scores = nn.Parameter(torch.empty(num_layers, intermediate_size))
        nn.init.normal_(self.scores, mean=init_mean, std=init_std)

    def _scores_to_masks(self, scores: torch.Tensor) -> MaskPredictorOutput:
        scores = scores.reshape(-1, self.num_layers, self.intermediate_size)

        target_k = min(self.target_intermediate_size, self.intermediate_size)
        hard_mask = hard_topk_mask(scores, target_k)
        soft_mask = hard_mask + scores - scores.detach()
        keep_ratio = hard_mask.mean(dim=-1)
        topk_indices = torch.topk(scores, k=target_k, dim=-1).indices

        return MaskPredictorOutput(
            hard_mask=hard_mask,
            soft_mask=soft_mask,
            keep_ratio=keep_ratio,
            scores=scores,
            topk_indices=topk_indices,
        )

    def forward_from_features(
        self,
        features: Optional[torch.Tensor],
        batch_size: Optional[int] = None,
    ) -> MaskPredictorOutput:
        _ = features
        _ = batch_size
        return self._scores_to_masks(self.scores)


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Train the mask predictor using mixed-task data.")
    parser.add_argument(
        "--config",
        type=Path,
        default=Path("configs/mask_training.yaml"),
        help="Training configuration YAML file.",
    )
    parser.add_argument(
        "--output-dir",
        type=Path,
        default=None,
        help="Optional override for output directory.",
    )
    parser.add_argument(
        "--target-intermediate-size",
        type=int,
        default=None,
        help="Override for mask predictor target_intermediate_size. If unset, use config value.",
    )
    parser.add_argument(
        "--chatml-file",
        action="append",
        default=None,
        help="Override chatml_files entries. Format: path[:split[:task_name]]",
    )
    parser.add_argument(
        "--val-chatml-file",
        action="append",
        default=None,
        help="Override validation chatml files. Format: path[:split[:task_name]]",
    )
    parser.add_argument(
        "--test-chatml-file",
        action="append",
        default=None,
        help="Override test chatml files. Format: path[:split[:task_name]]",
    )
    parser.add_argument(
        "--wandb-run-name",
        type=str,
        default=None,
        help="Override wandb run name (useful for persona-specific runs).",
    )
    parser.add_argument(
        "--prompt-text-mode",
        type=str,
        default=None,
        choices=["system_user", "system_only", "user_only"],
        help="Override data.prompt_text_mode for mask predictor input text.",
    )
    parser.add_argument(
        "--llm-prompt-mode",
        type=str,
        default=None,
        choices=["system_user", "system_only", "user_only"],
        help="Override data.llm_prompt_mode for ChatML encoding to LLM.",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=42,
        help="Override data.seed and seed Python/Torch RNGs.",
    )
    parser.add_argument(
        "--learning-rate",
        type=float,
        default=None,
        help="Override training learning_rate.",
    )
    parser.add_argument(
        "--learning-rate-mask",
        type=float,
        default=None,
        help="Override mask learning rate (defaults to training.learning_rate).",
    )
    parser.add_argument(
        "--learning-rate-llm",
        type=float,
        default=None,
        help="Override LLM learning rate (defaults to training.learning_rate).",
    )
    parser.add_argument(
        "--warmup-ratio",
        type=float,
        default=None,
        help="Override training warmup_ratio.",
    )
    parser.add_argument(
        "--warmup-ratio-llm",
        type=float,
        default=None,
        help="Override LLM warmup_ratio (defaults to warmup_ratio).",
    )
    parser.add_argument(
        "--mask-scheduler-scale",
        type=float,
        default=1.0,
        help="Scale factor for mask LR scheduler steps (lower = faster decay).",
    )
    parser.add_argument(
        "--log-ema-decay",
        type=float,
        default=0.98,
        help="EMA decay for logging trend metrics.",
    )
    parser.add_argument(
        "--export-only-final-epoch",
        action="store_true",
        help="If set, export only the final epoch model.",
    )
    parser.add_argument(
        "--export-initial-mask",
        action="store_true",
        help="If set, export the initial (pre-backprop) mask as epoch_00.",
    )
    parser.add_argument(
        "--export-mask-only",
        action="store_true",
        help="If set, export only top-k indices (no pruned model weights).",
    )
    return parser.parse_args()


def setup_logging():
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    )


def save_mask_checkpoint(
    base_dir: Path,
    mask_predictor: torch.nn.Module,
    epoch: Optional[int] = None,
    metadata: Optional[dict] = None,
    base_model: Optional[torch.nn.Module] = None,
    tokenizer=None,
):
    ckpt_dir = base_dir if epoch is None else base_dir / f"epoch_{epoch:02d}"
    ckpt_dir.mkdir(parents=True, exist_ok=True)
    torch.save(mask_predictor.state_dict(), ckpt_dir / "mask_predictor.pt")
    if metadata:
        with open(ckpt_dir / "meta.json", "w", encoding="utf-8") as f:
            json.dump(metadata, f, indent=2)
    if base_model is not None:
        hf_dir = ckpt_dir / "base_model"
        base_model.save_pretrained(hf_dir)
        if tokenizer is not None and hasattr(tokenizer, "save_pretrained"):
            tokenizer.save_pretrained(hf_dir)


def _resolve_unique_export_dir(root: Path, run_name: str) -> Path:
    """Return a unique export directory, appending a numeric suffix if needed."""
    base = root / run_name
    if not base.exists():
        return base
    suffix = 0
    while True:
        candidate = root / f"{run_name}_{suffix}"
        if not candidate.exists():
            return candidate
        suffix += 1


def main() -> None:
    setup_logging()
    args = parse_args()
    cfg = load_yaml_config(args.config)
    if args.chatml_file or args.val_chatml_file or args.test_chatml_file:
        cfg.setdefault("data", {})
        if args.chatml_file:
            cfg["data"]["chatml_files"] = _parse_chatml_overrides(args.chatml_file)
        if args.val_chatml_file:
            cfg["data"]["chatml_val_files"] = _parse_chatml_overrides(args.val_chatml_file)
        if args.test_chatml_file:
            cfg["data"]["chatml_test_files"] = _parse_chatml_overrides(args.test_chatml_file)
    if args.seed is not None:
        seed = int(args.seed)
        cfg.setdefault("data", {})
        cfg["data"]["seed"] = seed
        random.seed(seed)
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)

    if "model" in cfg:
        model_cfg = cfg["model"]
    else:
        model_cfg_path = cfg.get("model_config", "configs/model.yaml")
        model_cfg = load_yaml_config(model_cfg_path)

    if args.target_intermediate_size is not None:
        model_cfg.setdefault("mask_predictor", {})
        model_cfg["mask_predictor"]["target_intermediate_size"] = args.target_intermediate_size

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logger.info("Loading base model from %s", model_cfg.get("base_model"))
    llama_bundle = load_llama3_model(model_cfg)
    base_model_config_dict = copy.deepcopy(llama_bundle.model.config.to_dict())
    base_model_config_cls = llama_bundle.model.config.__class__
    export_dtype = next(llama_bundle.model.parameters()).dtype

    freeze_base_model = bool(model_cfg.get("freeze_base_model", True))
    if freeze_base_model:
        for param in llama_bundle.model.parameters():
            param.requires_grad_(False)
        llama_bundle.model.to(device)
        llama_bundle.model.eval()
        logger.info("Base model parameters frozen.")
    else:
        llama_bundle.model.to(device)
        llama_bundle.model.train()
        logger.info("Base model parameters will be updated during training.")

    mask_cfg = cfg.get("mask_predictor", model_cfg.get("mask_predictor", {}))
    if args.target_intermediate_size is not None:
        mask_cfg = dict(mask_cfg)
        mask_cfg["target_intermediate_size"] = args.target_intermediate_size
    mask_predictor = DirectMaskPredictor(mask_cfg, llama_bundle.model.config, device)
    mask_predictor.to(device)

    for param in mask_predictor.parameters():
        param.requires_grad_(True)

    data_cfg = cfg.get("data", {})
    if args.prompt_text_mode is not None:
        data_cfg = dict(data_cfg)
        data_cfg["prompt_text_mode"] = args.prompt_text_mode
    if args.llm_prompt_mode is not None:
        data_cfg = dict(data_cfg)
        data_cfg["llm_prompt_mode"] = args.llm_prompt_mode
    system_prompt = data_cfg.get(
        "system_prompt",
        "You are a helpful assistant.",
    )
    cutoff_len = int(data_cfg.get("cutoff_len", 2048))
    prompt_max_length = int(data_cfg.get("prompt_max_length", 512))
    prompt_text_mode = data_cfg.get("prompt_text_mode", "system_user")
    llm_prompt_mode = data_cfg.get("llm_prompt_mode", prompt_text_mode)
    append_eos = bool(data_cfg.get("append_eos_to_response", False))
    seed = int(data_cfg.get("seed", 42))

    logger.info("Building training dataset...")
    chatml_files = data_cfg.get("chatml_files")
    chatml_val_files = data_cfg.get("chatml_val_files")
    if chatml_files:
        dataset = build_chatml_training_dataset(
            files=chatml_files,
            tokenizer=llama_bundle.tokenizer,
            cutoff_len=cutoff_len,
            prompt_max_length=prompt_max_length,
            default_system_prompt=system_prompt,
            prompt_text_mode=prompt_text_mode,
            llm_prompt_mode=llm_prompt_mode,
            append_eos=append_eos,
            seed=seed,
        )
        dataset_names = [
            item.get("task_name") or Path(item["path"]).stem for item in chatml_files
        ]
    else:
        specs = load_dataset_specs(data_cfg)
        dataset = build_training_dataset(
            specs=specs,
            tokenizer=llama_bundle.tokenizer,
            system_prompt=system_prompt,
            cutoff_len=cutoff_len,
            prompt_max_length=prompt_max_length,
            seed=seed,
        )
        dataset_names = [spec.name for spec in specs]

    val_dataset = None
    test_dataset = None
    if len(dataset) > 0:
        sample0 = dataset[0]
        input_len = len(sample0["input_ids"])
        label_len = sum(1 for t in sample0["labels"] if t != -100)
        prompt_preview = sample0["prompt_text"].replace("\n", "\\n")
        prompt_preview = prompt_preview[:200] + ("..." if len(prompt_preview) > 200 else "")
        logger.info(
            "First training sample | prompt_text_mode=%s | task=%s | input_len=%d | label_len=%d | prompt_preview=\"%s\"",
            prompt_text_mode,
            sample0.get("task", ""),
            input_len,
            label_len,
            prompt_preview,
        )
    if chatml_val_files:
        val_dataset = build_chatml_training_dataset(
            files=chatml_val_files,
            tokenizer=llama_bundle.tokenizer,
            cutoff_len=cutoff_len,
            prompt_max_length=prompt_max_length,
            default_system_prompt=system_prompt,
            prompt_text_mode=prompt_text_mode,
            llm_prompt_mode=llm_prompt_mode,
            append_eos=append_eos,
            seed=seed,
        )
    elif "validation_datasets" in data_cfg:
        val_specs = load_dataset_specs({"datasets": data_cfg["validation_datasets"]})
        val_dataset = build_training_dataset(
            specs=val_specs,
            tokenizer=llama_bundle.tokenizer,
            system_prompt=system_prompt,
            cutoff_len=cutoff_len,
            prompt_max_length=prompt_max_length,
            seed=seed,
        )
    
    chatml_test_files = data_cfg.get("chatml_test_files")
    if chatml_test_files:
        test_dataset = build_chatml_training_dataset(
            files=chatml_test_files,
            tokenizer=llama_bundle.tokenizer,
            cutoff_len=cutoff_len,
            prompt_max_length=prompt_max_length,
            default_system_prompt=system_prompt,
            prompt_text_mode=prompt_text_mode,
            llm_prompt_mode=llm_prompt_mode,
            append_eos=append_eos,
            seed=seed,
        )
        logger.info("Loaded %d test samples.", len(test_dataset))
    if len(dataset) == 0:
        raise RuntimeError("No training samples were created. Please check dataset configuration.")
    logger.info("Loaded %d training samples.", len(dataset))
    task_counts = {}
    for sample in dataset.samples:
        task_counts[sample.task] = task_counts.get(sample.task, 0) + 1
    logger.info("Task distribution: %s", task_counts)

    cached_features = None

    collator = MaskPruningCollator(
        pad_token_id=llama_bundle.tokenizer.pad_token_id or llama_bundle.tokenizer.eos_token_id,
        label_pad_token_id=-100,
        pad_to_multiple_of=8,
    )

    train_batch_size = int(cfg["training"]["batch_size"])
    num_workers = int(data_cfg.get("num_workers", 0))
    train_loader = DataLoader(
        dataset,
        batch_size=train_batch_size,
        shuffle=True,
        collate_fn=collator,
        num_workers=num_workers,
        pin_memory=True,
    )
    val_loader = None
    if val_dataset is not None and len(val_dataset) > 0:
        val_loader = DataLoader(
            val_dataset,
            batch_size=train_batch_size,
            shuffle=False,
            collate_fn=collator,
            num_workers=num_workers,
            pin_memory=True,
        )
    
    test_loader = None
    if test_dataset is not None and len(test_dataset) > 0:
        test_loader = DataLoader(
            test_dataset,
            batch_size=train_batch_size,
            shuffle=False,
            collate_fn=collator,
            num_workers=num_workers,
            pin_memory=True,
        )

    train_cfg = cfg["training"]
    num_epochs = int(train_cfg.get("num_epochs", 1))
    grad_accum = int(train_cfg.get("grad_accumulation_steps", 1))
    learning_rate = float(train_cfg.get("learning_rate", 5e-5))
    if args.learning_rate is not None:
        learning_rate = float(args.learning_rate)
    learning_rate_mask = learning_rate if args.learning_rate_mask is None else float(args.learning_rate_mask)
    learning_rate_llm = learning_rate if args.learning_rate_llm is None else float(args.learning_rate_llm)
    mask_scheduler_scale = float(args.mask_scheduler_scale)
    if mask_scheduler_scale <= 0:
        raise ValueError("mask_scheduler_scale must be positive.")
    log_ema_decay = float(args.log_ema_decay)
    if not 0.0 < log_ema_decay < 1.0:
        raise ValueError("log_ema_decay must be between 0 and 1.")
    weight_decay = float(train_cfg.get("weight_decay", 0.0))
    warmup_ratio = float(train_cfg.get("warmup_ratio", 0.0))
    if args.warmup_ratio is not None:
        warmup_ratio = float(args.warmup_ratio)
    warmup_ratio_llm = warmup_ratio
    if args.warmup_ratio_llm is not None:
        warmup_ratio_llm = float(args.warmup_ratio_llm)
    max_grad_norm = float(train_cfg.get("max_grad_norm", 1.0))
    logging_steps = int(train_cfg.get("logging_steps", 10))
    save_every = train_cfg.get("save_every")
    save_each_epoch = bool(train_cfg.get("save_each_epoch", False))
    wandb_cfg = train_cfg.get("wandb", {})
    if args.wandb_run_name:
        wandb_cfg = dict(wandb_cfg)
        wandb_cfg["run_name"] = args.wandb_run_name
    use_wandb = bool(wandb_cfg.get("enable", False))
    wandb_run = None
    if use_wandb:
        if wandb is None:
            logger.warning("wandb logging requested but wandb is not installed. Disable wandb or install the package.")
        else:
            resolved_cfg = {
                "model": _sanitize_for_wandb(model_cfg),
                "data": _sanitize_for_wandb(data_cfg),
                "training": _sanitize_for_wandb(train_cfg),
            }
            wandb_run = wandb.init(
                project=wandb_cfg.get("project", "mask-pruning"),
                name=wandb_cfg.get("run_name"),
                entity=wandb_cfg.get("entity"),
                tags=wandb_cfg.get("tags", None),
                config={
                    "learning_rate": learning_rate,
                    "learning_rate_mask": learning_rate_mask,
                    "learning_rate_llm": learning_rate_llm,
                    "mask_scheduler_scale": mask_scheduler_scale,
                    "log_ema_decay": log_ema_decay,
                    "batch_size": train_batch_size,
                    "grad_accum": grad_accum,
                    "num_epochs": num_epochs,
                    "weight_decay": weight_decay,
                    "warmup_ratio": warmup_ratio,
                    "warmup_ratio_llm": warmup_ratio_llm,
                    "cutoff_len": cutoff_len,
                    "prompt_max_length": prompt_max_length,
                    "datasets": dataset_names,
                    "task_distribution": task_counts,
                    "config_file": str(args.config),
                    "config_raw": _sanitize_for_wandb(cfg),
                    "config_resolved": resolved_cfg,
                    "args": _sanitize_for_wandb(vars(args)),
                },
            )

    mask_params = [p for p in mask_predictor.parameters() if p.requires_grad]
    llm_params = []
    if not freeze_base_model:
        llm_params = [p for p in llama_bundle.model.parameters() if p.requires_grad]
    trainable_params = mask_params + llm_params
    if len(trainable_params) == 0:
        raise RuntimeError("No trainable parameters found for optimization.")

    optimizer_mask = AdamW(mask_params, lr=learning_rate_mask, weight_decay=weight_decay)
    optimizer_llm = AdamW(llm_params, lr=learning_rate_llm, weight_decay=weight_decay) if llm_params else None
    total_steps = math.ceil(len(train_loader) / grad_accum) * num_epochs
    warmup_steps_llm = int(total_steps * warmup_ratio_llm)
    scheduler_llm = None
    if optimizer_llm is not None:
        scheduler_llm = get_cosine_schedule_with_warmup(optimizer_llm, warmup_steps_llm, total_steps)
    mask_total_steps = max(1, int(total_steps * mask_scheduler_scale))
    warmup_steps_mask = int(mask_total_steps * warmup_ratio)
    scheduler_mask = get_cosine_schedule_with_warmup(optimizer_mask, warmup_steps_mask, mask_total_steps)

    output_dir = Path(args.output_dir or train_cfg.get("output_dir", "outputs/mask_predictor"))
    output_dir.mkdir(parents=True, exist_ok=True)

    run_name = wandb_cfg.get("run_name") or output_dir.name
    export_root = Path("exports")
    export_run_dir = _resolve_unique_export_dir(export_root, run_name)
    export_run_dir.mkdir(parents=True, exist_ok=True)
    export_only_final_epoch = bool(args.export_only_final_epoch)
    export_initial_mask = bool(args.export_initial_mask)
    export_mask_only = bool(args.export_mask_only)
    if export_mask_only and not freeze_base_model:
        raise RuntimeError("export-mask-only requires freeze_base_model=true (LLM weights would be lost).")

    mask_predictor.train()
    global_step = 0
    mask_step = 0
    llm_step = 0
    optimizer_mask.zero_grad()
    if optimizer_llm is not None:
        optimizer_llm.zero_grad()
    training_start = time.perf_counter()
    prev_hard_mask = None
    ema_loss = None
    ema_lambda_mean = None
    ema_topk_overlap = None

    def run_eval(loader):
        """Run evaluation on given loader (validation or test)."""
        if loader is None:
            return None
        was_predictor_train = mask_predictor.training
        was_llm_train = llama_bundle.model.training
        mask_predictor.eval()
        llama_bundle.model.eval()
        total_loss = 0.0
        batches = 0
        with torch.no_grad():
            for batch in loader:
                batch.pop("prompt_text")
                batch.pop("task", None)
                inputs = {k: v.to(device) for k, v in batch.items()}
                mask_output = mask_predictor.forward_from_features(
                    cached_features,
                    batch_size=inputs["input_ids"].size(0),
                )
                mask_tensor = mask_output.soft_mask.to(device)
                with apply_layer_masks(llama_bundle.model, mask_tensor):
                    outputs = llama_bundle.model(**inputs)
                    lm_loss = outputs["loss"]
                loss = lm_loss
                total_loss += loss.item()
                batches += 1
        if was_predictor_train:
            mask_predictor.train()
        if not freeze_base_model and was_llm_train:
            llama_bundle.model.train()
        return total_loss / max(batches, 1)

    def run_validation():
        return run_eval(val_loader)

    def run_test():
        return run_eval(test_loader)

    def compute_grad_norm(params) -> float:
        """Compute L2 norm of gradients for logging."""
        total_sq = 0.0
        for p in params:
            if p.grad is None:
                continue
            param_norm = p.grad.data.norm(2).item()
            total_sq += param_norm * param_norm
        return total_sq ** 0.5

    def compute_mask_step_stats(mask_output: torch.nn.Module, prev_hard_mask: Optional[torch.Tensor]):
        with torch.inference_mode():
            target_k = min(mask_predictor.target_intermediate_size, mask_predictor.intermediate_size)
            lambda_scores = mask_output.scores.float()
            stats = {
                "mask/lambda_mean": lambda_scores.mean().item(),
                "mask/lambda_std": lambda_scores.std().item(),
                "mask/lambda_near_zero": (lambda_scores < 0.05).float().mean().item(),
                "mask/lambda_near_one": (lambda_scores > 0.95).float().mean().item(),
            }

            hard_mask = mask_output.hard_mask
            if hard_mask.dim() == 3:
                hard_mask = hard_mask[0]
            hard_mask_cpu = (hard_mask > 0.5).detach().cpu()
            if prev_hard_mask is not None:
                overlap = (hard_mask_cpu & prev_hard_mask).float().sum(dim=-1)
                denom = hard_mask_cpu.float().sum(dim=-1).clamp(min=1.0)
                overlap_ratio = (overlap / denom).mean().item()
                stats["mask/topk_overlap"] = overlap_ratio
                stats["mask/topk_change"] = 1.0 - overlap_ratio
        return stats, hard_mask_cpu

    def log_lambda_saturation(mask_output: torch.nn.Module, epoch_step: int) -> None:
        """Monitor lambda saturation during training."""
        with torch.inference_mode():
            target_k = min(mask_predictor.target_intermediate_size, mask_predictor.intermediate_size)
            lambda_scores = mask_output.scores.float()
            near_zero = (lambda_scores < 0.05).float().mean().item()
            near_one = (lambda_scores > 0.95).float().mean().item()
            mean_val = lambda_scores.mean().item()
            topk_mask = hard_topk_mask(lambda_scores, target_k).bool()
            topk_vals = lambda_scores[topk_mask]
            non_topk_vals = lambda_scores[~topk_mask]
            topk_near_zero = (topk_vals < 0.05).float().mean().item() if topk_vals.numel() else 0.0
            topk_near_one = (topk_vals > 0.95).float().mean().item() if topk_vals.numel() else 0.0
            non_topk_near_zero = (non_topk_vals < 0.05).float().mean().item() if non_topk_vals.numel() else 0.0
            non_topk_near_one = (non_topk_vals > 0.95).float().mean().item() if non_topk_vals.numel() else 0.0
            topk_mean = topk_vals.mean().item() if topk_vals.numel() else 0.0
            non_topk_mean = non_topk_vals.mean().item() if non_topk_vals.numel() else 0.0

        log_entry = {
            "epoch": epoch + 1,
            "epoch_step": epoch_step,
            "lambda_near_zero": near_zero,
            "lambda_near_one": near_one,
            "lambda_mean": mean_val,
            "topk_lambda_near_zero": topk_near_zero,
            "topk_lambda_near_one": topk_near_one,
            "non_topk_lambda_near_zero": non_topk_near_zero,
            "non_topk_lambda_near_one": non_topk_near_one,
            "topk_lambda_mean": topk_mean,
            "non_topk_lambda_mean": non_topk_mean,
        }
        logger.info("Lambda saturation check: %s", json.dumps(log_entry))
        if wandb_run is not None:
            wandb_run.log(
                {
                    "mask/lambda_near_zero": near_zero,
                    "mask/lambda_near_one": near_one,
                    "mask/lambda_mean": mean_val,
                    "mask/topk_lambda_near_zero": topk_near_zero,
                    "mask/topk_lambda_near_one": topk_near_one,
                    "mask/non_topk_lambda_near_zero": non_topk_near_zero,
                    "mask/non_topk_lambda_near_one": non_topk_near_one,
                    "mask/topk_lambda_mean": topk_mean,
                    "mask/non_topk_lambda_mean": non_topk_mean,
                    "mask/epoch_step": epoch_step,
                },
                step=global_step,
            )

    def export_pruned_epoch(epoch_index: int, is_final: bool) -> None:
        if export_only_final_epoch and not is_final:
            return
        export_dir = export_run_dir / f"epoch_{epoch_index:02d}"
        export_dir.mkdir(parents=True, exist_ok=True)

        was_predictor_train = mask_predictor.training
        mask_predictor.eval()
        with torch.inference_mode():
            mask_output = mask_predictor.forward_from_features(
                cached_features,
                batch_size=1,
            )
        if was_predictor_train:
            mask_predictor.train()

        topk_indices = mask_output.topk_indices.detach().cpu()

        if export_mask_only:
            torch.save(topk_indices, export_dir / "topk_indices.pt")
            meta = {
                "mask_only": True,
                "epoch": epoch_index,
                "target_intermediate_size": int(mask_predictor.target_intermediate_size),
                "base_model": str(model_cfg.get("base_model")),
            }
            with open(export_dir / "mask_export.json", "w", encoding="utf-8") as f:
                json.dump(meta, f, indent=2)
            return

        state_cpu = {k: v.detach().cpu() for k, v in llama_bundle.model.state_dict().items()}
        export_config = base_model_config_cls.from_dict(base_model_config_dict)
        export_model = llama_bundle.model.__class__(export_config)
        export_model.to(dtype=export_dtype)
        export_model.load_state_dict(state_cpu, strict=True)
        prune_llama_ffn(export_model, topk_indices)
        export_model.config.torch_dtype = export_dtype
        export_model.save_pretrained(export_dir)
        llama_bundle.tokenizer.save_pretrained(export_dir)

    if export_initial_mask:
        logger.info("Exporting initial (pre-backprop) mask to epoch_00.")
        export_pruned_epoch(0, is_final=False)

    for epoch in range(num_epochs):
        logger.info("Starting epoch %d / %d", epoch + 1, num_epochs)
        epoch_task_counts = {}
        running_loss = 0.0
        micro_batches = 0
        epoch_desc = f"Epoch {epoch + 1}/{num_epochs}"
        epoch_steps = len(train_loader)
        half_step = max(1, epoch_steps // 2)
        for step, batch in enumerate(tqdm(train_loader, desc=epoch_desc, leave=False), start=1):
            batch.pop("prompt_text")
            tasks = batch.pop("task")

            inputs = {k: v.to(device) for k, v in batch.items()}
            mask_output = mask_predictor.forward_from_features(
                cached_features,
                batch_size=inputs["input_ids"].size(0),
            )
            mask_tensor = mask_output.soft_mask.to(device)

            with apply_layer_masks(llama_bundle.model, mask_tensor):
                outputs = llama_bundle.model(**inputs)
                lm_loss = outputs["loss"]

            loss = lm_loss

            loss = loss / grad_accum
            loss.backward()
            running_loss += loss.item()
            micro_batches += 1
            for t in tasks:
                epoch_task_counts[t] = epoch_task_counts.get(t, 0) + 1

            if step == half_step or step == epoch_steps:
                log_lambda_saturation(mask_output, step)

            if step % grad_accum == 0:
                next_step = global_step + 1
                log_grad = (next_step % logging_steps == 0)
                grad_norm_mask = None
                grad_norm_llm = None
                if log_grad:
                    grad_norm_mask = compute_grad_norm(mask_predictor.parameters())
                    if freeze_base_model:
                        grad_norm_llm = 0.0
                    else:
                        grad_norm_llm = compute_grad_norm(
                            p for p in llama_bundle.model.parameters() if p.requires_grad
                        )

                torch.nn.utils.clip_grad_norm_(mask_params, max_grad_norm)
                if llm_params:
                    torch.nn.utils.clip_grad_norm_(llm_params, max_grad_norm)
                optimizer_mask.step()
                if optimizer_llm is not None:
                    optimizer_llm.step()
                mask_step += 1
                if optimizer_llm is not None:
                    llm_step += 1
                if mask_step <= mask_total_steps:
                    scheduler_mask.step()
                if scheduler_llm is not None:
                    scheduler_llm.step()
                optimizer_mask.zero_grad()
                if optimizer_llm is not None:
                    optimizer_llm.zero_grad()
                global_step += 1

                if global_step % logging_steps == 0:
                    mask_lr = scheduler_mask.get_last_lr()[0]
                    llm_lr = scheduler_llm.get_last_lr()[0] if scheduler_llm is not None else learning_rate_llm
                    log_entry: Dict[str, float] = {
                        "global_step": global_step,
                        "loss": loss.item() * grad_accum,
                        "language_model_loss": lm_loss.item(),
                        "learning_rate": mask_lr,
                        "learning_rate/mask": mask_lr,
                    }
                    if llm_params:
                        log_entry["learning_rate/llm"] = llm_lr
                    if grad_norm_mask is not None and grad_norm_llm is not None:
                        log_entry["grad_norm/mask"] = grad_norm_mask
                        log_entry["grad_norm/llm"] = grad_norm_llm
                        log_entry["grad_norm/ratio"] = grad_norm_mask / (grad_norm_llm + 1e-12)
                    mask_stats, prev_hard_mask = compute_mask_step_stats(mask_output, prev_hard_mask)
                    log_entry.update(mask_stats)
                    ema_loss = loss.item() * grad_accum if ema_loss is None else (
                        log_ema_decay * ema_loss + (1.0 - log_ema_decay) * (loss.item() * grad_accum)
                    )
                    log_entry["loss/ema"] = ema_loss
                    if "mask/lambda_mean" in mask_stats:
                        lambda_mean = mask_stats["mask/lambda_mean"]
                        ema_lambda_mean = lambda_mean if ema_lambda_mean is None else (
                            log_ema_decay * ema_lambda_mean + (1.0 - log_ema_decay) * lambda_mean
                        )
                        log_entry["mask/lambda_mean_ema"] = ema_lambda_mean
                    if "mask/topk_overlap" in mask_stats:
                        overlap_val = mask_stats["mask/topk_overlap"]
                        ema_topk_overlap = overlap_val if ema_topk_overlap is None else (
                            log_ema_decay * ema_topk_overlap + (1.0 - log_ema_decay) * overlap_val
                        )
                        log_entry["mask/topk_overlap_ema"] = ema_topk_overlap
                    task_count_log = {}
                    for t in tasks:
                        task_count_log[t] = task_count_log.get(t, 0) + 1
                    for task_name, count in task_count_log.items():
                        log_entry[f"train/task_count/{task_name}"] = count

                    logger.info("Step %s", json.dumps(log_entry))
                    if wandb_run is not None:
                        wandb_run.log(log_entry, step=global_step)

        logger.info("Completed epoch %d", epoch + 1)
        epoch_loss_value = (running_loss / max(micro_batches, 1)) * grad_accum
        logger.info("Epoch %d summary: loss=%.4f, tasks=%s", epoch + 1, epoch_loss_value, epoch_task_counts)
        val_loss = run_validation()
        if val_loss is not None:
            logger.info("Validation loss after epoch %d: %.4f", epoch + 1, val_loss)
        test_loss = run_test()
        if test_loss is not None:
            logger.info("Test loss after epoch %d: %.4f", epoch + 1, test_loss)
        if wandb_run is not None:
            epoch_log = {
                "epoch": epoch + 1,
                "epoch_loss": epoch_loss_value,
            }
            if val_loss is not None:
                epoch_log["validation_loss"] = val_loss
            if test_loss is not None:
                epoch_log["test_loss"] = test_loss
            for task_name, count in epoch_task_counts.items():
                epoch_log[f"epoch/task_count/{task_name}"] = count
            wandb_run.log(epoch_log, step=global_step)

        export_pruned_epoch(epoch + 1, is_final=(epoch + 1 == num_epochs))

        if save_each_epoch:
            epoch_metadata = {
                "epoch": epoch + 1,
                "epoch_loss": epoch_loss_value,
                "task_counts": epoch_task_counts,
                "global_step": global_step,
            }
            if test_loss is not None:
                epoch_metadata["test_loss"] = test_loss
            save_mask_checkpoint(
                output_dir,
                mask_predictor,
                epoch=epoch + 1,
                metadata=epoch_metadata,
            )

    total_seconds = time.perf_counter() - training_start
    logger.info("Training complete. Saving mask predictor weights to %s (elapsed %.2f s)", output_dir, total_seconds)
    final_metadata = {
        "num_epochs": num_epochs,
        "total_steps": global_step,
        "task_distribution": task_counts,
        "elapsed_seconds": total_seconds,
    }
    save_mask_checkpoint(
        output_dir,
        mask_predictor,
        metadata=final_metadata,
    )
    if wandb_run is not None:
        wandb_run.finish()


if __name__ == "__main__":
    main()
