import argparse
import gc
import math
from collections import defaultdict
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Sequence, Tuple

import hydra
import torch



MODEL_KIND_TO_CONFIG_SET_MAP: dict[str, dict[str, str]] = {}

MODEL_KIND_SUFFIX = {
    "_1B": "1B",
    "_1BA": "1BA",
    "_2B": "2B",
    "_2BA": "2BA",
    "_0p5B": "0.5B",
    "_0p5BA": "0.5BA",
    "_1p5BNG": "1p5BNG",
    
}

def get_hydra_run_path(_model_kind):
    raise NotImplementedError("provide --hydra-run-path explicitly")


def normalize_model_kind(config_set: str, default: str = "1p5B") -> tuple[str, str]:
    
    model_kind = default
    normalized = config_set
    if 'swiglu' in config_set:
        return 'swiglu', 'swiglu'
    for suffix, kind in MODEL_KIND_SUFFIX.items():
        if config_set.endswith(suffix):
            model_kind = kind
            normalized = config_set[: -len(suffix)]
            break
    return model_kind, normalized


def str_to_bool(text):
    text = str(text).strip().lower()
    if text in {"true", "1", "yes", "y"}:
        return True
    if text in {"false", "0", "no", "n"}:
        return False
    raise argparse.ArgumentTypeError("boolean value expected")


def parse_csv_list(text: Optional[str]) -> list[str]:
    if text is None:
        return []
    parts = [p.strip() for p in text.split(",")]
    return [p for p in parts if p]


def parse_layers_per_impl(text: Optional[str], num_impls: int) -> list[Optional[list[int]]]:
    if text is None:
        return [None] * int(num_impls)

    raw_groups = text.split(";")
    if len(raw_groups) != int(num_impls):
        raise ValueError(
            "--layers-per-implementation must have the same number of entries "
            "as --mlp-implementations"
        )

    layers = []
    for g in raw_groups:
        g = g.strip()
        if not g or g.lower() == "none":
            layers.append(None)
            continue
        items: list[int] = []
        for token in parse_csv_list(g):
            if "-" in token:
                start_text, end_text = (t.strip() for t in token.split("-", 1))
                start = int(start_text)
                end = int(end_text)
                if end < start:
                    raise ValueError(
                        "range end must be >= start in "
                        "--layers-per-implementation"
                    )
                items.extend(range(start, end + 1))
            else:
                items.append(int(token))
        layers.append(items if items else None)
    return layers


def parse_dtype(dtype_name: str):
    name = str(dtype_name).strip().lower()
    if name in {"bf16", "bfloat16"}:
        return torch.bfloat16
    if name in {"fp16", "float16", "half"}:
        return torch.float16
    if name in {"fp32", "float32"}:
        return torch.float32
    raise ValueError(f"unsupported dtype: {dtype_name}")


def free_cuda_memory():
    
    if not torch.cuda.is_available():
        return
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()
    torch.cuda.synchronize()


def resolve_run_dir(_config_set: str, _model_kind: str, model_path: Optional[str]):
    if model_path is None:
        raise ValueError("provide --model-path")
    return Path(model_path)


def resolve_state_dict_file(run_dir: Path | str) -> Path:
    run_dir = Path(run_dir)
    if run_dir.is_file():
        return run_dir

    candidates = [
        run_dir / "model.safetensors",
        run_dir / "pytorch_model.bin",
    ]
    for p in candidates:
        if p.exists():
            return p

    raise FileNotFoundError(
        f"no state dict found in {run_dir} "
        f"(tried: {', '.join([c.name for c in candidates])})"
    )


def get_core_model(model: torch.nn.Module) -> torch.nn.Module:
    
    if hasattr(model, "model") and isinstance(model.model, torch.nn.Module):
        return model.model
    return model


def get_mlp_modules(model: torch.nn.Module) -> list[torch.nn.Module]:
    
    core = get_core_model(model)
    mlps: list[torch.nn.Module] = []

    maybe_layers = None
    if hasattr(core, "layers"):
        maybe_layers = core.layers
    elif hasattr(core, "model") and hasattr(core.model, "layers"):
        maybe_layers = core.model.layers

    if maybe_layers is not None:
        for layer in maybe_layers:
            if hasattr(layer, "mlp"):
                mlps.append(layer.mlp)

    if mlps:
        return mlps

    
    for name, module in model.named_modules():
        if name.endswith(".mlp"):
            mlps.append(module)

    if mlps:
        return mlps

    
    for _, module in model.named_modules():
        cls_name = module.__class__.__name__.lower()
        if cls_name.endswith("mlp") or cls_name == "mlp":
            mlps.append(module)

    return mlps


