import jax
import jax.numpy as jnp
from jax import vmap, random
import numpy as np
import numpy.polynomial.chebyshev as cheb
import scipy
from numpy.polynomial import chebyshev as C
from numpy.polynomial import polynomial as P
from collections import defaultdict
from itertools import product

def build_D_c_matrix(N_x, ord=1):
    """
    构造切比雪夫微分矩阵 D_c，不使用内部函数

    参数:
        N_x: 节点数量，对应于Chebyshev-Lobatto节点的数量
        ord: 微分阶数，默认为1

    返回:
        D_c: 构造好的切比雪夫微分矩阵
    """
    # 1. 构造切比雪夫一阶微分矩阵 D_p 和节点 x
    if N_x == 1:
        x = np.array([1.0, -1.0])  # Chebyshev nodes of the 2nd kind for N=1
        D_p = np.array([[0.]])
    else:
        # 生成Chebyshev nodes of the 2nd kind (Chebyshev-Lobatto节点)
        x = cheb.chebpts2(N_x)

        if N_x == 2:
            D_p = np.array([[-0.5, 0.5],
                           [-0.5, 0.5]])
        else:
            # 构造c向量：c_0 = c_N = 2，其余=1，且交替符号 (-1)^j
            c = np.ones(N_x)
            c[0] = 2.0
            c[-1] = 2.0
            c = c * ((-1.0) ** np.arange(N_x))

            X = np.tile(x, (N_x, 1))
            dX = X - X.T

            # 构造微分矩阵
            D_p = np.outer(c, 1.0 / c) / (dX + np.eye(N_x))  # 非对角线元素
            D_p = np.diag(np.sum(D_p, axis=1)) - D_p  # 对角线元素

    # 2. 构建T矩阵，T[i,j] = T_i(x_j)
    # 初始化T矩阵
    T = np.zeros((N_x, N_x))

    for j in range(N_x):
        # 创建只包含T_j的多项式
        c = np.zeros(N_x)
        c[j] = 1.0
        p = cheb.Chebyshev(c)

        # 在所有节点上计算该多项式的值
        T[:, j] = p(x)

    # 3. 构造B矩阵，使得T @ B ≈ I
    N = N_x - 1

    # 权重w_j
    w = np.ones(N_x)
    w[0] = 0.5
    w[-1] = 0.5

    # 构造B矩阵
    B = (2.0 / N) * (T.T * w)

    # 边界修正（k=0和k=N）
    B[0, :] *= 0.5
    B[-1, :] *= 0.5

    # 4. 计算指定阶数的微分矩阵D_c
    D_c = D_p @ T
    for i in range(ord - 1):
        D_c = D_p @ D_c

    D_c = B @ D_c

    # remove the first column and last row
    return D_c

def build_diff_matrix(k_sp, D_c, axis=0):
    """
    k_sp: (m, dim) array, selected multi-index subset
    D_c: (N_x, N_x) Chebyshev coefficient differential matrix (1D)
    axis: which dimension to differentiate along (default x-direction)

    return:
        M : (m, m) matrix representing derivative along given axis
    """

    k_sp = np.asarray(k_sp)
    m, dim = k_sp.shape
    N_x = D_c.shape[0]

    # --- 建立快速搜索表：从 tuple(k) 到 index ---
    index_map = {tuple(k_sp[i]): i for i in range(m)}

    # result matrix
    M = np.zeros((m, m))

    for i in range(m):
        k = k_sp[i]
        k_axis = k[axis]

        # 从 D_c 里面取出对该 k-axis 的一行 (所有 j)
        row = D_c[k_axis, :]  # shape (N_x,)

        # 遍历这一行的所有非零项
        nonzero_cols = np.nonzero(row)[0]
        for j in nonzero_cols:
            val = row[j]

            # 新的 k vector：只在 axis 方向变成 j
            k_new = list(k)
            k_new[axis] = j
            k_new = tuple(k_new)

            # 只有当 k_new 也在 k_sp 内时，才放进矩阵
            if k_new in index_map:
                col = index_map[k_new]
                M[i, col] = val

    return M

