from types import SimpleNamespace
from typing import Any, Dict, Optional, Tuple

import jax
import jax.numpy as jnp
import numpy as np
import torch
from flax.core import FrozenDict, freeze


def _unwrap_per_model(tree):
    """If tree == {'model_0': {...}}, return the inner dict; else return tree unchanged."""
    if isinstance(tree, dict) and "model_0" in tree and len(tree) == 1:
        return tree["model_0"]
    return tree

def _to_frozen(x):
    return x if isinstance(x, FrozenDict) else freeze(x)

def load_for_inference_torch(
    model_file: str,
) -> Tuple[Any, Dict, Optional[Dict]]:
    """Loads the saved object you wrote with torch.save and returns (arch_ctor_kwargs, params, batch_stats).

    Expected shapes:
      obj["model"] is either:
        (state, batch_stats)  # from your eval_train
      or
        {"params": {...}, "batch_stats": {...}}  # if you later trimmed it down.
    We DO NOT reconstruct the module here—just return the kwargs you’ll pass to your arch.
    """
    blob = torch.load(model_file, map_location="cpu", weights_only=False)

    model_payload = blob.get("model", None)
    if model_payload is None:
        raise ValueError("Saved file has no 'model' key.")

    # Handle the two common shapes
    if isinstance(model_payload, tuple) and len(model_payload) == 2:
        state_like, batch_stats = model_payload
        # state_like can be a Flax TrainState or a dict-like with 'params'
        params = getattr(state_like, "params", None)
        if params is None and isinstance(state_like, dict):
            params = state_like.get("params", None)
        if params is None:
            raise ValueError("Could not find params in saved state.")
    elif isinstance(model_payload, dict):
        params = model_payload.get("params", None)
        batch_stats = model_payload.get("batch_stats", None)
        if params is None:
            raise ValueError("model dict missing 'params'.")
    else:
        raise ValueError("Unrecognized saved 'model' format.")

    # unwrap {"model_0": ...}
    params = _unwrap_per_model(params)
    batch_stats = _unwrap_per_model(batch_stats) if batch_stats is not None else None

    # Freeze for Flax
    params = _to_frozen(params)
    if batch_stats is not None:
        batch_stats = _to_frozen(batch_stats)

    # Optional: architecture kwargs (if you stored them)
    arch_kwargs = blob.get("arch_kwargs", None)  # you can choose to store this when saving
    return arch_kwargs, params, batch_stats


# works with: arch = ResNet9(...), params_frozen = FrozenDict, batch_stats_frozen = FrozenDict|None

def make_model_fn_from_state(
    arch,                    # e.g., ResNet9(num_classes=10, ...)
    params_frozen: FrozenDict,
    batch_stats_frozen: FrozenDict | None = None,
    *,
    standardize: bool = False,
    mean: jnp.ndarray | None = None,   # shape (C,)
    std:  jnp.ndarray | None = None,   # shape (C,)
):
    """Returns a jitted model_fn: (NHWC float32) -> logits."""
    if standardize:
        assert mean is not None and std is not None, "Provide mean/std when standardize=True"
        mean = jnp.asarray(mean, dtype=jnp.float32).reshape(1, 1, 1, -1)
        std  = jnp.maximum(jnp.asarray(std, dtype=jnp.float32).reshape(1, 1, 1, -1), 1e-8)

    def model_fn(x_nhwc: jnp.ndarray) -> jnp.ndarray:
        x = (x_nhwc - mean) / std if standardize else x_nhwc
        if batch_stats_frozen is not None:
            return arch.apply(
                {"params": params_frozen, "batch_stats": batch_stats_frozen},
                x, train=False, mutable=False,
            )
        else:
            return arch.apply({"params": params_frozen}, x, train=False, mutable=False)

    return jax.jit(model_fn)


def load_for_inference(arch, model_blob):
    """arch: instantiated Flax module (e.g., ResNet9(...))
    model_blob: the dict you saved under "model" with keys:
                {"params": {...}, "batch_stats": {...}}
    Returns a tuple compatible with your predict() path: (state_like, batch_stats)
    where state_like has .apply_fn and .params.
    """
    params_np = model_blob["params"]
    batch_stats_np = model_blob["batch_stats"]

    # Put on device
    params = jax.tree_util.tree_map(jax.device_put, params_np)
    batch_stats = jax.tree_util.tree_map(jax.device_put, batch_stats_np)

    # Minimal state-like object for your predict() (has .apply_fn and .params)
    state_like = SimpleNamespace(apply_fn=arch.apply, params=params)
    return (state_like, batch_stats)

def to_jax_nhwc(tensor: torch.Tensor):
    # torch Tensor in (N,C,H,W) -> jax array (N,H,W,C)
    arr = tensor.permute(0, 2, 3, 1).contiguous().cpu().numpy()
    return jnp.array(arr)

def to_torch_nchw(jax_arr: jnp.ndarray, device=None):
    # jax array (N,H,W,C) -> torch (N,C,H,W)
    np_arr = np.array(jax_arr)
    t = torch.from_numpy(np_arr).permute(0, 3, 1, 2).contiguous()
    return t.to(device) if device is not None else t

