"""
Generic bilevel max–max optimizer with implicit differentiation (HVP + CG) + CRN + Consistent Eval + Bounds + Masks.

Key features:
- Works for arbitrary shapes of a (agent vars) and t (contract/outer vars).
- Agnostic to expectation computation (analytic or Monte Carlo).
- Correct SPD fix for inner MAX problems in the CG solve (uses -H_aa and -g_a).
- Optional projections and/or simple bound boxes for a and t (active-set projected updates).
- Training CRN: reuse one MC batch per outer step (less variance).
- Consistent evaluation: single held-out MC batch used for logs and final summary.
- NEW: Per-setting masks:
    • t_train_mask  — which t coordinates are allowed to update (0 = frozen).
    • t_metric_mask — which t coordinates are measured in t-distance logs.

Returns:
    (t, a, u1_values, u2_values, trace)
"""

from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import contextlib
import inspect
import torch
from tqdm import trange

try:
    from utils.csv_utils import compute_errors  # optional
except Exception:
    compute_errors = None

Tensor = torch.Tensor
Number = Union[float, int]


def optimize_bilevel(
    # User-provided payoff functions (expected utilities).
    u1: Optional[Callable[..., Tensor]] = None,   # principal objective g(a, t, **params)
    u2: Optional[Callable[..., Tensor]] = None,   # agent objective f(a, t, **params)
    C:  Optional[Callable[..., Tensor]] = None,   # optional penalty added to OUTER objective g

    # Variables (must be leaf tensors with requires_grad=True).
    t: Tensor = None,      # ANY shape: outer var(s)
    a: Tensor = None,      # ANY shape: inner var(s)

    # Loop + stepsizes
    outer_steps: int = 100000,
    inner_lr: float = 5e-3,
    outer_lr: float = 1e-3,
    inner_max_steps: int = 1000,
    inner_grad_tol: float = 1e-6,

    # Implicit-diff linear solve (CG) settings
    cg_iters: int = 20,
    cg_damping: float = 1e-4,
    grad_clip: Optional[float] = None,

    # Optional projections (user-supplied)
    project_a: Optional[Callable[[Tensor], Tensor]] = None,
    project_t: Optional[Callable[[Tensor], Tensor]] = None,

    # Simple bound boxes (used if no projector provided)
    a_bounds: Optional[Tuple[Tensor, Tensor]] = None,  # (lb, ub) same shape as a or broadcastable
    t_bounds: Optional[Tuple[Tensor, Tensor]] = None,  # (lb, ub) same shape as t or broadcastable
    bound_mode: str = "project",  

    # NEW: masks
    t_train_mask: Optional[Tensor] = None,   # same shape as t (or broadcastable)
    t_metric_mask: Optional[Tensor] = None,  # same shape as t (or broadcastable)

    # Optional ground-truth optima (for diagnostics)
    a_star: Optional[Union[Number, Tensor]] = None,
    t_star: Optional[Union[Number, Tensor]] = None,

    # Optional theoretical optimum function
    get_theoretical_optimum_fn: Optional[Callable[..., Tuple]] = None,

    # Optional labels for t entries in logs
    t_labels: Optional[List[str]] = None,

    # Model constants
    setting_parameters: Optional[Dict[str, Any]] = None,

    # Optional per-step kwargs injector (user-defined)
    per_step_kwargs_fn: Optional[Callable[[int], Dict[str, Any]]] = None,

    # -------- Training CRN controls -------- 
    use_crn: bool = False,
    crn_refresh: int = 1,
    crn_antithetic: bool = True,
    nsamples_key: str = "nsamples",
    make_z: Optional[Callable[[int, Dict[str, Any], torch.dtype, torch.device], Tensor]] = None,
    base_seed: int = 0,

    # -------- Consistent evaluation (for logs & final summary) --------
    log_with_eval: bool = False,
    eval_make_z: Optional[Callable[[int, Dict[str, Any], torch.dtype, torch.device], Tensor]] = None,
    eval_nsamples: int = 4096,
    eval_seed: int = 12345,
) -> Tuple[Tensor, Tensor, List[float], List[float], List[Dict[str, float]]]:

    assert t is not None and a is not None, "Pass leaf Tensors for `t` and `a`."
    assert t.requires_grad and a.requires_grad, "`t` and `a` must require grad."

    base_settings: Dict[str, Any] = dict(setting_parameters or {})

    # ---------- helpers ----------
    def _merge_kwargs(extra: Optional[Dict[str, Any]]) -> Dict[str, Any]:
        if not extra:
            return dict(base_settings)
        merged = dict(base_settings)
        merged.update(extra)
        return merged

    def _filter_kwargs_for_fn(fn: Callable, kw: Dict[str, Any]) -> Dict[str, Any]:
        """Pass only what `fn` can accept (robust to functions without **kwargs)."""
        try:
            sig = inspect.signature(fn)
        except (TypeError, ValueError):
            return kw
        params = sig.parameters
        accepts_varkw = any(p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values())
        if accepts_varkw:
            return kw
        allowed = {name for name in params.keys()}
        return {k: v for k, v in kw.items() if k in allowed}

    def _call_fn(fn: Callable[..., Tensor], a_: Tensor, t_: Tensor, extra: Optional[Dict[str, Any]]) -> Tensor:
        kw = _merge_kwargs(extra)
        kw = _filter_kwargs_for_fn(fn, kw)
        return fn(a_, t_, **kw) if kw else fn(a_, t_)

    def _maybe_sum(x: Tensor) -> Tensor:
        return x if x.ndim == 0 else x.sum()

    def _clip_norm_(g: Tensor, maxnorm: Optional[float]) -> Tensor:
        if maxnorm is None or maxnorm <= 0:
            return g
        n = g.norm()
        if n > maxnorm:
            g = g * (maxnorm / (n + 1e-12))
        return g

    # ---- bounds → default projectors ----
    def _make_projector_from_bounds(bounds: Optional[Tuple[Tensor, Tensor]]):
        if bounds is None:
            return None
        lb, ub = bounds
        def _proj(x: Tensor) -> Tensor:
            return x.clamp(lb.to(x), ub.to(x))
        _proj._bounds = (lb.detach().clone(), ub.detach().clone()) 
        return _proj

    if project_a is None:
        project_a = _make_projector_from_bounds(a_bounds)
    if project_t is None:
        project_t = _make_projector_from_bounds(t_bounds)

    def _project_a(x: Tensor) -> Tensor:
        return x if project_a is None else project_a(x)

    def _project_t(x: Tensor) -> Tensor:
        return x if project_t is None else project_t(x)

    def _active_set_update(x: Tensor, grad: Tensor, lr: float, projector, mask: Optional[Tensor]) -> Tensor:
        """Active-set step: move only on free coords; then project/clamp; respect `mask`."""
        if mask is not None:
            grad = grad * mask.to(dtype=grad.dtype, device=grad.device)
        if projector is None or not hasattr(projector, "_bounds"):
            return x + lr * grad
        lb, ub = projector._bounds 
        lb = lb.to(x); ub = ub.to(x)
        tol = 1e-10
        free = (x > lb + tol) & (x < ub - tol)
        step = torch.zeros_like(x)
        step[free] = (lr * grad)[free]
        return x + step

    @contextlib.contextmanager
    def _seeded(seed: int):
        """Temporarily seed RNGs for reproducible sampling without polluting global state."""
        cpu_state = torch.get_rng_state()
        cuda_state = torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None
        torch.manual_seed(int(seed))
        try:
            yield
        finally:
            torch.set_rng_state(cpu_state)
            if cuda_state is not None:
                torch.cuda.set_rng_state_all(cuda_state)

    def _build_step_extra(step_idx: int) -> Dict[str, Any]:
        """
        Build a SINGLE extra kwargs dict to be reused across the entire outer step.
        If CRN is enabled and a `make_z` sampler is provided, create a z and (optionally)
        antithetic augment it, then attach to kwargs as 'z'.
        """
        user_extra = per_step_kwargs_fn(step_idx) if per_step_kwargs_fn else {}
        extra = dict(user_extra) if user_extra else {}

        if use_crn and make_z is not None:
            group_id = (step_idx // max(int(crn_refresh), 1))
            seed = int(base_seed + group_id)
            with _seeded(seed):
                ns = int(base_settings.get(nsamples_key, 1024))
                z = make_z(ns, base_settings, dtype=t.dtype, device=t.device)
            if crn_antithetic and z.is_floating_point():
                z = torch.cat([z, -z], dim=0)
            extra["z"] = z

        return extra

    def _build_eval_extra() -> Optional[Dict[str, Any]]:
        """Build a fixed eval kwargs dict used for logging and final summary."""
        if not log_with_eval or eval_make_z is None:
            return None
        with _seeded(int(eval_seed)):
            z_eval = eval_make_z(int(eval_nsamples), base_settings, dtype=t.dtype, device=t.device)
        return {"z": z_eval}

    def _inner_solve(a_start: Tensor, t_fixed: Tensor, step_idx: int, step_extra: Optional[Dict[str, Any]]) -> Tuple[Tensor, float]:
        a_loc = a_start.detach().clone().requires_grad_(True)
        for _ in range(inner_max_steps):
            u2_val = _maybe_sum(_call_fn(u2, a_loc, t_fixed.detach(), step_extra))
            (grad_a,) = torch.autograd.grad(u2_val, a_loc, create_graph=False)
            gn = grad_a.detach().norm().item()
            if gn <= inner_grad_tol:
                break
            with torch.no_grad():
                a_loc.copy_(_project_a(a_loc + inner_lr * grad_a))
        return a_loc.detach(), gn

    def _cg_solve(matvec: Callable[[Tensor], Tensor], rhs: Tensor, iters: int, damping: float) -> Tensor:
        x = torch.zeros_like(rhs)
        r = rhs - (matvec(x) + damping * x)
        p = r.clone()
        rs = torch.dot(r.flatten(), r.flatten())
        for _ in range(iters):
            Ap = matvec(p) + damping * p
            denom = torch.dot(p.flatten(), Ap.flatten())
            alpha = rs / (denom + 1e-12)
            x = x + alpha * p
            r = r - alpha * Ap
            rs_new = torch.dot(r.flatten(), r.flatten())
            if rs_new.sqrt() < 1e-8:
                break
            p = r + (rs_new / (rs + 1e-12)) * p
            rs = rs_new
        return x

    # ---------- logs ----------
    u1_values: List[float] = []
    u2_values: List[float] = []
    trace: List[Dict[str, float]] = []

    # Prepare masks
    one_like_t = torch.ones_like(t, dtype=t.dtype, device=t.device)
    t_train_mask = (t_train_mask.to(t) if isinstance(t_train_mask, torch.Tensor) else one_like_t).detach()
    t_metric_mask = (t_metric_mask.to(t) if isinstance(t_metric_mask, torch.Tensor) else one_like_t).detach()

    # Initial projection to ensure feasibility
    with torch.no_grad():
        a.copy_(_project_a(a))
        t.copy_(_project_t(t))

    # Build a single held-out eval batch (optional)
    eval_extra = _build_eval_extra()

    # ---------- main loop ----------
    for step in trange(outer_steps, desc="Outer optimization", leave=True):
        # Build ONE extra dict per step (CRN lives here)
        step_extra = _build_step_extra(step)

        # (1) Inner solve (use SAME step_extra for all inner iterations)
        a_star_t, grad_norm = _inner_solve(a, t, step, step_extra)

        # (2) Implicit hypergradient (again, SAME step_extra)
        a_star_t = a_star_t.detach().requires_grad_(True)
        f_val = _maybe_sum(_call_fn(u2, a_star_t, t, step_extra))
        g_val = _maybe_sum(_call_fn(u1, a_star_t, t, step_extra))
        if C is not None:
            try:
                g_val = g_val + _maybe_sum(_call_fn(C, a_star_t, t, step_extra))
            except Exception:
                pass

        (g_a,) = torch.autograd.grad(g_val, a_star_t, retain_graph=True)
        (g_t,) = torch.autograd.grad(g_val, t, retain_graph=True)
        (f_a,) = torch.autograd.grad(f_val, a_star_t, create_graph=True)

        def hvp_aa(v: Tensor) -> Tensor:
            (Hv,) = torch.autograd.grad(f_a, a_star_t, v, retain_graph=True)
            return -Hv  # SPD fix for MAX inner

        rhs = -g_a.detach()
        v = _cg_solve(hvp_aa, rhs, iters=cg_iters, damping=cg_damping)

        dot_fa_v = torch.dot(f_a.flatten(), v.flatten())
        (HatT_v,) = torch.autograd.grad(dot_fa_v, t, retain_graph=False)

        hypergrad_t = g_t - HatT_v
        if grad_clip is not None:
            hypergrad_t = _clip_norm_(hypergrad_t, grad_clip)

        # Mask the hypergradient for training
        hypergrad_t = hypergrad_t * t_train_mask

        with torch.no_grad():
            if bound_mode == "project":
                # active-set step + clamp (respects mask)
                t.copy_(_active_set_update(t, hypergrad_t, outer_lr, project_t, mask=t_train_mask))
                t.copy_(_project_t(t))
            else:
                t.add_(outer_lr * hypergrad_t)
                t.copy_(_project_t(t))
        t.grad = None
        a.data.copy_(a_star_t.detach())

        # (3) Logs — use eval batch if provided, else training step_extra
        with torch.no_grad():
            extra_for_log = eval_extra if eval_extra is not None else step_extra
            u1_log = float(_maybe_sum(_call_fn(u1, a.detach(), t.detach(), extra_for_log)))
            u2_log = float(_maybe_sum(_call_fn(u2, a.detach(), t.detach(), extra_for_log)))
            c_log = 0.0
            if C is not None:
                try:
                    c_log = float(_maybe_sum(_call_fn(C, a.detach(), t.detach(), extra_for_log)))
                except Exception:
                    c_log = 0.0

        u1_values.append(u1_log)
        u2_values.append(u2_log)

        entry: Dict[str, float] = {
            "step": float(step),
            "u1": u1_log,
            "u2": u2_log,
            "C": c_log,
            "inner_grad_norm": grad_norm,
        }
        for idx, val in enumerate(t.detach().flatten().tolist()):
            key = t_labels[idx] if (t_labels and idx < len(t_labels)) else f"t[{idx}]"
            entry[key] = float(val)

        # -------- Errors if reference stars provided --------
        if a_star is not None:
            if isinstance(a_star, (int, float)):
                entry["err_a_abs"] = abs(float(a.detach().item()) - float(a_star))
            elif isinstance(a_star, torch.Tensor):
                if a_star.numel() == 1 and a.numel() == 1:
                    entry["err_a_abs"] = abs(float(a.detach().item()) - float(a_star.to(a).item()))
                elif a_star.shape == a.shape:
                    entry["err_a_l2"] = float(torch.norm(a.detach() - a_star.to(a)).item())

        if t_star is not None:
            if isinstance(t_star, (int, float)):
                entry["err_t_abs"] = abs(float(t.detach().mean().item()) - float(t_star))
            elif isinstance(t_star, torch.Tensor) and t_star.shape == t.shape:
                # Masked L2 distance
                diff = (t.detach() - t_star.to(t)) * t_metric_mask
                entry["err_t_l2"] = float(torch.norm(diff).item())

        trace.append(entry)

    # Optionally compute theoretical stars if missing
    if get_theoretical_optimum_fn is not None and (a_star is None or t_star is None):
        try:
            opt = get_theoretical_optimum_fn(t=t.detach(), setting_parameters=base_settings)
        except Exception:
            opt = None
        if isinstance(opt, (tuple, list)) and len(opt) >= 2:
            if a_star is None:
                a_star = opt[0]
            if t_star is None:
                t_star = opt[1]

    # Evaluate final u1 and u2 with the SAME eval batch for consistency
    with torch.no_grad():
        extra_final = eval_extra
        final_u1 = float(_maybe_sum(_call_fn(u1, a.detach(), t.detach(), extra_final)))
        final_u2 = float(_maybe_sum(_call_fn(u2, a.detach(), t.detach(), extra_final)))

    # Evaluate reference u1*, u2* on the SAME eval batch
    u1_star_val, u2_star_val = None, None
    if a_star is not None and t_star is not None:
        try:
            with torch.no_grad():
                u1_star_val = float(_maybe_sum(_call_fn(u1, a_star, t_star, extra_final)))
                # For u2*, evaluate agent utility at final t (or t_star if you prefer)
                u2_star_val = float(_maybe_sum(_call_fn(u2, a_star, t.detach(), extra_final)))
        except Exception:
            pass

    # Print summary
    print("==== Optimization summary ====")
    print(f"a* (ref): {a_star}")
    print(f"t* (ref): {t_star}")
    print(f"final a: {a.detach().cpu().numpy()}")
    print(f"final t: {t.detach().cpu().numpy()}")
    print(f"final u1 (eval): {final_u1}" + (f"   (u1* (eval) ≈ {u1_star_val})" if u1_star_val is not None else ""))
    print(f"final u2 (eval): {final_u2}" + (f"   (u2* (eval) ≈ {u2_star_val})" if u2_star_val is not None else ""))
    print("================================")

    return t, a, u1_values, u2_values, trace