def build_boundary_matrix(k_set, N_max, side=+1, axis=0):
    """
    Build matrix A so that A @ c gives boundary sums at x_axis = side.

    Parameters
    ----------
    k_set : (N,d) int array
        Multi-index array for Chebyshev modes
    N_max : int
        Maximum mode index (assumed same in each dimension)
    side : int
        Boundary side (+1 or -1)
    axis : int
        Dimension index along which boundary condition is applied (0 <= axis < d)

    Returns
    -------
    A : (M, N) array
        Boundary operator matrix
        Rows = constraints for all fixed indices except 'axis'
    row_info : list
        Metadata: (axis, side, other_indices) for each row
    """
    N, d = k_set.shape
    num_rows = (N_max + 1) ** (d - 1)
    A = np.zeros((num_rows, N))

    # mapping from (other indices) -> row id
    index_to_row = {}
    for row, idx in enumerate(np.ndindex(*((N_max + 1,) * (d - 1)))):
        index_to_row[idx] = row
    # fill matrix
    for n in range(N):
        kj = k_set[n, axis]
        other_idx = tuple(np.delete(k_set[n], axis))

        weight = 1 if side == +1 else (-1) ** kj
        r = index_to_row[other_idx]
        A[r, n] = weight
    # remove zero rows
    A = A[~jnp.all(A == 0, axis=1) ]
    return A

def chebyshev_translation_matrix(N, alpha):
    """
    Construct A such that:
        T_n(x + alpha) = sum_k A[k, n] T_k(x)
    """
    A = np.zeros((N + 1, N + 1))

    for n in range(N + 1):
        # 1. T_n in Chebyshev basis
        c_cheb = np.zeros(n + 1)
        c_cheb[n] = 1.0

        # 2. Convert to power basis: p(x)
        p = C.cheb2poly(c_cheb)   # p[k] * x^k

        # 3. Shift: p(x + alpha)
        p_shift = np.zeros_like(p)
        for k in range(len(p)):
            # p[k] * (x + alpha)^k
            shifted_term = p[k] * P.polypow([alpha, 1.0], k)
            p_shift = P.polyadd(p_shift, shifted_term)

        # 4. Convert back to Chebyshev basis
        c_new = C.poly2cheb(p_shift)

        # 5. Fill translation matrix
        A[:len(c_new), n] = c_new

    return A.T

def multi_index_leq(k):
    """
    Generate all j such that 0 <= j_i <= k_i
    """
    return product(*[range(ki + 1) for ki in k])
def chebyshev_shift_sparse(k_sp, c_k, alpha):
    """
    Parameters
    ----------
    k_sp : ndarray, shape (N, dim)
        Sparse multi-index set
    c_k : ndarray, shape (N,)
        Chebyshev coefficients
    alpha : float
        Shift (same for all dimensions)

    Returns
    -------
    c_new : dict
        keys: multi-index tuple
        values: shifted coefficients
    """
    k_sp = np.asarray(k_sp, dtype=int)
    c_k = np.asarray(c_k)

    N, dim = k_sp.shape
    k_max = k_sp.max()

    # 1D translation matrix
    A = chebyshev_translation_matrix(k_max, alpha)

    c_new = defaultdict(float)

    for idx in range(N):
        k = k_sp[idx]
        ck = c_k[idx]

        for j in multi_index_leq(k):
            weight = 1.0
            for d in range(dim):
                weight *= A[k[d], j[d]]

            c_new[j] += ck * weight

    return c_new

def count_model_parameters(model) -> tuple[int, float]:
    """
    统计Equinox模型的参数量和参数字节数（纯JAX/Equinox实现，无NumPy依赖）
    Args:
        model: 加载的Equinox模型实例
    Returns:
        total_params: 总参数量（个数）
        total_bytes: 参数占用的字节数（可换算为MB/GB）
    """
    # 筛选模型中所有可训练的数值型参数（排除非参数的结构/元数据）
    params = eqx.filter(model, eqx.is_inexact_array)

    total_params = 0
    total_bytes = 0

    # 核心：用jax.tree_util遍历所有参数张量（JAX推荐的树结构遍历方式）
    def count_fn(tensor):
        nonlocal total_params, total_bytes
        if isinstance(tensor, jnp.ndarray):
            # 累加参数量（元素个数）
            total_params += tensor.size
            # 累加字节数（元素个数 × 单个元素字节数）
            total_bytes += tensor.size * tensor.dtype.itemsize

    # 遍历参数树并统计
    jax.tree_util.tree_map(count_fn, params)

    return total_params, total_bytes

def find_common_vectors(k, f):
    k = np.array(k)
    f = np.array(f)

    if k.shape[1] != f.shape[1]:
        raise ValueError("两个数据集的向量维度不一致")

    f_set = set(tuple(vec) for vec in f)
    common_vectors = []
    positions = []

    for idx, vec in enumerate(k):
        vec_tuple = tuple(vec)
        if vec_tuple in f_set:
            common_vectors.append(vec)
            positions.append(idx)

    return common_vectors, positions
