from __future__ import annotations

import typing as _t
from functools import partial
from typing import Optional, Tuple

import jax
import jax.numpy as jnp
import optax
from multimodel_train_state import MultiModelTrainState
from tqdm import tqdm
import numpy as np

# BN-aware helpers from your unrolled module
from unrolled_canaries import flatten_over_models, get_logits, train_model, _eval_model_mean


# -----------------------
# Small jitted wrapper to run training for a (possibly truncated) perms tensor
# -----------------------
@partial(jax.jit, static_argnames=("use_dp",))
def _run_train_jitted(
    state: MultiModelTrainState,
    batch_stats: dict,
    perms: jnp.ndarray,                 # [num_epochs, steps, batch, num_models]
    images: jnp.ndarray,                # HWC in your pipeline
    labels: jnp.ndarray,
    use_dp: bool,
):
    # No eval during canary search — keeps the HLO tiny
    state, batch_stats, _, _, _ = train_model(
        state, batch_stats, perms,
        train_images=images, train_targets=labels,
        test_images=None, test_targets=None,
        use_dp=use_dp, verbose=False,
    )
    return state, batch_stats


# -----------------------
# Loss that trains only the tail (already pre-sliced)
# -----------------------
@partial(
    jax.jit,
    static_argnames=("loss_type", "loss_agg", "is_mlp", "use_dp"),
)
def _canary_loss_tail_only(
    canary: jnp.ndarray,
    # fixtures (warm states + tail perms)
    warm_state_in: MultiModelTrainState,
    warm_bs_in: dict,
    perms_tail_in: jnp.ndarray,
    state_out_final: MultiModelTrainState,   # OUT is precomputed full; no training here
    bs_out_final: dict,
    # data/meta
    canary_label: int,
    train_images_full: jnp.ndarray,
    train_targets_full: jnp.ndarray,
    canary_idx: int,
    is_mlp: bool,
    use_dp: bool = False,
    loss_type: str = "l2",
    loss_agg: str = "max",
    # --- NEW: optional test-set for eval ---
    test_images: _t.Optional[jnp.ndarray] = None,
    test_targets: _t.Optional[jnp.ndarray] = None,
    report_eval: bool = False,   # dynamic bool
):
    # splice canary
    canary_label = jnp.array(canary_label, dtype=jnp.int16)
    train_images  = train_images_full.at[canary_idx].set(canary)
    train_targets = train_targets_full.at[canary_idx].set(canary_label)

    # IN: run only the tail from the warm state (grads flow to canary)
    state_in, bs_in = _run_train_jitted(
        warm_state_in, warm_bs_in, perms_tail_in, train_images, train_targets, use_dp=use_dp,
    )

    # OUT: reuse fully-precomputed state (independent of canary)
    state_out, bs_out = state_out_final, bs_out_final

    # Evaluate logits on canary (BN eval mode)
    canary_infer = canary.reshape(1, -1) if is_mlp else canary[None, :]
    logits_in  = flatten_over_models(get_logits(state_in,  canary_infer, bs_in,  train=False))[:, 0, :]
    logits_out = flatten_over_models(get_logits(state_out, canary_infer, bs_out, train=False))[:, 0, :]

    if loss_type == "l2":
        dist = jnp.mean(optax.l2_loss(logits_in, logits_out), axis=1)
    elif loss_type == "hinge":
        from lira_utils import lira_hinge_loss
        dist = lira_hinge_loss(logits_in,  canary_label[None]) - \
               lira_hinge_loss(logits_out, canary_label[None])
    else:
        raise NotImplementedError(f"loss_type {loss_type!r} not implemented")

    loss = -(jnp.max(dist) if loss_agg == "max" else jnp.mean(dist))

    # --- NEW: optional test-set evaluation as AUX (no grad path) ---
    def _do_eval(_):
        # reuse your BN-aware mean evaluator
        test_loss_pm, test_acc_pm = _eval_model_mean(state_in, bs_in, test_images, test_targets)
        # fixed-shape small aux payload
        return jnp.array([jnp.mean(test_loss_pm), jnp.mean(test_acc_pm)], dtype=jnp.float32)

    def _skip_eval(_):
        # NaNs indicate “not computed this step”
        return jnp.array([jnp.nan, jnp.nan], dtype=jnp.float32)

    have_test = (test_images is not None) & (test_targets is not None)
    aux = jax.lax.cond(have_test & report_eval, _do_eval, _skip_eval, operand=None)

    return loss, aux  # <- return aux for has_aux=True

