import jax
import jax.numpy as jnp
from jax import jit, lax
import argparse
import time
import numpy as np  # 仅用于数据加载和保存
import os


parser = argparse.ArgumentParser(description="SINN_high with JAX acceleration")
parser.add_argument("--d", type=int, default=2, help="dimension of data")
parser.add_argument("--q", type=int, default=5, help="hyperparameter of hyperbolic cross")
parser.add_argument("--alpha", type=float, default=1.0, help="hyperparameter of heat equation")
parser.add_argument("--device", type=str, default="0", help="GPU devices to use, e.g., '0,1' for multiple GPUs")
parser.add_argument("--T", type=float, default=0.01, help="terminal time")
parser.add_argument("--percentage", type=float, default=1.0, help="training percentage")
args = parser.parse_args()

os.environ["CUDA_VISIBLE_DEVICES"] = args.device

def compute_u_pred_exact(x_test, coe_1d, k_1d, k_full, T):
    coeffs = jnp.prod(coe_1d[k_full], axis=1)
    decay = jnp.exp(-jnp.sum(k_1d[k_full]**2, axis=1) * T)
    phi = jnp.exp(1j * (k_1d[k_full] @ x_test.T))
    u_pred = jnp.real(coeffs * decay @ phi)
    return u_pred

def find_common_vectors(k, f):
    """
    找出k中存在于f中的向量，并返回这些向量在k中的位置

    参数:
    k: 包含N_k个d维向量的数组，形状为(N_k, d)
    f: 包含N_f个d维向量的数组，形状为(N_f, d)

    返回:
    common_vectors: 共同存在的向量列表
    positions: 这些向量在k中的位置索引列表
    """
    # 确保输入是numpy数组
    k = np.array(k)
    f = np.array(f)

    # 检查维度是否一致
    if k.shape[1] != f.shape[1]:
        raise ValueError("两个数据集的向量维度不一致")

    # 将f中的向量转换为元组集合，便于快速查找
    f_set = set(tuple(vec) for vec in f)

    # 存储结果
    common_vectors = []
    positions = []

    # 遍历k中的每个向量
    for idx, vec in enumerate(k):
        vec_tuple = tuple(vec)
        if vec_tuple in f_set:
            common_vectors.append(vec)
            positions.append(idx)

    return common_vectors, positions


def index_set(i,N_max,mode):
    if i==0:
        I=[]
    else:
        N_i=2**i+1
        if mode=='cheby':
          I=np.arange(min(N_i,N_max))
        if mode=='fourier':
          I=np.arange(-min(N_i,N_max)+1,min(N_i,N_max))
    return I


def generate_Idq_cheby(q,d,N_max=100):
    i_single=np.arange(1,q+1)
    vec_i=np.meshgrid(*([i_single]*d))
    vec_i=np.concatenate([vec_i[i].reshape(-1,1) for i in range(len(vec_i))],-1)
    i_l1=np.linalg.norm(vec_i,axis=-1,ord=1)
    idq=[]
    for j,l1 in enumerate(i_l1):
        if l1<=q:
            indexs=[]
            for i_d in vec_i[j,:]:
                X=np.setdiff1d(index_set(i_d,N_max,'cheby'),index_set(i_d-1,N_max,'cheby'))
                indexs.append(X)
            grids=np.meshgrid(*indexs)
            idq.append(np.concatenate([grids[i].reshape(-1,1) for i in range(len(grids))],-1))
    idq=np.concatenate(idq,0)
    coordinate_tuples = [tuple(point) for point in idq]
    idq = np.array(list(set(coordinate_tuples)))
    return idq

