import jax
import jax.numpy as jnp
import math
import numpy as np

from typing import List, Tuple

def dn_elements(n: int):
    G = [('rot', i) for i in range(n)] + [('ref', i) for i in range(n)]
    idx = {g: i for i, g in enumerate(G)}
    return G, idx

def mult(g, h, n):
    tg, k = g; th, l = h
    if tg=='rot' and th=='rot': return ('rot', (k + l) % n)
    if tg=='rot' and th=='ref': return ('ref', (l - k) % n)
    if tg=='ref' and th=='rot': return ('ref', (k + l) % n)
    if tg=='ref' and th=='ref': return ('rot', (l - k) % n)
    raise ValueError

def idx_mul(i: int, j: int, G, idx, p: int) -> int:
    g, h = G[i], G[j]
    return idx[mult(g, h, p)]

def make_dihedral_dataset_with_test(
    p: int,
    batch_size: int,
    num_batches: int,
    seed: int,
    *,
    test_batch_size: int | None = None,
    shuffle_test: bool = True,
    drop_remainder: bool = False,
):
    G = [('rot', i) for i in range(p)] + [('ref', i) for i in range(p)]
    idx = {g: i for i, g in enumerate(G)}

    group_size = len(G)
    total_pairs = group_size * group_size
    total_train = num_batches * batch_size

    pairs = np.array([(idx[g], idx[h]) for g in G for h in G], dtype=np.int32)
    labels = np.array([idx[mult(g, h, p)] for g in G for h in G], dtype=np.int32)

    key = jax.random.PRNGKey(seed)
    perm = np.array(jax.random.permutation(key, total_pairs))

    if total_train > total_pairs:
        raise ValueError(
            f"Train size {total_train} > total pairs {total_pairs}; cannot fill train without repeats."
        )

    test_is_full = (total_train == total_pairs)

    train_idx = perm[:total_train]
    x_train = jnp.array(pairs[train_idx]).reshape(num_batches, batch_size, 2)
    y_train = jnp.array(labels[train_idx]).reshape(num_batches, batch_size)

    if test_is_full:
        test_idx = np.arange(total_pairs, dtype=np.int32)
        if shuffle_test:
            rng = np.random.default_rng(seed ^ 0xBEEF)
            rng.shuffle(test_idx)
    else:
        mask = np.ones(total_pairs, dtype=bool)
        mask[train_idx] = False
        test_idx = np.nonzero(mask)[0]
        if shuffle_test:
            rng = np.random.default_rng(seed ^ 0xBEEF)
            rng.shuffle(test_idx)

    B_test = int(test_batch_size) if test_batch_size is not None else int(batch_size)
    if drop_remainder:
        K_test = len(test_idx) // B_test
        use = test_idx[: K_test * B_test]
        x_test_batches = jnp.array(pairs[use].reshape(K_test, B_test, 2))
        y_test_batches = jnp.array(labels[use].reshape(K_test, B_test))
    else:
        rem = len(test_idx) % B_test
        if rem == 0:
            use = test_idx
        else:
            pad = B_test - rem
            pad_idx = np.concatenate([test_idx, test_idx[:pad]], axis=0)
            use = pad_idx
        K_test = len(use) // B_test
        x_test_batches = jnp.array(pairs[use].reshape(K_test, B_test, 2))
        y_test_batches = jnp.array(labels[use].reshape(K_test, B_test))

    x_flat = np.asarray(x_test_batches).reshape(-1, 2)
    y_flat = np.asarray(y_test_batches).reshape(-1)
    y_from_x = np.array([idx_mul(i, j, G, idx, p) for i, j in x_flat], dtype=np.int32)
    assert np.array_equal(y_flat, y_from_x), "Mismatch: some (x,y) in test are misaligned"

    if not test_is_full:
        train_pairs = np.asarray(x_train).reshape(-1, 2)
        test_pairs  = np.asarray(x_test_batches).reshape(-1, 2)
        train_set = set(map(tuple, train_pairs))
        test_set  = set(map(tuple, test_pairs))
        assert train_set.isdisjoint(test_set), "Train and test are not disjoint!"

    return x_train, y_train, x_test_batches, y_test_batches

def check_representation_consistency(G, R, mult, p, tol=1e-6):
    for g in G:
        for h in G:
            lhs = R(mult(g, h, p))
            rhs = R(g) @ R(h)
            err = jnp.linalg.norm(lhs - rhs)
            if err > tol:
                print(f"Inconsistency at g={g}, h={h}, error={err:.2e}")

def enumerate_subgroups_Dn(n):
    subs = []
    seen = set()
    def add(name, H):
        key = frozenset(H)
        if key not in seen:
            seen.add(key)
            subs.append((name, H))
    for d in range(1, n+1):
        if n % d == 0:
            step = n // d
            H = [('rot', i*step) for i in range(d)]
            add(f"C_{d}", H)
    for m in range(n):
        H = [('rot', 0), ('ref', m)]
        add(f"Refl2_{m}", H)
    for d in range(2, n+1):
        if n % d == 0:
            step = n // d
            for m in range(step):
                rots  = [('rot', t*step) for t in range(d)]
                refls = [('ref', (m + t*step) % n) for t in range(d)]
                H = rots + refls
                add(f"Dih_{d}_axis_{m}", H)
    return subs

