from __future__ import annotations
import math as _m
from typing import Optional, Tuple
import torch
from torch import nn
from .tools.gumbel import gumbel_topk_binary
from .tools.spec import GateSpec
from .backends import get_backend
from .tools.apply import apply_ffn_rowcol_gates, apply_attn_rowcol_gates
from .cul import ControlUnitLearner
from .tools.data_utils import fetch_requests_for_batch, load_sid_to_req_map
from .tools.tau import get_tau


def freeze_all_params(m: nn.Module):
    n = 0
    for p in m.parameters():
        p.requires_grad = False
        n += 1
    print(f"[Freeze] Froze {n} tensors in base model")


def compute_kl(rho_phi: torch.Tensor, prior_keep: float) -> torch.Tensor:
    eps = 1e-6
    rho_phi = rho_phi.clamp(eps, 1 - eps)
    rho_prior = torch.full_like(rho_phi, prior_keep).clamp(eps, 1 - eps)
    kl_per_sample = (
        rho_phi * torch.log(rho_phi / rho_prior)
        + (1 - rho_phi) * torch.log((1 - rho_phi) / (1 - rho_prior))
    ).sum(dim=1)
    return kl_per_sample


def save_meta_and_state(save_dir, run_id: str, ctl: nn.Module, meta: dict):
    import json
    from pathlib import Path
    save_dir = Path(save_dir)
    save_dir.mkdir(parents=True, exist_ok=True)
    pth_path = save_dir / f"{run_id}.pth"
    json_path = save_dir / f"{run_id}.meta.json"
    torch.save({"control_unit_state_dict": ctl.state_dict(), "meta": meta}, pth_path)
    with json_path.open("w", encoding="utf-8") as f:
        json.dump(meta, f, indent=2)
    print(f"[Save] Control unit saved to {pth_path}")
    print(f"[Save] Meta JSON saved to {json_path}")


def slugify(s: str) -> str:
    keep = []
    for ch in str(s):
        if ch.isalnum() or ch in ["-", "_", "."]:
            keep.append(ch)
        elif ch in ["/", "\\", ":", " "]:
            keep.append("_")
        else:
            keep.append("_")
    out = []
    for ch in keep:
        if not (out and out[-1] == "_" and ch == "_"):
            out.append(ch)
    return "".join(out).strip("_")