def generate_Idq_fourier(q,d,N_max=100):
    i_single=np.arange(1,q+1)
    vec_i=np.meshgrid(*([i_single]*d))
    vec_i=np.concatenate([vec_i[i].reshape(-1,1) for i in range(len(vec_i))],-1)
    i_l1=np.linalg.norm(vec_i,axis=-1,ord=1)
    idq=[]
    for j,l1 in enumerate(i_l1):
        if l1<=q:
            indexs=[]
            for i_d in vec_i[j,:]:
                X=np.setdiff1d(index_set(i_d,N_max,'fourier'),index_set(i_d-1,N_max,'fourier'))
                indexs.append(X)
            grids=np.meshgrid(*indexs)
            idq.append(np.concatenate([grids[i].reshape(-1,1) for i in range(len(grids))],-1))
    idq=np.concatenate(idq,0)
    coordinate_tuples = [tuple(point) for point in idq]
    idq = np.array(list(set(coordinate_tuples)))
    return idq
@jit
def fourier_polynomial(k, X):
    """计算傅里叶多项式 e^(i k·x)"""
    return jnp.exp(1j * (k @ X.T))

@jit
def get_cof_high(k, coe_1d):
    return jnp.prod(coe_1d[k])


@jit
def compute_u_pred(test_points, u0_1d_fft, kdq, N_t, dt, alpha):
    coeffs = jnp.array([get_cof_high(k, u0_1d_fft) for k in kdq])
    lap = -jnp.sum(kdq**2, axis=1)
    phi = fourier_polynomial(kdq, test_points)
    V = coeffs[:, None] * phi

    # 预计算稳定的更新系数（显式RK2，满足CFL）
    gamma = alpha * lap[:, None] * dt
    # 标准RK2（Heun法）
    def step_fn(i, val):
        V = val
        k1 = gamma * V
        k2 = gamma * (V + k1)
        V_new = V + 0.5 * (k1 + k2)
        return V_new

    V_final = lax.fori_loop(0, N_t, step_fn, V)
    history = jnp.array([
        jnp.real(jnp.sum(V, axis=0)),
        jnp.real(jnp.sum(V_final, axis=0))
    ])
    u_pred = jnp.real(jnp.sum(V_final, axis=0))
    return u_pred, history


def main():
    T1 = time.time()

    d = args.d
    q = args.q
    T = args.T
    alpha = args.alpha

    N_x = 100
    N_t = 100
    dt = (T - 0) / N_t
    N_max = N_x

    # 生成索引集（这里使用CPU计算，因为索引生成不是计算密集型）
    kdq = generate_Idq_fourier(q=q, d=d, N_max=int(N_max/2))
    kdq = kdq.astype(int)  # 确保整数类型

    # 加载数据
    test_data = np.load(f'../../data/test_point_t_{d}.npz')
    test_points = jnp.array(test_data['x_test'])
    u_target = jnp.array(test_data['u_test'])
    k_full = np.array(test_data['k_full'])

    # 加载初始条件的FFT系数
    u0_1d_fft = jnp.array(np.load('../../data/fourier_t_1d.npz')['u0_1d_fft'])

    _, indices = find_common_vectors(kdq, k_full)
    if args.percentage < 1:
        indices = jax.random.choice(jax.random.PRNGKey(0), indices,
                                    shape=(np.round(len(indices) * args.percentage).astype('int32'),),
                                    replace=False)
        indices = np.sort(indices)

    kdq = np.array(kdq)
    kdq = kdq[indices]

    # 计算预测值
    u_pred, history = compute_u_pred(test_points, u0_1d_fft, k_full, N_t, dt, alpha)

    # 计算误差
    k_1d = np.around(np.fft.fftfreq(N_x) * N_x).astype('int32')
    u0_target = compute_u_pred_exact(test_points, u0_1d_fft, k_1d,k_full,0.0)
    error_0 = jnp.linalg.norm(history[0]-u0_target)/jnp.linalg.norm(u0_target)
    error_T = jnp.linalg.norm(u_pred.flatten() - u_target.flatten()) / jnp.linalg.norm(u_target.flatten())

    T2 = time.time()
    execution_time = T2 - T1

    print(f'error_0: {error_0:.2e}, error_T: {error_T:.2e}')
    print(f"execution time : {execution_time:.6f} s")

    # 保存结果（转换为numpy数组以便保存）
    np.savez(
        f'results/fourier_t_{d}.npz',
        u_pred=np.asarray(u_pred),
        execution_time=execution_time,
        u_target=u_target,
    )

if __name__ == "__main__":
    main()

