import sys
sys.path.append ('../../')

import jax
import jax.numpy as jnp
import numpy as np
from jax import vmap, random, jit, lax
import scipy
import argparse
import os
import time
import numpy.polynomial.chebyshev as cheb

parser = argparse.ArgumentParser(description="SINN_high")
parser.add_argument("--d", type=int, default=2, help="dimension of data")
parser.add_argument("--q", type=int, default=10, help="hyperparameter of hyperbolic cross")
parser.add_argument("--device", type=int, default=0, help="hyperparameter of hyperbolic cross")
parser.add_argument("--alpha", type=float, default=1.0, help="hyperparameter of convection equation")
parser.add_argument("--T", type=float, default=0.01, help="terminal time")
parser.add_argument("--seed", type=int, default=0, help="seed")
parser.add_argument("--percentage", type=float, default=1.0, help="training percentage")
args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.device)

def build_D_c_matrix(N_x, ord=1):
    """
    构造切比雪夫微分矩阵 D_c，不使用内部函数

    参数:
        N_x: 节点数量，对应于Chebyshev-Lobatto节点的数量
        ord: 微分阶数，默认为1

    返回:
        D_c: 构造好的切比雪夫微分矩阵
    """
    # 1. 构造切比雪夫一阶微分矩阵 D_p 和节点 x
    if N_x == 1:
        x = np.array([1.0, -1.0])  # Chebyshev nodes of the 2nd kind for N=1
        D_p = np.array([[0.]])
    else:
        # 生成Chebyshev nodes of the 2nd kind (Chebyshev-Lobatto节点)
        x = cheb.chebpts2(N_x)

        if N_x == 2:
            D_p = np.array([[-0.5, 0.5],
                           [-0.5, 0.5]])
        else:
            # 构造c向量：c_0 = c_N = 2，其余=1，且交替符号 (-1)^j
            c = np.ones(N_x)
            c[0] = 2.0
            c[-1] = 2.0
            c = c * ((-1.0) ** np.arange(N_x))

            X = np.tile(x, (N_x, 1))
            dX = X - X.T

            # 构造微分矩阵
            D_p = np.outer(c, 1.0 / c) / (dX + np.eye(N_x))  # 非对角线元素
            D_p = np.diag(np.sum(D_p, axis=1)) - D_p  # 对角线元素

    # 2. 构建T矩阵，T[i,j] = T_i(x_j)
    # 初始化T矩阵
    T = np.zeros((N_x, N_x))

    for j in range(N_x):
        # 创建只包含T_j的多项式
        c = np.zeros(N_x)
        c[j] = 1.0
        p = cheb.Chebyshev(c)

        # 在所有节点上计算该多项式的值
        T[:, j] = p(x)

    # 3. 构造B矩阵，使得T @ B ≈ I
    N = N_x - 1

    # 权重w_j
    w = np.ones(N_x)
    w[0] = 0.5
    w[-1] = 0.5

    # 构造B矩阵
    B = (2.0 / N) * (T.T * w)

    # 边界修正（k=0和k=N）
    B[0, :] *= 0.5
    B[-1, :] *= 0.5

    # 4. 计算指定阶数的微分矩阵D_c
    D_c = D_p @ T
    for i in range(ord - 1):
        D_c = D_p @ D_c

    D_c = B @ D_c

    # remove the first column and last row
    return D_c

def build_derivative_matrix(k_sp, D_c, axis=0):
    """
    k_sp: (m, dim) array, selected multi-index subset
    D_c: (N_x, N_x) Chebyshev coefficient differential matrix (1D)
    axis: which dimension to differentiate along (default x-direction)

    return:
        M : (m, m) matrix representing derivative along given axis
    """

    k_sp = np.asarray(k_sp)
    m, dim = k_sp.shape
    N_x = D_c.shape[0]

    # --- 建立快速搜索表：从 tuple(k) 到 index ---
    index_map = {tuple(k_sp[i]): i for i in range(m)}

    # result matrix
    M = np.zeros((m, m))

    for i in range(m):
        k = k_sp[i]
        k_axis = k[axis]

        # 从 D_c 里面取出对该 k-axis 的一行 (所有 j)
        row = D_c[k_axis, :]  # shape (N_x,)

        # 遍历这一行的所有非零项
        nonzero_cols = np.nonzero(row)[0]
        for j in nonzero_cols:
            val = row[j]

            # 新的 k vector：只在 axis 方向变成 j
            k_new = list(k)
            k_new[axis] = j
            k_new = tuple(k_new)

            # 只有当 k_new 也在 k_sp 内时，才放进矩阵
            if k_new in index_map:
                col = index_map[k_new]
                M[i, col] = val

    return M