def build_batch_from_dataset(
    cfg,
    tokenizer,
    batch_size: int,
    seq_len: int,
    device: torch.device,
):
    datasets = hydra.utils.instantiate(cfg.make_dataset_fn, tokenizer=tokenizer)
    train_ds = datasets["train_dataset"]

    try:
        samples = list(train_ds.take(batch_size))
    except Exception:
        samples = [train_ds[i] for i in range(batch_size)]

    pad_id = tokenizer.pad_token_id
    if pad_id is None:
        pad_id = 0

    input_ids = []
    attention_mask = []
    for sample in samples:
        ids = list(sample["input_ids"])[:seq_len]
        mask = list(sample.get("attention_mask", [1] * len(ids)))[:seq_len]

        if len(ids) < seq_len:
            pad_len = seq_len - len(ids)
            ids = ids + [pad_id] * pad_len
            mask = mask + [0] * pad_len

        input_ids.append(ids)
        attention_mask.append(mask)

    input_ids_t = torch.tensor(input_ids, dtype=torch.long, device=device)
    attention_mask_t = torch.tensor(attention_mask, dtype=torch.long,
                                    device=device)
    return input_ids_t, attention_mask_t


def build_seq_len_list(seq_len: int, seq_len_step: Optional[int]) -> list[int]:
    seq_len = int(seq_len)
    if seq_len_step is None:
        return [seq_len]
    step = int(seq_len_step)
    if step <= 0:
        raise ValueError("--seq-len-step must be > 0")
    if step > seq_len:
        raise ValueError("--seq-len-step must be <= --seq-len")
    seq_lens = list(range(step, seq_len + 1, step))
    if seq_lens[-1] != seq_len:
        seq_lens.append(seq_len)
    return seq_lens


def slice_batch_for_seq_len(
    input_ids: torch.Tensor,
    attention_mask: torch.Tensor,
    seq_len: int,
    batch_size: int,
):
    seq_len = int(seq_len)
    batch_size = int(batch_size)
    input_ids_sl = input_ids[:batch_size, :seq_len].contiguous()
    attention_mask_sl = attention_mask[:batch_size, :seq_len].contiguous()
    return input_ids_sl, attention_mask_sl


def decode_token_for_label(tokenizer, token_id: int) -> str:
    try:
        text = tokenizer.decode([int(token_id)],
                                clean_up_tokenization_spaces=False)
    except Exception:
        text = "<decode_error>"
    text = text.replace("\n", "\\n")
    return f"\"{text}\""


def safe_mean_and_std(total: float, total_sq: float, count: int) -> Tuple[float, float]:
    if count <= 0:
        return float("nan"), float("nan")
    mean = total / float(count)
    var = max(total_sq / float(count) - mean * mean, 0.0)
    return mean, math.sqrt(var)


