from typing import Any, Dict, List, Tuple
from functools import partial

import jax
import jax.numpy as jnp
import numpy as np

import dihedral
import DFT


def make_full_eval_grid(p: int) -> Tuple[jnp.ndarray, jnp.ndarray]:
    G, _ = DFT.make_irreps_Dn(p)
    idx = {g: i for i, g in enumerate(G)}
    group_size = len(G)

    x_eval = dihedral.build_index_grid(group_size, arity=2)

    x_np = np.asarray(x_eval)
    labels = [idx[dihedral.mult(G[i], G[j], p)] for (i, j) in x_np]
    y_eval = jnp.array(labels, dtype=jnp.int32)

    return x_eval, y_eval


def make_train_and_test_batches(
    p: int,
    batch_size: int,
    k: int,
    random_seed_ints: List[int],
    *,
    test_batch_size: int | None = None,
    shuffle_test: bool = True,
    drop_remainder: bool = True,
) -> Tuple[
    List[Tuple[jnp.ndarray, jnp.ndarray]],
    jnp.ndarray,
    jnp.ndarray,
    jnp.ndarray,
    jnp.ndarray,
]:
    train_list: List[Tuple[jnp.ndarray, jnp.ndarray]] = []
    x_train_stack = []
    y_train_stack = []
    x_test_stack = []
    y_test_stack = []

    for seed in random_seed_ints:
        x_tr, y_tr, x_te, y_te = dihedral.make_dihedral_dataset_with_test(
            p,
            batch_size,
            k,
            seed,
            test_batch_size=test_batch_size,
            shuffle_test=shuffle_test,
            drop_remainder=drop_remainder,
        )
        train_list.append((x_tr, y_tr))
        x_train_stack.append(x_tr)
        y_train_stack.append(y_tr)
        x_test_stack.append(x_te)
        y_test_stack.append(y_te)

    x_train_batches = jnp.stack(x_train_stack, axis=0)
    y_train_batches = jnp.stack(y_train_stack, axis=0)

    K_tests = [arr.shape[0] for arr in x_test_stack]
    B_tests = [arr.shape[1] for arr in x_test_stack]
    if len(set(K_tests)) != 1 or len(set(B_tests)) != 1:
        raise ValueError(f"Test batch shapes differ across seeds: K={K_tests}, B={B_tests}")

    x_test_batches = jnp.stack(x_test_stack, axis=0)
    y_test_batches = jnp.stack(y_test_stack, axis=0)

    print("x_train_batches.shape =", x_train_batches.shape)
    print("y_train_batches.shape =", y_train_batches.shape)
    print("x_test_batches.shape  =", x_test_batches.shape)
    print("y_test_batches.shape  =", y_test_batches.shape)
    print(f"Number of train batches per model: {x_train_batches.shape[1]}")
    print(f"Number of test  batches per model: {x_test_batches.shape[1]}")

    dataset_size_bytes = x_train_batches.size * x_train_batches.dtype.itemsize
    dataset_size_mb = dataset_size_bytes / (1024**2)
    print(f"Train dataset size per model: {dataset_size_mb:.2f} MB")

    return train_list, x_train_batches, y_train_batches, x_test_batches, y_test_batches