def get_cof_high(k,coe_1d):
    return jnp.prod(jnp.take(coe_1d, k))

# index set
def index_set(i, N_max, mode):
    if i == 0:
        return np.array([], dtype=int)
    N_i = 2**i + 1
    if mode == "cheby":
        return np.arange(min(N_i, N_max))
    elif mode == "fourier":
        return np.arange(-min(N_i, N_max) + 1, min(N_i, N_max))

def generate_Idq_cheby(q, d, N_max=100):
    i_single = np.arange(1, q+1)
    vec_i = np.array(np.meshgrid(*([i_single]*d))).reshape(d, -1).T
    i_l1 = np.linalg.norm(vec_i, ord=1, axis=-1)
    idq_list = []
    for j in range(vec_i.shape[0]):
        if i_l1[j] <= q:
            indexs = []
            for i_d in vec_i[j, :]:
                X = np.setdiff1d(index_set(i_d, N_max, "cheby"),
                                  index_set(i_d-1, N_max, "cheby"))
                indexs.append(X)
            grids = np.array(np.meshgrid(*indexs)).reshape(d, -1).T
            idq_list.append(grids)
    idq = np.concatenate(idq_list, axis=0)
    idq = np.unique(idq, axis=0)
    return idq

def get_cof_high(k,coe_1d):
    return jnp.prod(jnp.take(coe_1d, k))

def extend_arccos(n,x):
    # 定义三种情况的表达式
    # 1. x ∈ [-1, 1]：使用三角函数表示
    inside = jnp.cos(n * jnp.arccos(x))
    # 2. x > 1：使用双曲函数表示
    greater = jnp.cosh(n * jnp.arccosh(x))
    # 3. x < -1：利用奇偶性和双曲函数表示
    less = (-1) ** n * jnp.cosh(n * jnp.arccosh(-x))
    result = jnp.where(
        x > 1,
        greater,
        jnp.where(
            x < -1,
            less,
            inside
        )
    )
    return result
