from functools import partial
import typing as _t

import flax.linen as nn
import jax
import jax.numpy as jnp
import optax
import optax.contrib
import pydantic
from jax import checkpoint as remat
from jax.debug import print as id_print

from multimodel_train_state import MultiModelTrainState
import logging
from jax.debug import print as jprint  # at top

_DEFAULT_DP_MICROBATCH = 64  # good starting point on A100; try 64 or 128

# -----------------------
# Config
# -----------------------
class DPParams(pydantic.BaseModel):
    noise_multiplier: float
    l2_norm_clip: float
    delta: float


# -----------------------
# Helpers
# -----------------------
def map_over_models(f):
    """Maps f over each per-model subtree: {'model_0': ..., 'model_1': ..., ...}."""
    def _is_model_leaf(x):
        return isinstance(x, dict) and ("model_0" not in x.keys())
    return partial(jax.tree_util.tree_map, f, is_leaf=_is_model_leaf)


def flatten_over_models(tree):
    """Flatten outputs across models into a single array."""
    return jnp.array(jax.tree_util.tree_leaves(tree))


def convert_to_per_model_tree(array, num_models, in_axes=1):
    """Convert array mapped on axis to pytree keyed by model_i."""
    pydef = jax.tree_util.tree_flatten({f"model_{i}": 0 for i in range(num_models)})[1]
    leaves = list(map(list, jax.vmap(lambda x: [x], in_axes=in_axes, out_axes=0)(array)))[0]
    return jax.tree_util.tree_unflatten(pydef, leaves)


def _extract_component(tree_of_tuples, idx: int):
    """Extract idx-th element from tuple leaves."""
    def _is_leaf(x):
        return (not isinstance(x, dict)) or isinstance(x, tuple)
    return jax.tree_util.tree_map(
        lambda node: node[idx] if isinstance(node, tuple) else node,
        tree_of_tuples,
        is_leaf=_is_leaf,
    )