def one_hot(labels: jnp.ndarray, num_classes: int):
    return jax.nn.one_hot(labels, num_classes)


import jax
import jax.numpy as jnp
from flax.core import FrozenDict

def _has_vars(tree):
    try:
        return len(jax.tree_util.tree_leaves(tree)) > 0
    except Exception:
        return False

def make_model_fn_from_state(
    arch,                      # e.g., ResNet9(...), must match training config
    state_like,                # TrainState-like (apply_fn+params) OR FrozenDict params
    batch_stats=None,          # FrozenDict or dict of batch_stats (or None)
    *,
    standardize: bool = False,
    mean=None,                 # shape (C,) if standardize=True
    std=None,                  # shape (C,) if standardize=True
):
    """
    Returns a jitted model_fn: (NHWC float32) -> logits float32.
    Works whether `state_like` is a TrainState-like object or a FrozenDict of params.
    """
    # Figure out how we will apply: via arch.apply or via state_like.apply_fn
    if hasattr(state_like, "apply_fn") and hasattr(state_like, "params"):
        apply_via_arch = False
        apply_fn = state_like.apply_fn
        params = state_like.params
    else:
        apply_via_arch = True
        apply_fn = None
        params = state_like  # expected FrozenDict

    # Standardization tensors
    if standardize:
        assert mean is not None and std is not None, "Provide mean/std when standardize=True"
        mean = jnp.asarray(mean, dtype=jnp.float32).reshape(1, 1, 1, -1)
        std  = jnp.maximum(jnp.asarray(std, dtype=jnp.float32).reshape(1, 1, 1, -1), 1e-8)

    has_bs = batch_stats is not None and _has_vars(batch_stats)

    def model_fn(x_nhwc: jnp.ndarray) -> jnp.ndarray:
        x = (x_nhwc - mean) / std if standardize else x_nhwc
        if apply_via_arch:
            if has_bs:
                return arch.apply({"params": params, "batch_stats": batch_stats}, x, train=False, mutable=False)
            else:
                return arch.apply({"params": params}, x, train=False, mutable=False)
        else:
            if has_bs:
                return apply_fn({"params": params, "batch_stats": batch_stats}, x, train=False, mutable=False)
            else:
                return apply_fn({"params": params}, x, train=False, mutable=False)

    return jax.jit(model_fn)



# @partial(jax.jit, static_argnames=("norm", "targeted", "num_classes"))
def _single_loss_logits(logits, labels_onehot, targeted: bool):
    # returns scalar loss per-example: negative log-probability of true class (cross-entropy)
    # logits shape (num_classes,)
    # labels_onehot shape (num_classes,)
    logp = jax.nn.log_softmax(logits)
    loss = -jnp.sum(logp * labels_onehot)    # cross-entropy
    return -loss if targeted else loss

import jax.numpy as jnp


def fast_gradient_method_jax(
    model_fn,
    x_jax: jnp.ndarray,   # (N,H,W,C) float32 in [0,1] (or standardized; see clip)
    eps: float,
    norm: float,          # jnp.inf or 2
    *,
    clip_min: float | None = None,
    clip_max: float | None = None,
    y: jnp.ndarray | None = None,    # (N, num_classes) one-hot or probs
    targeted: bool = False,
    batch_chunk: int = 64,
    num_classes: int = 10,
) -> jnp.ndarray:
    assert norm in (jnp.inf, 2.0), "Only L_inf or L2 supported."

    # If no labels are given, use model predictions (avoid label leaking).
    if y is None:
        def _pred_fn(xx):
            return jnp.argmax(model_fn(xx), axis=-1)
        preds = []
        for s in range(0, x_jax.shape[0], batch_chunk):
            preds.append(_pred_fn(x_jax[s:s+batch_chunk]))
        y_idx = jnp.concatenate(preds, axis=0)
        y = jax.nn.one_hot(y_idx, num_classes)

    def loss_for_batch(xx, yy):
        logits = model_fn(xx)
        logp = jax.nn.log_softmax(logits)
        loss = -(logp * yy).sum(axis=-1).mean()
        return -loss if targeted else loss

    grad_fn = jax.grad(loss_for_batch)

    adv_list = []
    for s in range(0, x_jax.shape[0], batch_chunk):
        xb = x_jax[s:s+batch_chunk]
        yb = y[s:s+batch_chunk]

        g = grad_fn(xb, yb)
        if norm == jnp.inf:
            perturb = eps * jnp.sign(g)
        else:  # L2
            flat = jnp.reshape(g, (g.shape[0], -1))
            denom = jnp.linalg.norm(flat, ord=2, axis=1, keepdims=True)
            denom = jnp.maximum(denom, 1e-12)
            unit = flat / denom
            perturb = jnp.reshape(unit, g.shape) * eps

        adv = xb + perturb
        if (clip_min is not None) and (clip_max is not None):
            adv = jnp.clip(adv, clip_min, clip_max)
        adv_list.append(adv)

    return jnp.concatenate(adv_list, axis=0)