def high_icheby(c_pred, k_set, x_test,batch_size=50000):
    '''

    :param c_pred: (N,)
    :param k_set: (N,d)
    :param x_test: (d,)
    :return:
    '''
    # u_pred1 = jnp.sum(c_pred * jnp.exp(1j * jnp.sum(k_set * x_test, axis=1)))
    # u_pred = jnp.sum(c_pred * jnp.prod(jnp.exp(1j * (k_set * x_test)), axis=1))

    N = k_set.shape[0]
    u_pred = 0.0
    for i in range((N + batch_size - 1) // batch_size):
        # Get batch slice indices
        start, end = i * batch_size, min((i + 1) * batch_size, N)

        # Compute batch contribution and accumulate
        # u_pred += jnp.sum(c_pred[start:end] * jnp.prod(jnp.cos(k_set[start:end] * jnp.arccos(x_test)), axis=1))
        u_pred += jnp.sum(c_pred[start:end] * jnp.prod(extend_arccos(k_set[start:end],x_test), axis=1))
    return jnp.real(u_pred)

def find_common_vectors(k, f):
    """
    找出k中存在于f中的向量，并返回这些向量在k中的位置

    参数:
    k: 包含N_k个d维向量的数组，形状为(N_k, d)
    f: 包含N_f个d维向量的数组，形状为(N_f, d)

    返回:
    common_vectors: 共同存在的向量列表
    positions: 这些向量在k中的位置索引列表
    """
    # 确保输入是numpy数组
    k = np.array(k)
    f = np.array(f)

    # 检查维度是否一致
    if k.shape[1] != f.shape[1]:
        raise ValueError("两个数据集的向量维度不一致")

    # 将f中的向量转换为元组集合，便于快速查找
    f_set = set(tuple(vec) for vec in f)

    # 存储结果
    common_vectors = []
    positions = []

    # 遍历k中的每个向量
    for idx, vec in enumerate(k):
        vec_tuple = tuple(vec)
        if vec_tuple in f_set:
            common_vectors.append(vec)
            positions.append(idx)

    return common_vectors, positions


@jit
def compute_u_pred(test_points, coeffs, kdq, N_t, dt,D_c):
    # 初始历史记录
    history = [vmap(high_icheby, (None, None, 0))(coeffs, kdq, test_points)]


    def step_fn(i,u_nd_cheb):
        k1 = D_c @ u_nd_cheb
        u1_nd_cheb = u_nd_cheb + 0.5 * dt * k1

        k2 = D_c @ u1_nd_cheb
        u2_nd_cheb = u_nd_cheb + 0.5 * dt * k2

        k3 = D_c @ u2_nd_cheb
        u3_nd_cheb = u_nd_cheb + dt * k3

        k4 = D_c @ u3_nd_cheb

        unew = u_nd_cheb + (dt / 6) * (k1 + 2 * k2 + 2 * k3 + k4)
        return unew

    # 使用lax.fori_loop进行高效循环
    V_final = lax.fori_loop(0, N_t, step_fn, coeffs)

    # 计算历史记录（为了节省内存，这里只保存初始和最终状态，完整历史可按需修改）
    history.append(vmap(high_icheby, (None, None, 0))(V_final, kdq, test_points))
    return V_final, jnp.array(history)

# ---------------- 主计算 ---------------- #
T1 = time.time()

d = args.d
q = args.q
T = args.T
alpha = args.alpha

N_x = 50
dt=1e-6
N_t=int(T/dt+1)
N_max = N_x-1
kdq = generate_Idq_cheby(q=q, d=d, N_max=N_max)

data = jnp.load(f"../../data/test_point_cheby_t_{d}.npz")
test_points = data["x_test"]
u_target = data["u_test"]
# diff_matrices=data["diff_matrices"]
k_full = data["k_full"]
u0_nd_cheb=data["u0_nd_cheb"]
diff_matrices = data["diff_matrices"]
# u_nd_cheb_T = data["u_nd_cheb_T"]
u0_1d_cheb = jnp.load("../../data/cheby_t_1d.npz")["u0_1d_cheb"]

# if args.percentage < 1:
#     indices_mask = np.load(f"../../convection_chebyshev/indices_{d}_{args.percentage}_{args.seed}.npz")["indices"]
#     indices_mask = np.sort(indices_mask)
#     k_full=k_full[indices_mask]
#
# _,indices_common_kdq = find_common_vectors(kdq, k_full)
# indices_common_kdq = np.array(indices_common_kdq)
# kdq=np.array(kdq)
# kdq=kdq[indices_common_kdq]


all_indices = jnp.arange(len(k_full))
k_mask = k_full.copy()
if args.percentage < 1:
    indices = np.load(f'../../convection_chebyshev/indices_{d}_{args.percentage}_{args.seed}.npz')['indices']
    mask = jnp.isin(all_indices, indices, invert=True)
    indices_c = all_indices[mask]
    u0_nd_cheb[indices_c] = 0.0
    diff_matrices[indices_c, :] = 0.0

all_indices = jnp.arange(len(k_mask))
_, index_kmask = find_common_vectors(k_mask, kdq)
index_kmask = np.array(index_kmask)
mask_kdq = jnp.isin(all_indices, index_kmask, invert=True)
indices_kdq_c = all_indices[mask_kdq]
u0_nd_cheb = jnp.delete(u0_nd_cheb, indices_kdq_c, axis=0)
diff_matrices = jnp.delete(diff_matrices, indices_kdq_c, axis=0)



u_nd_cheb, history = compute_u_pred(test_points, u0_nd_cheb, k_mask, N_t, dt,diff_matrices)

u_pred= history[-1]

T2 = time.time()
error = jnp.linalg.norm(u_pred.flatten() - u_target.flatten()) / jnp.linalg.norm(u_target.flatten())
execution_time = T2 - T1
print(f"error: {error:.2e}")
print(f"execution time: {execution_time:.6f} s")

jnp.savez(f"results/cheby_t_{d}.npz", u_pred=u_pred, execution_time=execution_time,history=history,u_nd_cheb=u_nd_cheb)