# -----------------------
# Model perms
# -----------------------
def generate_model_perms(key, num_models, num_epochs, num_samples, batch_size,
                         sample_non_canaries: bool, canary_idx: int | None):
    """Generate model permutations for the training process."""
    key, key_datasets = jax.random.split(key)
    if not sample_non_canaries:
        training_datasets = jnp.repeat(jnp.arange(num_samples).reshape(1, -1), num_models, axis=0)
    else:
        training_datasets = jnp.empty((num_models, num_samples // 2), dtype=int)
        model_dataset_keys = jax.random.split(key_datasets, num_models)
        for model_idx in range(num_models):
            if canary_idx is None:
                current_dataset = jax.random.permutation(model_dataset_keys[model_idx], num_samples)[: num_samples // 2]
            else:
                current_dataset = jnp.concatenate(
                    [
                        jax.random.choice(
                            model_dataset_keys[model_idx],
                            jnp.delete(jnp.arange(num_samples), canary_idx),
                            shape=(num_samples // 2 - 1,),
                            replace=False,
                        ),
                        jnp.array([canary_idx]),
                    ]
                )
            training_datasets = training_datasets.at[model_idx].set(current_dataset)

    steps_per_epoch = training_datasets.shape[1] // batch_size
    key, key_batching = jax.random.split(key)
    key_model_batches = jax.random.split(key_batching, num_models)

    model_perms = jnp.empty((num_epochs, steps_per_epoch, batch_size, num_models), dtype=int)
    for model_idx in range(num_models):
        key_epochs = jax.random.split(key_model_batches[model_idx], num_epochs)
        for epoch_idx in range(num_epochs):
            key_subset, key_shuffle = jax.random.split(key_epochs[epoch_idx])
            if canary_idx is None:
                current_samples = jax.random.choice(
                    key_subset,
                    training_datasets[model_idx],
                    shape=(steps_per_epoch * batch_size,),
                    replace=False,
                )
            else:
                current_samples = jnp.concatenate(
                    [
                        jax.random.choice(
                            key_subset,
                            training_datasets[model_idx, :-1],
                            shape=(steps_per_epoch * batch_size - 1,),
                            replace=False,
                        ),
                        jnp.array([canary_idx]),
                    ]
                )
            current_samples = jax.random.permutation(key_shuffle, current_samples)
            model_perms = model_perms.at[epoch_idx, :, :, model_idx].set(
                current_samples.reshape(steps_per_epoch, batch_size)
            )
    return model_perms


# -----------------------
# State init (with BN)
# -----------------------
def create_train_state(rngs, learning_rate, momentum, num_models,
                       arch: nn.Module, image_shape, dp_params: DPParams | None, 
                       weight_decay: float | None = None):
    
    wd = 0.0 if weight_decay is None else float(weight_decay)

    def _build_optimizer(seed: int):
        if dp_params is None:
            # SGD + decoupled weight decay
            return optax.chain(
                optax.add_decayed_weights(wd) if wd > 0 else optax.identity(),
                optax.sgd(learning_rate, momentum),
            )
        else:
            # DP-SGD path (unchanged)
            return optax.contrib.dpsgd(
                learning_rate=learning_rate,
                l2_norm_clip=dp_params.l2_norm_clip,
                noise_multiplier=dp_params.noise_multiplier,
                seed=seed,
                momentum=None,
            )

    rngs_init, rngs_seeds = zip(*map(lambda r: jax.random.split(r), rngs))
    seeds = tuple(map(lambda r: int(jax.random.randint(r, (), 0, 2**31 - 1)), rngs_seeds))

    params, batch_stats = {}, {}
    dummy_x = jnp.ones([1, *image_shape])
    for i in range(num_models):
        variables = arch.init(rngs_init[i], dummy_x, train=True)
        params[f"model_{i}"] = variables["params"]
        batch_stats[f"model_{i}"] = variables.get("batch_stats", {})

    txs = {f"model_{i}": _build_optimizer(seeds[i]) for i in range(num_models)}
    state = MultiModelTrainState.create(apply_fn=arch.apply, params=params, txs=txs)
    return state, batch_stats


# -----------------------
# Apply (with BN)
# -----------------------
def _apply_train(apply_fn, params, batch_stats, images):
    if batch_stats:
        logits, mutable_out = apply_fn(
            {"params": params, "batch_stats": batch_stats},
            images,
            train=True,
            mutable=["batch_stats"],
        )
        new_bs = mutable_out["batch_stats"]
    else:
        logits = apply_fn({"params": params}, images, train=True, mutable=False)
        new_bs = batch_stats
    return logits, new_bs


def _apply_eval(apply_fn, params, batch_stats, images):
    if batch_stats:
        logits = apply_fn({"params": params, "batch_stats": batch_stats}, images, train=False, mutable=False)
    else:
        logits = apply_fn({"params": params}, images, train=False, mutable=False)
    return logits


def get_logits(state, images, batch_stats, train: bool):
    if train:
        f = lambda _p, _b: _apply_train(state.apply_fn, _p, _b, images)[0]
    else:
        f = lambda _p, _b: _apply_eval(state.apply_fn, _p, _b, images)
    return map_over_models(lambda _p, _b: f(_p, _b))(state.params, batch_stats)


# -----------------------
# Loss/Grad
# -----------------------
def apply_model_perms(
    state: MultiModelTrainState,
    images: jnp.ndarray,
    labels: jnp.ndarray,
    perm: jnp.ndarray,
    batch_stats: dict,
    *,
    per_sample_gradients: bool,  # MUST be a Python bool (static)
    label_smoothing: float = 0.0,        # <-- NEW (static)
) -> tuple[dict, jnp.ndarray, dict]:
    """Compute grads + accuracy + new batch_stats per model."""
    num_models = perm.shape[-1]
    perms_per_model = convert_to_per_model_tree(perm, num_models, in_axes=1)

    def per_model_step(params, model_perm, model_bs):
        model_images = images[model_perm]
        labels_int32 = labels[model_perm].astype(jnp.int32)
        model_labels_1h = jax.nn.one_hot(labels_int32, 10)

        # label smoothing
        ls = float(label_smoothing)
        if ls and ls > 0.0:
            num_classes = model_labels_1h.shape[-1]
            model_labels_1h = (1.0 - ls) * model_labels_1h + ls / num_classes

        @remat
        def loss_fn_train(p, bs, x, y):
            logits, new_bs = _apply_train(state.apply_fn, p, bs, x)
            loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=y))
            return loss, new_bs

        def loss_fn_eval(p, bs, x, y):
            logits = _apply_eval(state.apply_fn, p, bs, x)
            return jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=y))

        if per_sample_gradients:
            # -------------------------------
            # DP path: microbatched per-example grads
            # -------------------------------
            n = model_images.shape[0]  # static compile-time dimension (batch size)
            assert isinstance(n, int), "Batch size must be static for DP microbatching."

            # Pick the largest divisor of n that is <= preferred microbatch
            def _pick_divisor(n, pref):
                for cand in (pref, 128, 64, 32, 16, 8, 4, 2, 1):
                    if cand <= n and (n % cand == 0):
                        return cand
                return 1  # fallback

            mb = _pick_divisor(n, _DEFAULT_DP_MICROBATCH)
            num_chunks = n // mb

            # vmap the per-example grad over a microbatch
            vm_grad = jax.vmap(
                jax.grad(lambda p, x_i, y_i: loss_fn_eval(p, model_bs, x_i[None], y_i[None])),
                in_axes=(None, 0, 0),
                out_axes=0,
            )

            # Preallocate per-example gradient buffers (shape: (n, *param_shape))
            per_ex_buf = jax.tree_util.tree_map(lambda p: jnp.zeros((n,) + p.shape, p.dtype), params)

            # shapes of inputs for dynamic_slice
            img_tail = model_images.shape[1:]  # (H,W,C)
            lbl_tail = model_labels_1h.shape[1:]  # (num_classes,)

            def body_fun(i, buf):
                s = i * mb
                # Slice microbatch from inputs
                x_mb = jax.lax.dynamic_slice(model_images,
                                             (s, 0, 0, 0),
                                             (mb, *img_tail))
                y_mb = jax.lax.dynamic_slice(model_labels_1h,
                                             (s, 0),
                                             (mb, *lbl_tail))
                # Per-example grads for this chunk
                g_mb = vm_grad(params, x_mb, y_mb)

                # Write chunk into preallocated buffers
                def _update(buf_leaf, g_leaf):
                    # buf_leaf: (n, *param_shape), g_leaf: (mb, *param_shape)
                    start_idx = (s,) + (0,) * (g_leaf.ndim - 1)
                    return jax.lax.dynamic_update_slice(buf_leaf, g_leaf, start_idx)

                return jax.tree_util.tree_map(_update, buf, g_mb)

            per_ex_grads = jax.lax.fori_loop(0, num_chunks, body_fun, per_ex_buf)

            grads, new_bs = per_ex_grads, model_bs  # BN not updated in DP path
        else:
            (loss, new_bs), grads = jax.value_and_grad(
                lambda p, b: loss_fn_train(p, b, model_images, model_labels_1h),
                has_aux=True,
            )(params, model_bs)

        logits_eval = _apply_eval(state.apply_fn, params, new_bs, model_images)
        acc = jnp.mean(jnp.argmax(logits_eval, -1) == labels_int32)
        return grads, acc, new_bs

    grads_per_model = map_over_models(lambda _p, _perm, _bs: per_model_step(_p, _perm, _bs))(
        state.params, perms_per_model, batch_stats
    )

    grads_tree = _extract_component(grads_per_model, 0)
    acc_tree   = _extract_component(grads_per_model, 1)
    bs_tree    = _extract_component(grads_per_model, 2)

    return grads_tree, flatten_over_models(acc_tree), bs_tree