# -----------------------
# One jitted canary step: updates canary only. All big arrays are dynamic args.
# -----------------------
def fit_canary(
    *,
    canary_params: jnp.ndarray,                          # HWC
    optimizer: optax.GradientTransformation,
    canary_search_steps: int,
    canary_label: int,
    output_path,                                         # path-like (npz)
    fixtures: Tuple[MultiModelTrainState, dict, jnp.ndarray,
                    MultiModelTrainState, dict, jnp.ndarray],
    summary_writer,
    train_images_full: jnp.ndarray,
    train_targets_full: jnp.ndarray,
    test_images: Optional[jnp.ndarray],
    test_targets: Optional[jnp.ndarray],
    is_mlp: bool,
    canary_idx: int,
    clip_canary: bool = True,
    clip_min: float | jnp.ndarray = 0.0,
    clip_max: float | jnp.ndarray = 1.0,
    loss_type: str = "l2",
    loss_agg: str = "max",
    use_dp: bool = False,
    tbptt_k_steps: int = 4,      # <<— choose small K (e.g., 4–16)
    eval_every_k: int = 10,     # <-- NEW
) -> jnp.ndarray:
    """Fast canary fitting with warmup precompute + tail-only backprop per step.
    Assumes BN-aware fixtures:
      fixtures = (init_state_in, batch_stats_in, model_perms_in,
                  init_state_out, batch_stats_out, model_perms_out)
    """
    (init_state_in,  batch_stats_in,  model_perms_in,
     init_state_out, batch_stats_out, model_perms_out) = fixtures

    # -----------------------
    # Host-side: split perms into head/tail using Python ints (static slices)
    # -----------------------
    assert model_perms_in.ndim  == 4 and model_perms_out.ndim == 4
    assert model_perms_in.shape == model_perms_out.shape
    num_epochs        = int(model_perms_in.shape[0])
    steps_per_epoch   = int(model_perms_in.shape[1])
    k_tail            = min(int(tbptt_k_steps), steps_per_epoch)
    head_len          = steps_per_epoch - k_tail
    if k_tail <= 0:
        raise ValueError("tbptt_k_steps must be >= 1")

    perms_tail_in  = model_perms_in[:, head_len:, ...]   # [E, K, B, M]
    perms_head_in  = model_perms_in[:, :head_len, ...]   # [E, S-K, B, M]
    perms_tail_out = model_perms_out[:, head_len:, ...]
    perms_head_out = model_perms_out[:, :head_len, ...]

    # -----------------------
    # Warmup once on HOST (no grad wrt canary)
    #   IN  warmup: use original dataset (no canary) on the head
    #   OUT warmup+tail: compute the full OUT model once and cache it
    # -----------------------
    if head_len > 0:
      warm_state_in,  warm_bs_in  = _run_train_jitted(
          init_state_in,  batch_stats_in,  perms_head_in,
          train_images_full, train_targets_full, use_dp=False,
      )
      warm_state_out, warm_bs_out = _run_train_jitted(
          init_state_out, batch_stats_out, perms_head_out,
          train_images_full, train_targets_full, use_dp=False,
      )
    else:
      warm_state_in,  warm_bs_in  = init_state_in,  batch_stats_in
      warm_state_out, warm_bs_out = init_state_out, batch_stats_out

    # Finish OUT once with its tail too (still independent of canary)
    state_out_final, bs_out_final = _run_train_jitted(
        warm_state_out, warm_bs_out, perms_tail_out,
        train_images_full, train_targets_full, use_dp=False,
    )

    # -----------------------
    # Jitted step (don’t capture huge arrays as constants)
    # -----------------------
    opt_state = optimizer.init(canary_params)

    @partial(
        jax.jit,
        static_argnames=("canary_label","canary_idx","loss_type","loss_agg","is_mlp","use_dp", "clip"),
        donate_argnums=(0, 1),   # donate canary_params & opt_state
    )
    def _step(
        c_params, opt_state,
        warm_state_in, warm_bs_in, perms_tail_in,
        state_out_final, bs_out_final,
        train_images_full, train_targets_full,
        # --- NEW: test set & flag ---
        test_images, test_targets,
        report_eval: bool,
        *,
        canary_label: int, canary_idx: int, loss_type: str, loss_agg: str, is_mlp: bool, use_dp: bool,
        clip: bool, clip_lo, clip_hi,
    ):
        def _loss(cp):
            return _canary_loss_tail_only(
                cp,
                warm_state_in, warm_bs_in, perms_tail_in,
                state_out_final, bs_out_final,
                canary_label,
                train_images_full, train_targets_full,
                canary_idx, is_mlp,
                use_dp=use_dp,
                loss_type=loss_type, loss_agg=loss_agg,
                test_images=test_images, test_targets=test_targets,
                report_eval=report_eval,  # <- dynamic flag
            )

        (loss_val, aux), grads = jax.value_and_grad(_loss, has_aux=True)(c_params)
        updates, opt_state = optimizer.update(grads, opt_state, c_params)
        c_params = optax.apply_updates(c_params, updates)
        if clip:
            c_params = jnp.clip(c_params, clip_lo, clip_hi)
        return c_params, opt_state, loss_val, aux  # aux = [test_loss_mean, test_acc_mean]

    # -----------------------
    # Host loop
    # -----------------------
    per_step_canaries = []
    per_step_losses = []
    per_step_test_metrics = []  # optional capture

    pbar = tqdm(total=canary_search_steps, desc="Canary Search", unit="step")
    for t in range(canary_search_steps):
        report_eval = (eval_every_k > 0) and (t % eval_every_k == 0)

        canary_params, opt_state, loss_value, aux = _step(
            canary_params, opt_state,
            warm_state_in, warm_bs_in, perms_tail_in,
            state_out_final, bs_out_final,
            train_images_full, train_targets_full,
            # pass test-set & flag
            test_images, test_targets,
            report_eval,
            canary_label=canary_label,
            canary_idx=canary_idx,
            loss_type=loss_type,
            loss_agg=loss_agg,
            is_mlp=bool(is_mlp),
            use_dp=bool(use_dp),
            clip=bool(clip_canary), clip_lo=clip_min, clip_hi=clip_max,
        )

        loss_host = float(loss_value)
        test_loss_mean, test_acc_mean = map(float, aux)  # may be NaN when not computed

        if jnp.any(jnp.isnan(canary_params)):
            break

        # --- SNAPSHOT: move to host memory (owned copy) ---
        canary_host = np.array(jax.device_get(canary_params), copy=True)
        per_step_canaries.append(canary_host)
        per_step_losses.append(loss_host)

        if report_eval:
            per_step_test_metrics.append((t, test_loss_mean, test_acc_mean))

        # progress + TB
        postfix = {"canary_loss": f"{loss_host:.4f}"}
        if report_eval and not jnp.isnan(aux).any():
            postfix["test_acc"] = f"{test_acc_mean:.3f}"
        pbar.set_postfix(**postfix)
        pbar.update(1)

        if summary_writer is not None:
            summary_writer.scalar("canary_loss", loss_host, t)
            if report_eval and not jnp.isnan(aux).any():
                summary_writer.scalar("test/acc_mean", test_acc_mean, t)
                summary_writer.scalar("test/loss_mean", test_loss_mean, t)
            summary_writer.image("canary", canary_params, t)
            summary_writer.flush()

    pbar.close()
    jnp.savez(
        output_path,
        per_step_canaries=np.stack(per_step_canaries, axis=0),
        per_step_canary_losses=np.asarray(per_step_losses, dtype=np.float32),
        # --- NEW: store test metrics (ragged → fixed array via padding or save as list of tuples)
        per_k_test_metrics=jnp.array(per_step_test_metrics) if per_step_test_metrics else jnp.empty((0,3)),
    )
    return canary_params
