import sys
sys.path.append ('../../')

import jax
import jax.numpy as jnp
import numpy as np
from jax import vmap, random
import scipy
import argparse
import os
import time
import numpy.polynomial.chebyshev as cheb

parser = argparse.ArgumentParser(description="SINN_high")
parser.add_argument("--d", type=int, default=2, help="dimension of data")
parser.add_argument("--q", type=int, default=5, help="hyperparameter of hyperbolic cross")
parser.add_argument("--device", type=int, default=2, help="hyperparameter of hyperbolic cross")
parser.add_argument("--seed", type=int, default=0, help="random seed")
parser.add_argument('--percentage', type=float, default=1.0,
                    help='Mask percentage (0~1, default: 1)')
args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.device)
# index set
def index_set(i, N_max, mode):
    '''
    Generate the indix for I_d
    :param i:
    :param N_max:
    :param mode:
    :return:
    '''
    if i == 0:
        return np.array([], dtype=int)
    N_i = 2**i + 1
    if mode == "cheby":
        return np.arange(min(N_i, N_max))
    elif mode == "fourier":
        return np.arange(-min(N_i, N_max) + 1, min(N_i, N_max))

def generate_Idq_cheby(q, d, N_max=100):
    i_single = np.arange(1, q+1)
    vec_i = np.array(np.meshgrid(*([i_single]*d))).reshape(d, -1).T
    i_l1 = np.linalg.norm(vec_i, ord=1, axis=-1)
    idq_list = []
    for j in range(vec_i.shape[0]):
        if i_l1[j] <= q:
            indexs = []
            for i_d in vec_i[j, :]:
                X = np.setdiff1d(index_set(i_d, N_max, "cheby"),
                                  index_set(i_d-1, N_max, "cheby"))
                indexs.append(X)
            grids = np.array(np.meshgrid(*indexs)).reshape(d, -1).T
            idq_list.append(grids)
    idq = np.concatenate(idq_list, axis=0)
    idq = np.unique(idq, axis=0)
    return idq

def get_cof_high(k,coe_1d):
    return jnp.prod(jnp.take(coe_1d, k))

def high_icheby(c_pred, k_set, x_test,batch_size=50000):
    '''
    compute the physical value of specific points
    :param c_pred: (N,)
    :param k_set: (N,d)
    :param x_test: (d,)
    :return:
    '''
    # u_pred1 = jnp.sum(c_pred * jnp.exp(1j * jnp.sum(k_set * x_test, axis=1)))
    # u_pred = jnp.sum(c_pred * jnp.prod(jnp.exp(1j * (k_set * x_test)), axis=1))

    N = k_set.shape[0]
    u_pred = 0.0
    for i in range((N + batch_size - 1) // batch_size):
        # Get batch slice indices
        start, end = i * batch_size, min((i + 1) * batch_size, N)

        # Compute batch contribution and accumulate
        # u_pred += jnp.sum(c_pred[start:end] * jnp.prod(jnp.exp(1j * k_set[start:end] * x_test), axis=1))
        u_pred += jnp.sum(c_pred[start:end] * jnp.prod(jnp.cos(k_set[start:end] * jnp.arccos(x_test)), axis=1))
    return jnp.real(u_pred)


def find_common_vectors(k, f):
    """
    找出k中存在于f中的向量，并返回这些向量在k中的位置

    参数:
    k: 包含N_k个d维向量的数组，形状为(N_k, d)
    f: 包含N_f个d维向量的数组，形状为(N_f, d)

    返回:
    common_vectors: 共同存在的向量列表
    positions: 这些向量在k中的位置索引列表
    """
    # 确保输入是numpy数组
    k = np.array(k)
    f = np.array(f)

    # 检查维度是否一致
    if k.shape[1] != f.shape[1]:
        raise ValueError("两个数据集的向量维度不一致")

    # 将f中的向量转换为元组集合，便于快速查找
    f_set = set(tuple(vec) for vec in f)

    # 存储结果
    common_vectors = []
    positions = []

    # 遍历k中的每个向量
    for idx, vec in enumerate(k):
        vec_tuple = tuple(vec)
        if vec_tuple in f_set:
            common_vectors.append(vec)
            positions.append(idx)

    return common_vectors, positions

@jax.jit
def solve_lstsq(diff_matrices, f_nd_cheb):
    # lstsq会被编译为GPU优化的XLA指令
    u_nd_cheb, _, _, _ = jnp.linalg.lstsq(diff_matrices, f_nd_cheb)
    return u_nd_cheb
# ---------------- 主计算 ---------------- #
d = args.d
q = args.q
N_x = 50
N_max = N_x-1
kdq = generate_Idq_cheby(q=q, d=d, N_max=N_max)
print(f"Generated {len(kdq)} indices for d={d}, q={q}")

data = jnp.load(f"../../data/test_point_cheby_{d}.npz")
test_points = data["x_test"]
k_full = data["k_full"]
u_target = data["u_test"]
diff_matrices = data["diff_matrices"]
f_nd_cheb = data["numerical_f_nd_cheb"]
u_1d_cheb = jnp.load("../../data/cheby_1d.npz")["u_1d_cheb"]


f_nd_cheb_in, f_nd_cheb_b = jnp.split(f_nd_cheb, [k_full.shape[0]], axis=0)
all_indices = jnp.arange(len(k_full))
k_mask=k_full.copy()
if args.percentage < 1:
    indices = np.load(f'../../one_order_chebyshev/indices_{d}_{args.percentage}_{args.seed}.npz')['indices']
    mask = jnp.isin(all_indices, indices, invert=True)
    indices_c = all_indices[mask]
    f_nd_cheb=jnp.delete(f_nd_cheb, indices_c, axis=0)
    k_mask=jnp.delete(k_mask, indices_c, axis=0)
    diff_matrices = jnp.delete(diff_matrices, indices_c, axis=0)

all_indices = jnp.arange(len(k_mask))
_,index_kmask=find_common_vectors(k_mask,kdq)
index_kmask = np.array(index_kmask)
mask_kdq = jnp.isin(all_indices, index_kmask, invert=True)
indices_kdq_c = all_indices[mask_kdq]
f_nd_cheb = jnp.delete(f_nd_cheb, indices_kdq_c, axis=0)
diff_matrices = jnp.delete(diff_matrices, indices_kdq_c, axis=0)

print(f"Left {f_nd_cheb.shape[0]} indices")
T1 = time.time()
u_nd_cheb = solve_lstsq(diff_matrices, f_nd_cheb)
T2 = time.time()
if len(u_nd_cheb) != len(k_mask):
    u_nd_cheb=jnp.delete(u_nd_cheb, indices_c, axis=0)
u_pred = vmap(high_icheby, (None, None, 0))(u_nd_cheb, k_mask, test_points)
error = jnp.linalg.norm(u_pred.flatten() - u_target.flatten()) / jnp.linalg.norm(u_target.flatten())
execution_time = T2 - T1

print(f"error: {error:.2e}")
print(f"execution time: {execution_time:.6f} s")
jnp.savez(f"results/chebyshev_{d}_{args.percentage}.npz", u_pred=u_pred, execution_time=execution_time,u_star=u_target)