# -----------------------
# Training
# -----------------------
@partial(jax.jit, static_argnames=("use_dp", "label_smoothing"), donate_argnums=(0,))
def train_epoch(
    state: MultiModelTrainState,
    batch_stats: dict,
    train_images: jnp.ndarray,
    train_targets: jnp.ndarray,
    perms: jnp.ndarray,
    *,
    use_dp: bool,                 # static
    label_smoothing: float = 0.0, # <-- NEW static
):
    def step(carry, perm):
        st, bs = carry
        grads, acc, new_bs = apply_model_perms(
            st, train_images, train_targets, perm, bs,
            per_sample_gradients=use_dp,
            label_smoothing=label_smoothing,   # <-- pass through
        )
        new_st = st.apply_gradients(grads=grads)
        return (new_st, new_bs), acc

    (state, batch_stats), accs = jax.lax.scan(step, (state, batch_stats), perms)
    train_acc = jnp.mean(accs, axis=0)
    return state, batch_stats, train_acc


@jax.jit
def _eval_model_mean(state, batch_stats, images, targets):
    logits_per_model = get_logits(state, images, batch_stats, train=False)
    one_hot = jax.nn.one_hot(targets, 10)
    loss_tree = map_over_models(lambda _logits: jnp.mean(optax.softmax_cross_entropy(logits=_logits, labels=one_hot)))(
        logits_per_model
    )
    acc_tree = map_over_models(lambda _logits: jnp.mean(jnp.argmax(_logits, -1) == targets))(logits_per_model)
    return flatten_over_models(loss_tree), flatten_over_models(acc_tree)