def train_cul(
    base_model: nn.Module,
    dataset,
    train_data,
    valid_data,
    test_data,
    *,
    epochs: int = 30,
    batch_size: int = 32,
    eval_batch_size: int = 1024,
    lr: float = 1e-4,
    plm: str = "answerdotai/ModernBERT-base",
    rho: float = 0.10,
    tau: float = 1.0,
    lambda_kl: float = 1.0,
    device: torch.device = torch.device("cpu"),
    shared_mask: bool = False,
    plm_tune: str = "all",
    plm_last_layers: str = "-1:",
    mask_scope: str = "ffn",
    backend: str = "sasrec",
    tau_schedule: str = "linear",
    tau_start: float = 0.7,
    tau_end: float = 0.3,
    data_path: Optional[str] = None,
    dataset_name: Optional[str] = None,
    base_config: Optional[dict] = None,
    save_root: Optional[str] = None,
    ckpt_tag: Optional[str] = None,
) -> Tuple[str, dict]:
    hooks = get_backend(backend)

    ffn_gate_spec: Optional[GateSpec] = None
    attn_gate_spec: Optional[GateSpec] = None
    wrappers_ffn = None
    wrappers_attn = None

    if mask_scope == "ffn":
        ffn_gate_spec = hooks.build_ffn_gate_spec(base_model)
        wrappers_ffn = hooks.patch_with_maskable_ffn(base_model)
        d = ffn_gate_spec.total_gates
    elif mask_scope == "attn":
        attn_gate_spec = hooks.build_attn_gate_spec(base_model)
        wrappers_attn = hooks.patch_with_maskable_attn_out(base_model)
        d = attn_gate_spec.total_gates
    elif mask_scope == "both":
        ffn_gate_spec = hooks.build_ffn_gate_spec(base_model)
        attn_gate_spec = hooks.build_attn_gate_spec(base_model)
        wrappers_ffn = hooks.patch_with_maskable_ffn(base_model)
        wrappers_attn = hooks.patch_with_maskable_attn_out(base_model)
        d = ffn_gate_spec.total_gates + attn_gate_spec.total_gates
    else:
        raise ValueError(f"Unknown mask_scope {mask_scope}")

    k = max(1, int(_m.ceil((1 - rho) * d)))
    prior_keep = k / d 
    print(f"[Mask][gates] d={d} | rho={rho:.3f} -> k={k} (k/d={(k/d):.3f}) | backend={backend}")

    ctl = ControlUnitLearner(plm, out_dim=d).to(device)
    ctl.set_finetune(plm_tune, plm_last_layers)
    optim = torch.optim.Adam(filter(lambda p: p.requires_grad, ctl.parameters()), lr=lr)

    sid2req = None
    if data_path is not None and dataset_name is not None:
        sid2req = load_sid_to_req_map(data_path, dataset_name)

    total_steps = epochs * len(train_data)
    global_step = 0
    running = {"loss": 0.0, "pred": 0.0, "kl": 0.0}

    for ep in range(1, epochs + 1):
        base_model.train(True)
        ctl.train(True)
        running = {k: 0.0 for k in running}
        for batch_idx, interaction in enumerate(train_data):
            interaction = interaction.to(device)

            cur_tau = get_tau(step=global_step, total_steps=total_steps, kind=tau_schedule, t0=tau_start, t1=tau_end)

            req_texts = fetch_requests_for_batch(interaction, dataset, sid2req)
            logits = ctl(req_texts, device=device)

            rho_phi = torch.sigmoid(logits)
            kl_per_sample = compute_kl(rho_phi, prior_keep)
            kl = kl_per_sample.mean() / d

            if shared_mask:
                gates_1d = gumbel_topk_binary(logits.mean(dim=0), k=k, tau=cur_tau)
                if mask_scope == "ffn":
                    apply_ffn_rowcol_gates(gates_1d, ffn_gate_spec, wrappers_ffn)
                elif mask_scope == "attn":
                    apply_attn_rowcol_gates(gates_1d, attn_gate_spec, wrappers_attn)
                elif mask_scope == "both":
                    g_ffn = gates_1d[: ffn_gate_spec.total_gates]
                    g_attn = gates_1d[ffn_gate_spec.total_gates :]
                    apply_ffn_rowcol_gates(g_ffn, ffn_gate_spec, wrappers_ffn)
                    apply_attn_rowcol_gates(g_attn, attn_gate_spec, wrappers_attn)
                pred_loss = base_model.calculate_loss(interaction)
            else:
                B = logits.size(0)
                loss_sum = 0.0
                for b in range(B):
                    gates_1d = gumbel_topk_binary(logits[b], k=k, tau=cur_tau)
                    if mask_scope == "ffn":
                        apply_ffn_rowcol_gates(gates_1d, ffn_gate_spec, wrappers_ffn)
                    elif mask_scope == "attn":
                        apply_attn_rowcol_gates(gates_1d, attn_gate_spec, wrappers_attn)
                    elif mask_scope == "both":
                        g_ffn = gates_1d[: ffn_gate_spec.total_gates]
                        g_attn = gates_1d[ffn_gate_spec.total_gates :]
                        apply_ffn_rowcol_gates(g_ffn, ffn_gate_spec, wrappers_ffn)
                        apply_attn_rowcol_gates(g_attn, attn_gate_spec, wrappers_attn)
                    idx = torch.tensor([b], device=interaction["user_id"].device)
                    loss_sum = loss_sum + base_model.calculate_loss(interaction[idx])
                pred_loss = loss_sum / B

            loss = pred_loss + lambda_kl * kl
            optim.zero_grad(); loss.backward(); optim.step()
            global_step += 1
            running["loss"] += float(loss.detach().item())
            running["pred"] += float(pred_loss.detach().item())
            running["kl"]   += float(kl.detach().item())

            if (batch_idx + 1) % 50 == 0:
                print(
                    f"[E{ep} S{batch_idx+1}] loss={running['loss']/(batch_idx+1):.4f} "
                    f"pred={running['pred']/(batch_idx+1):.4f} kl={running['kl']/(batch_idx+1):.4f} "
                    f"| tau={cur_tau:.4f} | B={interaction['user_id'].size(0)}"
                )

        print(
            f"[Epoch {ep}] avg_loss={running['loss']/(batch_idx+1):.4f} "
            f"pred={running['pred']/(batch_idx+1):.4f} kl={running['kl']/(batch_idx+1):.4f}"
        )

    dataset_tag = getattr(dataset, 'dataset_name', None) or (dataset_name or 'dataset')
    save_dir = save_root or f"saved_{slugify(dataset_tag)}"
    from time import strftime
    ts = strftime("%Y%m%d-%H%M%S")
    run_id = f"cul_{slugify(dataset_tag)}_{backend}_{mask_scope}_k-{k}_{ts}"

    meta = {
        "dataset": dataset_tag,
        "backend": backend,
        "rho": rho,
        "k": k,
        "prior_keep": prior_keep,
        "lambda_kl": lambda_kl,
        "mask_scope": mask_scope,
        "shared_mask": bool(shared_mask),
    }
    save_meta_and_state(save_dir, run_id, ctl, meta)
    return run_id, meta