class ActivationStatsAccumulator:
    

    def __init__(self, num_layers: int, seq_len: int):
        self.num_layers = int(num_layers)
        self.seq_len = int(seq_len)
        self.layer_sum = [0.0] * num_layers
        self.layer_sum_sq = [0.0] * num_layers
        self.layer_count = [0] * num_layers
        self.layer_min = [float("inf")] * num_layers
        self.layer_max = [float("-inf")] * num_layers

        self.layer_position_sum = torch.zeros(
            (num_layers, seq_len), dtype=torch.float64)
        self.layer_position_count = torch.zeros(
            (num_layers, seq_len), dtype=torch.float64)

        self.position_sum = torch.zeros(seq_len, dtype=torch.float64)
        self.position_count = torch.zeros(seq_len, dtype=torch.float64)

        self.token_sum: Dict[int, float] = defaultdict(float)
        self.token_count: Dict[int, int] = defaultdict(int)

        self.layer_token_sum: List[Dict[int, float]] = [
            defaultdict(float) for _ in range(num_layers)
        ]
        self.layer_token_count: List[Dict[int, int]] = [
            defaultdict(int) for _ in range(num_layers)
        ]

    def update(
        self,
        layer_idx: int,
        activation_counts: torch.Tensor,
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor],
    ) -> None:
        counts_cpu = activation_counts.detach().to(torch.float32).cpu()
        tokens_cpu = input_ids.detach().to(torch.long).cpu()

        mask_cpu = None
        if attention_mask is not None:
            mask_cpu = attention_mask.detach().to(torch.float32).cpu()
            counts_cpu = counts_cpu * mask_cpu

        if mask_cpu is not None:
            valid_per_pos = mask_cpu.sum(dim=0)
            total_valid = int(mask_cpu.sum().item())
        else:
            valid_per_pos = torch.full(
                (counts_cpu.shape[1],),
                float(counts_cpu.shape[0]),
                dtype=torch.float32,
            )
            total_valid = counts_cpu.numel()

        self.layer_sum[layer_idx] += counts_cpu.sum().item()
        self.layer_sum_sq[layer_idx] += (counts_cpu ** 2).sum().item()
        self.layer_count[layer_idx] += total_valid
        if mask_cpu is not None:
            valid_counts = counts_cpu[mask_cpu > 0]
        else:
            valid_counts = counts_cpu.reshape(-1)
        if valid_counts.numel() > 0:
            self.layer_min[layer_idx] = min(
                self.layer_min[layer_idx],
                float(valid_counts.min().item()),
            )
            self.layer_max[layer_idx] = max(
                self.layer_max[layer_idx],
                float(valid_counts.max().item()),
            )

        counts_by_pos = counts_cpu.sum(dim=0).to(torch.float64)
        valid_per_pos_f = valid_per_pos.to(torch.float64)
        self.layer_position_sum[layer_idx] += counts_by_pos
        self.layer_position_count[layer_idx] += valid_per_pos_f
        self.position_sum += counts_by_pos
        self.position_count += valid_per_pos_f

        flat_counts = counts_cpu.reshape(-1)
        flat_tokens = tokens_cpu.reshape(-1)
        if mask_cpu is not None:
            flat_mask = mask_cpu.reshape(-1)
            keep_mask = flat_mask > 0
            flat_counts = flat_counts[keep_mask]
            flat_tokens = flat_tokens[keep_mask]

        if flat_tokens.numel() == 0:
            return

        token_sum_vec = torch.bincount(
            flat_tokens, weights=flat_counts)
        token_count_vec = torch.bincount(flat_tokens)
        nonzero_token_ids = torch.nonzero(
            token_count_vec, as_tuple=False).flatten()

        for tid in nonzero_token_ids.tolist():
            cnt = int(token_count_vec[tid].item())
            total = float(token_sum_vec[tid].item())
            self.token_sum[tid] += total
            self.token_count[tid] += cnt
            self.layer_token_sum[layer_idx][tid] += total
            self.layer_token_count[layer_idx][tid] += cnt

    def layer_rows(self) -> List[Dict]:
        rows = []
        for i in range(self.num_layers):
            mean, std = safe_mean_and_std(
                self.layer_sum[i],
                self.layer_sum_sq[i],
                self.layer_count[i],
            )
            min_val = (
                float("nan")
                if self.layer_min[i] == float("inf")
                else float(self.layer_min[i])
            )
            max_val = (
                float("nan")
                if self.layer_max[i] == float("-inf")
                else float(self.layer_max[i])
            )
            rows.append({
                "layer_idx": int(i),
                "mean_nonzero": float(mean),
                "std_nonzero": float(std),
                "min_nonzero": min_val,
                "max_nonzero": max_val,
                "sum_nonzero": float(self.layer_sum[i]),
                "sum_nonzero_sq": float(self.layer_sum_sq[i]),
                "positions": int(self.layer_count[i]),
            })
        return rows

    def token_rows(
        self,
        tokenizer,
    ) -> List[Dict]:
        rows = []
        for tid, total in self.token_sum.items():
            cnt = self.token_count.get(tid, 0)
            if cnt <= 0:
                continue
            rows.append({
                "token_id": int(tid),
                "token": decode_token_for_label(tokenizer, tid),
                "mean_nonzero": float(total / cnt),
                "total_nonzero": float(total),
                "occurrences": int(cnt),
            })
        return rows

    def token_rows_per_layer(
        self,
        tokenizer,
    ) -> List[Dict]:
        rows: List[Dict] = []
        for layer_idx in range(self.num_layers):
            layer_sum = self.layer_token_sum[layer_idx]
            layer_cnt = self.layer_token_count[layer_idx]
            for tid, total in layer_sum.items():
                cnt = layer_cnt.get(tid, 0)
                if cnt <= 0:
                    continue
                rows.append({
                    "layer_idx": int(layer_idx),
                    "token_id": int(tid),
                    "token": decode_token_for_label(tokenizer, tid),
                    "mean_nonzero": float(total / cnt),
                    "total_nonzero": float(total),
                    "occurrences": int(cnt),
                })
        return rows

    def position_rows(self) -> List[Dict]:
        rows = []
        for pos in range(self.seq_len):
            cnt = float(self.position_count[pos].item())
            if cnt <= 0.0:
                continue
            mean = float(self.position_sum[pos].item() / cnt)
            rows.append({
                "position_idx": int(pos),
                "mean_nonzero": mean,
                "total_nonzero": float(self.position_sum[pos].item()),
                "occurrences": int(cnt),
            })
        return rows

    def position_rows_per_layer(self) -> List[Dict]:
        rows: List[Dict] = []
        for layer_idx in range(self.num_layers):
            for pos in range(self.seq_len):
                cnt = float(self.layer_position_count[layer_idx, pos].item())
                if cnt <= 0.0:
                    continue
                mean = float(
                    self.layer_position_sum[layer_idx, pos].item() / cnt)
                rows.append({
                    "layer_idx": int(layer_idx),
                    "position_idx": int(pos),
                    "mean_nonzero": mean,
                    "total_nonzero": float(
                        self.layer_position_sum[layer_idx, pos].item()
                    ),
                    "occurrences": int(cnt),
                })
        return rows