@partial(jax.jit, static_argnames=("use_dp", "verbose", "label_smoothing"), donate_argnums=(0,))
def _train_model_jitted(
    state: MultiModelTrainState,
    batch_stats: dict,
    epoch_perms: jnp.ndarray,
    train_images: jnp.ndarray,
    train_targets: jnp.ndarray,
    test_images: _t.Optional[jnp.ndarray],
    test_targets: _t.Optional[jnp.ndarray],
    *,
    use_dp: bool,                 # static
    verbose: bool,                # static
    label_smoothing: float = 0.0, # <-- NEW static
):
    """Multi-epoch training (BN-aware). `use_dp` and `verbose` are static."""
    num_epochs = epoch_perms.shape[0]
    num_models = epoch_perms.shape[-1]

    sums_test_loss = jnp.zeros((num_models,))
    sums_test_acc  = jnp.zeros((num_models,))
    sums_train_acc = jnp.zeros((num_models,))

    def epoch_step(carry, perms_for_epoch):
        st, bs, sum_loss, sum_acc, sum_train = carry
        st, bs, epoch_train_acc = train_epoch(
            st, bs, train_images, train_targets, perms_for_epoch, use_dp=use_dp, label_smoothing=label_smoothing
        )
        sum_train = sum_train + epoch_train_acc

        if test_images is not None:
            epoch_loss, epoch_acc = _eval_model_mean(st, bs, test_images, test_targets)
            sum_loss = sum_loss + epoch_loss
            sum_acc  = sum_acc + epoch_acc
            if verbose:
                id_print("test_acc_mean: {x:.3f}", x=epoch_acc.mean())
                id_print("train_acc_mean: {x:.3f}", x=epoch_train_acc.mean())
        else:
            if verbose:
                id_print("train_acc_mean: {x:.3f}", x=epoch_train_acc.mean())

        return (st, bs, sum_loss, sum_acc, sum_train), None

    (state, batch_stats, sums_test_loss, sums_test_acc, sums_train_acc), _ = jax.lax.scan(
        epoch_step, (state, batch_stats, sums_test_loss, sums_test_acc, sums_train_acc), epoch_perms
    )

    denom = jnp.asarray(num_epochs, dtype=sums_train_acc.dtype)
    return (
        state,
        batch_stats,
        sums_test_loss / jnp.maximum(denom, 1.0),
        sums_test_acc / jnp.maximum(denom, 1.0),
        sums_train_acc / denom,
    )


def train_model(
    state: MultiModelTrainState,
    batch_stats: dict,
    epoch_perms: jnp.ndarray,
    *,
    train_images: jnp.ndarray,
    train_targets: jnp.ndarray,
    test_images: _t.Optional[jnp.ndarray],
    test_targets: _t.Optional[jnp.ndarray],
    use_dp: bool = False,
    verbose: bool = True,
    label_smoothing: float = 0.0,  # <-- NEW
):
    """Non-jitted wrapper that coerces `use_dp`/`verbose` to Python bools so JIT sees static values."""
    # Coerce to plain Python bools (avoid jnp.bool_ / tracers sneaking in)
    use_dp_bool = bool(use_dp)
    verbose_bool = bool(verbose)

    return _train_model_jitted(
        state,
        batch_stats,
        epoch_perms,
        train_images,
        train_targets,
        test_images,
        test_targets,
        use_dp=use_dp_bool,
        verbose=verbose_bool,
        label_smoothing=label_smoothing,
    )