import os
os.environ["JAX_ENABLE_X64"] = "1"

import jax
from jax import config
config.update("jax_default_matmul_precision", "high")  

import jax.numpy as jnp
from typing import Dict, Tuple, Any
from functools import partial
import numpy as np
import dihedral

def make_irreps_Dn(n):
    
    irreps = []
    # trivial representation
    irreps.append(('triv', 1, lambda g: jnp.array([[1.0]], dtype=jnp.complex64), 0))
    # sign representation: rotations->+1, reflections->-1
    irreps.append(('sign', 1, lambda g: jnp.array([[1 if g[0]=='rot' else -1]], dtype=jnp.complex64), -1))
    # extra one-dimensional irreps for even n
    if n % 2 == 0:
        irreps.append(('rp', 1,
            lambda g: jnp.array([[(-1.0)**g[1]]], dtype=jnp.complex64), n//2))      # r^m ->(-1)^m
        irreps.append(('srp', 1,
            lambda g: jnp.array([[(-1.0)**g[1] *
                                (-1 if g[0] == 'ref' else 1)]], dtype=jnp.complex64), - n//2)) # s r^m -> (-1) (-1)^m
    # two-dimensional irreps (standard representations)
    maxk = (n - 1) // 2
    for k in range(1, maxk + 1):
        def Rk(g, k=k):
            m = g[1]
            theta = 2.0 * jnp.pi * k * m / n        
            c, s   = np.cos(theta).astype(np.float64), np.sin(theta).astype(np.float64)
            R      = np.array([[c, -s],
                            [s,  c]], dtype=np.float64)

            if g[0] == 'rot':                      
                return jnp.array(R, dtype=jnp.complex64)

            S = np.array([[1.0, 0.0],
                        [0.0,-1.0]], dtype=np.float64)
            return jnp.array(S @ R, dtype=jnp.complex64) # s r^m
            # return jnp.array(R @ S, dtype=jnp.complex64) # r^m s
        
        irreps.append((f'2D_{k}', 2, Rk, k))
    G = [('rot', i) for i in range(n)] + [('ref', i) for i in range(n)]
    return G, irreps

def build_rho_cache(G, irreps):
    return {
        name: jnp.stack([R(g) for g in G], axis=0)
        for name, _, R, _ in irreps
    }


#### fast GFT
def _A_from(z):
    return jnp.stack(
        [jnp.stack([ z.real, -z.imag], axis=-1),
         jnp.stack([ z.imag,  z.real], axis=-1)],
        axis=-2
    )  

def _B_from(z):
    return jnp.stack(
        [jnp.stack([ z.real,  z.imag], axis=-1),
         jnp.stack([ z.imag, -z.real], axis=-1)],
        axis=-2
    )  
def _A_from_pair(a_k, a_nk):
    c = 0.5 * (a_k + a_nk)
    s = -0.5j * (a_k - a_nk)
    return jnp.stack(
        [jnp.stack([c, -s], axis=-1),
         jnp.stack([s,  c], axis=-1)],
        axis=-2
    )  

def _B_from_pair(b_k, b_nk):
    c = 0.5 * (b_k + b_nk)
    s = -0.5j * (b_k - b_nk)
    return jnp.stack(
        [jnp.stack([c,  s], axis=-1),
         jnp.stack([s, -c], axis=-1)],
        axis=-2
    )  

def _fft_meta_from_irreps(irreps, group_size: int):
    n = group_size // 2
    even = (n % 2 == 0)
    ks   = tuple(range(1, (n - 1) // 2 + 1))  
    oned_names = ['triv', 'sign'] + (['rp', 'srp'] if even else [])
    twod_names = [f'2D_{k}' for k in ks]

    name2idx_1d = {nm: i for i, nm in enumerate(oned_names)}
    name2idx_2d = {nm: i for i, nm in enumerate(twod_names)}
    return {
        'n': n, 'even': even, 'ks': ks,
        'oned_names': oned_names, 'twod_names': twod_names,
        'name2idx_1d': name2idx_1d, 'name2idx_2d': name2idx_2d,
    }

from functools import partial

@partial(jax.jit, static_argnums=(1,2,3))
def _fft_stage1_all_r(f_blk_cplx: jnp.ndarray, n: int, even: bool, ks: tuple):
    
    G, _, Bc = f_blk_cplx.shape
    
    f0_rot = f_blk_cplx[:n, :, :]         
    f0_ref = f_blk_cplx[n:, :, :]         
    F0_rot = jnp.fft.fft(f0_rot, axis=0)  
    F0_ref = jnp.fft.fft(f0_ref, axis=0)  

    k0 = 0
    acc_1d = []
    triv = F0_rot[k0, :, :] + F0_ref[k0, :, :]
    sign = F0_rot[k0, :, :] - F0_ref[k0, :, :]
    acc_1d += [triv, sign]
    if even:
        kH = n // 2
        rp  = F0_rot[kH, :, :] + F0_ref[kH, :, :]
        srp = F0_rot[kH, :, :] - F0_ref[kH, :, :]
        acc_1d += [rp, srp]
    Y1d_all = jnp.stack(acc_1d, axis=0) if acc_1d else jnp.zeros((0, G, Bc), dtype=f_blk_cplx.dtype)

    acc_2d = []
    for k in ks:
        kc = (n - k) % n
        alpha_k, alpha_ck = F0_rot[k, :, :],  F0_rot[kc, :, :]
        beta_k,  beta_ck  = F0_ref[k, :, :],  F0_ref[kc, :, :]
        block = _A_from_pair(alpha_k, alpha_ck) + _B_from_pair(beta_k, beta_ck)   
        block = jnp.transpose(block, (2, 3, 0, 1))  
        acc_2d.append(block)
    Y2d_all = jnp.stack(acc_2d, axis=0) if acc_2d else jnp.zeros((0, 2, 2, G, Bc), dtype=f_blk_cplx.dtype)

    return Y1d_all, Y2d_all

@partial(jax.jit, static_argnums=(1,2,3))
def _fft_stage2_for_r(Yr: jnp.ndarray, n: int, even: bool, ks: tuple):
    
    Yr_rot = Yr[:, :, :n, :]                  
    Yr_ref = Yr[:, :,  n:, :]                 
    F1_rot = jnp.fft.fft(Yr_rot, axis=2)      
    F1_ref = jnp.fft.fft(Yr_ref, axis=2)      

    k0 = 0
    acc_1d = []
    triv = F1_rot[:, :, k0, :] + F1_ref[:, :, k0, :]
    sign = F1_rot[:, :, k0, :] - F1_ref[:, :, k0, :]
    acc_1d += [triv, sign]
    if even:
        kH = n // 2
        rp  = F1_rot[:, :, kH, :] + F1_ref[:, :, kH, :]
        srp = F1_rot[:, :, kH, :] - F1_ref[:, :, kH, :]
        acc_1d += [rp, srp]
    S1d = jnp.stack(acc_1d, axis=0) if acc_1d else jnp.zeros((0, Yr.shape[0], Yr.shape[1], Yr.shape[-1]), dtype=Yr.dtype)

    acc_2d = []
    for k in ks:
        kc = (n - k) % n
        a_k,  a_ck = F1_rot[:, :, k, :],   F1_rot[:, :, kc, :]
        b_k,  b_ck = F1_ref[:, :, k, :],   F1_ref[:, :, kc, :]
        block = _A_from_pair(a_k, a_ck) + _B_from_pair(b_k, b_ck)   
        block = jnp.moveaxis(block, -3, -1)                         
        acc_2d.append(block)
    S2d = jnp.stack(acc_2d, axis=0) if acc_2d else jnp.zeros((0, Yr.shape[0], Yr.shape[1], 2, 2, Yr.shape[-1]), dtype=Yr.dtype)
    return S1d, S2d

def _group_dft_preacts_inner_nojit(preacts, rho_cache, irreps, group_size):
    
    G  = int(group_size)                 
    n  = G // 2                          
    assert 2 * n == G, "group_size must be even for D_n"

    f_grid = jnp.asarray(preacts).reshape(G, G, -1) 
    N      = int(f_grid.shape[-1])

    
    TARGET_BYTES = 500 * 1024 * 1024
    elem_bytes = 8 if jnp.issubdtype(f_grid.dtype, jnp.complexfloating) else 4
    bytes_per_sample = (G * G + 2 * n * G + 64) * elem_bytes
    B = max(1, min(N, int(TARGET_BYTES // max(1, bytes_per_sample))))

    
    name_to_dim  = {}
    name_to_k    = {}
    for name, dim, _R, meta in irreps:
        name_to_dim[name] = dim
        name_to_k[name]   = meta  

    oned_names = [name for name, dim, *_ in irreps if dim == 1]
    twod_names = [name for name, dim, *_ in irreps if dim == 2]

    
    def _oned_pick(name, Xrot_k, Xref_k):
        if name == 'triv':   return Xrot_k + Xref_k
        if name == 'sign':   return Xrot_k - Xref_k
        if name == 'rp':     return Xrot_k + Xref_k
        if name == 'srp':    return Xref_k * (-1.0) + Xrot_k
        return Xrot_k + Xref_k

    Fhat = {}

    meta = _fft_meta_from_irreps(irreps, group_size)
    oned_names = meta['oned_names']; twod_names = meta['twod_names']
    name2idx_1d = meta['name2idx_1d']; name2idx_2d = meta['name2idx_2d']
    n = meta['n']; even = meta['even']; ks = meta['ks']
    inv_gsq = jnp.asarray(1.0, dtype=jnp.complex64) / (group_size * group_size)

    start = 0
    while start < N:
        curB  = min(B, N - start)
        f_blk = jax.lax.dynamic_slice_in_dim(f_grid, start, curB, axis=2).astype(jnp.complex64)  # (G,G,curB)

        Y1d_all, Y2d_all = _fft_stage1_all_r(f_blk, n, even, ks)   
        for r_name, d_r, _Rr, _ in irreps:
            if d_r == 1:
                ridx = name2idx_1d[r_name]
                Yr = Y1d_all[ridx][None, None, :, :]          
            else:
                ridx = name2idx_2d[r_name]
                Yr = Y2d_all[ridx]                            

            S1d, S2d = _fft_stage2_for_r(Yr, n, even, ks)    

            for s_name in oned_names:
                sidx = name2idx_1d[s_name]
                Z = (S1d[sidx] * inv_gsq).astype(jnp.complex64)     
                Z = Z[..., None, None, :]                          
                for bi in range(curB):
                    Fhat[(r_name, s_name, int(start + bi))] = Z[..., bi]

            for s_name in twod_names:
                sidx = name2idx_2d[s_name]
                right_block = (S2d[sidx] * inv_gsq).astype(jnp.complex64)  
                for bi in range(curB):
                    Fhat[(r_name, s_name, int(start + bi))] = right_block[..., bi]

        start += curB


    return Fhat

def jit_wrap_group_dft(rho_cache, irreps, group_size):
    return lambda preacts: _group_dft_preacts_inner_nojit(
        preacts, rho_cache, irreps, group_size
    )