@partial(jax.jit, static_argnames=("shuffle_within_batch", "debug", "samples_to_check"))
def shuffle_batches_for_epoch(
    x_batches: jnp.ndarray,
    y_batches: jnp.ndarray,
    epoch: int,
    seeds: jnp.ndarray,
    shuffle_within_batch: bool = True,
    debug: bool = False,
    samples_to_check: int = 5,
):
    M, K, B = x_batches.shape[0], x_batches.shape[1], x_batches.shape[2]

    keys_k = jax.vmap(lambda s: jax.random.fold_in(jax.random.PRNGKey(s), epoch))(seeds)
    perms_k = jax.vmap(lambda k: jax.random.permutation(k, K))(keys_k)

    gather_k_x = jnp.broadcast_to(perms_k[:, :, None, None], (M, K, B, 1))
    gather_k_y = jnp.broadcast_to(perms_k[:, :, None], (M, K, B))
    x_shuf = jnp.take_along_axis(x_batches, gather_k_x, axis=1)
    y_shuf = jnp.take_along_axis(y_batches, gather_k_y, axis=1)

    if not shuffle_within_batch:
        if debug:
            _debug_check_alignment(
                x_batches,
                y_batches,
                x_shuf,
                y_shuf,
                perms_k=perms_k,
                perms_b=None,
                samples_to_check=samples_to_check,
            )
        return x_shuf, y_shuf

    seeds_b = jnp.bitwise_xor(seeds.astype(jnp.uint32), jnp.uint32(0xBEEF))
    keys_b = jax.vmap(lambda s: jax.random.fold_in(jax.random.PRNGKey(s), epoch))(seeds_b)
    perms_b = jax.vmap(lambda k: jax.random.permutation(k, B))(keys_b)

    gather_b_x = jnp.broadcast_to(perms_b[:, None, :, None], (M, K, B, 1))
    gather_b_y = jnp.broadcast_to(perms_b[:, None, :], (M, K, B))
    x_shuf = jnp.take_along_axis(x_shuf, gather_b_x, axis=2)
    y_shuf = jnp.take_along_axis(y_shuf, gather_b_y, axis=2)

    if debug:
        _debug_check_alignment(
            x_batches,
            y_batches,
            x_shuf,
            y_shuf,
            perms_k=perms_k,
            perms_b=perms_b,
            samples_to_check=samples_to_check,
        )

    return x_shuf, y_shuf


def _debug_check_alignment(x_in, y_in, x_out, y_out, *, perms_k, perms_b, samples_to_check: int):
    M, K, B = x_in.shape[0], x_in.shape[1], x_in.shape[2]
    Ii = tuple(range(min(M, samples_to_check)))
    Jj = tuple(range(min(K, samples_to_check)))
    Kk = tuple(range(min(B, samples_to_check)))

    ok = jnp.array(True, dtype=jnp.bool_)
    for i in Ii:
        for j in Jj:
            for k in Kk:
                jj = perms_k[i, j]
                kk = k if (perms_b is None) else perms_b[i, k]

                x_ref = x_in[i, jj, kk, :]
                y_ref = y_in[i, jj, kk]
                x_now = x_out[i, j, k, :]
                y_now = y_out[i, j, k]

                ok_x = jnp.all(x_ref == x_now)
                ok_y = jnp.array(y_ref == y_now)
                ok = jnp.logical_and(ok, jnp.logical_and(ok_x, ok_y))

    jax.debug.print("[shuffle] alignment ok? {}", ok)

    def _host_assert(flag):
        if not bool(flag):
            raise AssertionError("[shuffle] debug check failed: x/y misaligned.")

    jax.debug.callback(_host_assert, ok)


@jax.jit
def train_epoch(states, x_batches, y_batches, initial_metrics):
    def train_step(state_metrics, batch):
        states_, metrics_ = state_metrics
        x, y = batch
        new_states, new_metrics = jax.vmap(
            lambda st, m, xb, yb: st.train_step(m, (xb, yb)),
            in_axes=(0, 0, 0, 0),
        )(states_, metrics_, x, y)
        return (new_states, new_metrics), None

    initial_state_metrics = (states, initial_metrics)
    transposed_x = x_batches.transpose(1, 0, 2, 3)
    transposed_y = y_batches.transpose(1, 0, 2)
    (new_states, new_metrics), _ = jax.lax.scan(train_step, initial_state_metrics, (transposed_x, transposed_y))
    return new_states, new_metrics


@jax.jit
def eval_model(states, x_batches, y_batches, initial_metrics):
    def eval_step(metrics_, batch):
        x, y = batch
        new_metrics = jax.vmap(
            lambda st, m, xb, yb: st.eval_step(m, (xb, yb)),
            in_axes=(0, 0, 0, 0),
        )(states, metrics_, x, y)
        return new_metrics, None

    transposed_x = x_batches.transpose(1, 0, 2, 3)
    transposed_y = y_batches.transpose(1, 0, 2)
    final_metrics, _ = jax.lax.scan(eval_step, initial_metrics, (transposed_x, transposed_y))
    return final_metrics