class ActivationAnalysisTracker:
    

    def __init__(self, num_layers: int, seq_len: int, *, track_l0: bool = False):
        self.accumulator = ActivationStatsAccumulator(num_layers, seq_len)
        self.active = False
        self.current_input_ids: Optional[torch.Tensor] = None
        self.current_attention_mask: Optional[torch.Tensor] = None
        self.last_device: Optional[torch.device] = None
        self.current_nonzero_sum: Optional[torch.Tensor] = None
        self.track_l0 = bool(track_l0)
        self.current_l0_per_layer: list[torch.Tensor] = []
        self.last_forward_summary: Optional[Dict] = None

    def register_layer(self, layer_idx: int, width: int, device: torch.device):
        
        pass

    def start_forward_pass(self) -> None:
        self.active = True

    def set_current_batch(
        self,
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor],
    ) -> None:
        self.current_input_ids = input_ids.detach().to(torch.long).cpu()
        self.current_attention_mask = (
            attention_mask.detach().to(torch.float32).cpu()
            if attention_mask is not None else None
        )
        self.current_nonzero_sum = torch.zeros_like(
            self.current_input_ids, dtype=torch.float32
        )
        if self.track_l0:
            self.current_l0_per_layer = []
        self.last_forward_summary = None

    def record_activations(
        self,
        layer_idx: int,
        absolute_activations: torch.Tensor,
    ) -> None:
        if not self.active:
            return
        self.last_device = absolute_activations.device
        with torch.no_grad():
            counts = (absolute_activations > 0).sum(dim=-1)
        if self.current_input_ids is None:
            return
        self.accumulator.update(
            layer_idx=layer_idx,
            activation_counts=counts,
            input_ids=self.current_input_ids,
            attention_mask=self.current_attention_mask,
        )
        if self.current_nonzero_sum is not None:
            self.current_nonzero_sum += counts.detach().cpu()
        if self.track_l0:
            total_nonzero = counts.sum().detach().cpu()
            num_tokens = counts.numel()
            if num_tokens > 0:
                l0_average = total_nonzero.to(torch.float32) / float(num_tokens)
                self.current_l0_per_layer.append(l0_average)

    def finalize_forward_pass(self) -> Tuple[torch.Tensor, Dict]:
        self.active = False
        mean_l0_per_layer: Optional[torch.Tensor] = None
        if self.track_l0 and self.current_l0_per_layer:
            mean_l0_per_layer = torch.stack(self.current_l0_per_layer).mean()
        if (
            self.current_nonzero_sum is not None
            and self.current_input_ids is not None
        ):
            valid = (
                self.current_attention_mask
                if self.current_attention_mask is not None
                else torch.ones_like(self.current_input_ids, dtype=torch.float32)
            )
            token_counts = (valid > 0).sum(dim=1).clamp(min=1)
            
            mean_nonzero = (
                self.current_nonzero_sum.sum(dim=1) / token_counts
            ).to(torch.float32)
            self.last_forward_summary = {
                "mean_nonzero_per_token": mean_nonzero,
            }
            if mean_l0_per_layer is not None:
                self.last_forward_summary["mean_l0_per_layer"] = mean_l0_per_layer
        device = (
            self.last_device
            if self.last_device is not None else torch.device("cpu")
        )
        return torch.tensor(0.0, device=device), {}