def is_subgroup(H, mult, inv, p):
    e = ('rot', 0)
    if e not in H:
        return False
    for a in H:
        for b in H:
            if mult(a, b, p) not in H:
                return False
    for a in H:
        if inv(a, p) not in H:
            return False
    return True

def inv(g, p):
    t,k = g
    return ('rot', (-k)%p) if t=='rot' else ('ref', k)

def build_coset_masks(G, subgroups, mult, p, side="left"):
    index_of = {g:i for i,g in enumerate(G)}
    coset_masks = {}
    for H_name, H_elems in subgroups:
        seen = set(); cid = 0
        for g in G:
            gi = index_of[g]
            if gi in seen:
                continue
            if side == "left":
                coset_idx = [ index_of[mult(g, h, p)] for h in H_elems ]
            else:
                coset_idx = [ index_of[mult(h, g, p)] for h in H_elems ]
            for j in coset_idx:
                seen.add(j)
            mask = np.zeros(len(G), dtype=bool)
            mask[coset_idx] = True
            coset_masks[(H_name, cid)] = mask
            cid += 1
    return coset_masks

def mult_chain(elems: List[Tuple[str,int]], n: int):
    if not elems:
        return ('rot', 0)
    acc = elems[0]
    for g in elems[1:]:
        acc = mult(acc, g, n)
    return acc

def build_cayley_table(n: int):
    G = [('rot',i) for i in range(n)] + [('ref',i) for i in range(n)]
    idx = {g:i for i,g in enumerate(G)}
    m = len(G)
    table_np = np.zeros((m, m), dtype=np.int32)
    for i, g in enumerate(G):
        for j, h in enumerate(G):
            table_np[i, j] = idx[mult(g, h, n)]
    return G, idx, jnp.array(table_np, dtype=jnp.int32)

def make_dihedral_dataset_k_ary(n: int,
                                batch_size: int,
                                num_batches: int,
                                seed: int,
                                arity: int = 3,
                                exhaustive: bool = False):
    assert arity >= 2, "arity must be >= 2"
    G, idx, table = build_cayley_table(n)
    m = len(G)
    total = num_batches * batch_size
    key = jax.random.PRNGKey(seed)
    if exhaustive:
        all_seq = np.array(np.meshgrid(*[np.arange(m) for _ in range(arity)], indexing='ij'))
        all_seq = all_seq.reshape(arity, -1).T
        perm = jax.random.permutation(key, all_seq.shape[0])[:total]
        arr = jnp.array(all_seq, dtype=jnp.int32)[perm]
    else:
        arr = jax.random.randint(key, shape=(total, arity), minval=0, maxval=m, dtype=jnp.int32)
    res = arr[:, 0]
    for t in range(1, arity):
        res = table[res, arr[:, t]]
    x = arr.reshape(num_batches, batch_size, arity)
    y = res.reshape(num_batches, batch_size)
    return x, y, G, idx, table

def make_eval_grid_k_ary(G, idx, table, n: int, arity: int, batch_size: int):
    m = len(G)
    axes = [np.arange(m) for _ in range(arity)]
    all_seq = np.array(np.meshgrid(*axes, indexing='ij'))
    all_seq = all_seq.reshape(arity, -1).T
    x_eval = jnp.array(all_seq, dtype=jnp.int32)
    res = x_eval[:, 0]
    for t in range(1, arity):
        res = table[res, x_eval[:, t]]
    y_eval = res.astype(jnp.int32)
    total_eval_points = x_eval.shape[0]
    num_full_batches  = total_eval_points // batch_size
    remain            = total_eval_points % batch_size
    if remain > 0:
        pad = batch_size - remain
        x_pad = jnp.zeros((pad, arity), dtype=jnp.int32)
        y_pad = jnp.zeros((pad,), dtype=jnp.int32)
        x_eval = jnp.concatenate([x_eval, x_pad], axis=0)
        y_eval = jnp.concatenate([y_eval, y_pad], axis=0)
        num_eval_batches = num_full_batches + 1
    else:
        num_eval_batches = num_full_batches
    x_eval_batches = x_eval.reshape(num_eval_batches, batch_size, arity)
    y_eval_batches = y_eval.reshape(num_eval_batches, batch_size)
    return x_eval_batches, y_eval_batches

if __name__ == "__main__":
    Ns = [3,4,5,6,8,10,18]
    for n in Ns:
        G = [('rot', i) for i in range(n)] + [('ref', i) for i in range(n)]
        e = ('rot', 0)
        for a in G:
            for b in G:
                assert mult(a, b, n) in G
        for g in G:
            ig = inv(g, n)
            assert mult(g, ig, n) == e and mult(ig, g, n) == e
        bad = []
        for name, H in enumerate_subgroups_Dn(n):
            if not is_subgroup(H, mult, inv, n):
                bad.append(name)
        if bad:
            print(f"[n={n}] Not subgroups: {bad}")
        else:
            print(f"[n={n}] all enumerated subgroups pass")
    print("Done.")
