import jax
import jax.numpy as jnp
from jax import jit, lax, vmap
import argparse
import time
import numpy as np  # 仅用于数据加载和保存
import os

# 设置JAX使用float32精度以节省内存（如果需要高精度可以改为False）
jnp.set_printoptions(precision=3)
# jax.config.update("jax_enable_x64", False)

parser = argparse.ArgumentParser(description="SINN_high with JAX acceleration and memory optimization")
parser.add_argument("--d", type=int, default=2, help="dimension of data")
parser.add_argument("--q", type=int, default=10, help="hyperparameter of hyperbolic cross")
parser.add_argument("--alpha", type=float, default=1.0, help="hyperparameter of heat equation")
parser.add_argument("--device", type=str, default="3", help="GPU devices to use, e.g., '0,1' for multiple GPUs")
parser.add_argument("--T", type=float, default=0.01, help="terminal time")
parser.add_argument("--chunk_size", type=int, default=500, help="Chunk size for memory optimization")
parser.add_argument("--percentage", type=float, default=1.0, help="training percentage")
parser.add_argument("--seed", type=int, default=0, help="random seed")
args = parser.parse_args()

os.environ["CUDA_VISIBLE_DEVICES"] = args.device

def find_common_vectors(k, f):
    """
    找出k中存在于f中的向量，并返回这些向量在k中的位置

    参数:
    k: 包含N_k个d维向量的数组，形状为(N_k, d)
    f: 包含N_f个d维向量的数组，形状为(N_f, d)

    返回:
    common_vectors: 共同存在的向量列表
    positions: 这些向量在k中的位置索引列表
    """
    # 确保输入是numpy数组
    k = np.array(k)
    f = np.array(f)

    # 检查维度是否一致
    if k.shape[1] != f.shape[1]:
        raise ValueError("两个数据集的向量维度不一致")

    # 将f中的向量转换为元组集合，便于快速查找
    f_set = set(tuple(vec) for vec in f)

    # 存储结果
    common_vectors = []
    positions = []

    # 遍历k中的每个向量
    for idx, vec in enumerate(k):
        vec_tuple = tuple(vec)
        if vec_tuple in f_set:
            common_vectors.append(vec)
            positions.append(idx)

    return common_vectors, positions



def index_set(i, N_max, mode):
    if i == 0:
        I = []
    else:
        N_i = 2 ** i + 1
        if mode == 'cheby':
            I = np.arange(min(N_i, N_max))
        if mode == 'fourier':
            I = np.arange(-min(N_i, N_max) + 1, min(N_i, N_max))
    return I


def generate_Idq_cheby(q, d, N_max=100):
    i_single = np.arange(1, q + 1)
    vec_i = np.meshgrid(*([i_single] * d))
    vec_i = np.concatenate([vec_i[i].reshape(-1, 1) for i in range(len(vec_i))], -1)
    i_l1 = np.linalg.norm(vec_i, axis=-1, ord=1)
    idq = []
    for j, l1 in enumerate(i_l1):
        if l1 <= q:
            indexs = []
            for i_d in vec_i[j, :]:
                X = np.setdiff1d(index_set(i_d, N_max, 'cheby'), index_set(i_d - 1, N_max, 'cheby'))
                indexs.append(X)
            grids = np.meshgrid(*indexs)
            idq.append(np.concatenate([grids[i].reshape(-1, 1) for i in range(len(grids))], -1))
    idq = np.concatenate(idq, 0)
    coordinate_tuples = [tuple(point) for point in idq]
    idq = np.array(list(set(coordinate_tuples)))
    return idq


def generate_Idq_fourier(q, d, N_max=100):
    i_single = np.arange(1, q + 1)
    vec_i = np.meshgrid(*([i_single] * d))
    vec_i = np.concatenate([vec_i[i].reshape(-1, 1) for i in range(len(vec_i))], -1)
    i_l1 = np.linalg.norm(vec_i, axis=-1, ord=1)
    idq = []
    for j, l1 in enumerate(i_l1):
        if l1 <= q:
            indexs = []
            for i_d in vec_i[j, :]:
                X = np.setdiff1d(index_set(i_d, N_max, 'fourier'), index_set(i_d - 1, N_max, 'fourier'))
                indexs.append(X)
            grids = np.meshgrid(*indexs)
            idq.append(np.concatenate([grids[i].reshape(-1, 1) for i in range(len(grids))], -1))
    idq = np.concatenate(idq, 0)
    coordinate_tuples = [tuple(point) for point in idq]
    idq = np.array(list(set(coordinate_tuples)))
    return idq


@jit
def fourier_polynomial(k, X):
    """计算傅里叶多项式 e^(i k·x)"""
    return jnp.exp(1j * (k @ X.T))


@jit
def get_cof_high(k, coe_1d):
    return jnp.prod(coe_1d[k])


@jit
def step_fn(i, val):
    V, lap, alpha, gamma,dt = val
    k1 = gamma * V
    k2 = gamma * (V + k1)
    V_new = V + 0.5 * (k1 + k2)
    return (V_new, lap, alpha, gamma,dt)
# def step_fn(i, val):
#     """单步更新函数，用于lax.fori_loop"""
#     V, lap, alpha, dt = val
#     k1 = V + alpha * lap[:, None] * dt * V
#     k2 = (3 * V + k1 + alpha * lap[:, None] * dt * k1) / 4
#     V_new = (V + 2 * k2 + 2 * alpha * lap[:, None] * dt * k2) / 3
#     return (V_new, lap, alpha, dt)


@jit
def process_chunk(chunk, test_points, u0_1d_fft, N_t, dt, alpha):
    """处理单个kdq块的计算"""
    # 提取当前块的k值
    kdq_chunk = chunk

    # 计算当前块的系数
    coeffs_chunk = vmap(lambda k: get_cof_high(k, u0_1d_fft))(kdq_chunk)

    # 计算拉普拉斯算子
    lap_chunk = -jnp.sum(kdq_chunk ** 2, axis=1)

    # 计算傅里叶多项式
    phi_chunk = fourier_polynomial(kdq_chunk, test_points)

    # 初始值
    V_chunk = coeffs_chunk[:, None] * phi_chunk

    # 时间演化
    gamma = alpha * lap_chunk[:, None] * dt
    val = (V_chunk, lap_chunk, alpha, gamma,dt)
    V_final_chunk, _, _, _,_ = lax.fori_loop(0, N_t, step_fn, val)

    # 计算初始和最终状态的贡献
    initial_contrib = jnp.real(jnp.sum(V_chunk, axis=0))
    final_contrib = jnp.real(jnp.sum(V_final_chunk, axis=0))

    return final_contrib, initial_contrib


def compute_u_pred(test_points, u0_1d_fft, kdq, N_t, dt, alpha, chunk_size):
    """分块计算预测值，减少内存占用"""
    n_k = len(kdq)
    n_points = test_points.shape[0]

    # 初始化结果数组
    u_pred = jnp.zeros(n_points, dtype=jnp.float32)
    initial_history = jnp.zeros(n_points, dtype=jnp.float32)

    # 分块处理所有k值
    for i in range(0, n_k, chunk_size):
        # 获取当前块
        end_idx = min(i + chunk_size, n_k)
        kdq_chunk = kdq[i:end_idx]

        # 处理当前块
        final_contrib, initial_contrib = process_chunk(
            kdq_chunk, test_points, u0_1d_fft, N_t, dt, alpha
        )

        # 累加结果
        u_pred += final_contrib
        initial_history += initial_contrib

        # 打印进度
        if (i // chunk_size) % 10 == 0:
            print(f"Processed {min(i + chunk_size, n_k)}/{n_k} k values")

    # 构建历史记录（初始和最终状态）
    history = jnp.stack([initial_history, u_pred])

    return u_pred, history


def main():
    T1 = time.time()

    d = args.d
    q = args.q
    T = args.T
    alpha = args.alpha
    chunk_size = args.chunk_size

    N_x = 100
    N_t = 100
    dt = (T - 0) / N_t
    N_max = N_x

    # 生成索引集（在CPU上进行）
    print("Generating index set...")
    kdq = generate_Idq_fourier(q=q, d=d, N_max=int(N_max / 2))
    kdq = kdq.astype(int)  # 确保整数类型
    print(f"Generated {len(kdq)} indices for d={d}, q={q}")

    # 加载数据
    print("Loading data...")
    test_data = np.load(f'../../data/test_point_t_{d}.npz')
    test_points = jnp.array(test_data['x_test'])
    u_target = jnp.array(test_data['u_test'])
    k_full = jnp.array(test_data['k_full'])

    # 加载初始条件的FFT系数
    u0_1d_fft = jnp.array(np.load('../../data/fourier_t_1d.npz')['u0_1d_fft'])


    if args.percentage < 1:
        indices_mask = np.load(f"../../heat_fourier/indices_{d}_{args.percentage}_{args.seed}.npz")["indices"]
        indices_mask = np.sort(indices_mask)
        k_full = k_full[indices_mask]
    _, indices = find_common_vectors(kdq, k_full)
    indices = jnp.array(indices)
    kdq = np.array(kdq)
    kdq = kdq[indices]

    # 计算预测值（分块处理）
    print("Starting computation with chunk size:", chunk_size)
    u_pred, history = compute_u_pred(
        test_points, u0_1d_fft, kdq, N_t, dt, alpha, chunk_size
    )

    # 计算误差
    error = jnp.linalg.norm(u_pred.flatten() - u_target.flatten()) / jnp.linalg.norm(u_target.flatten())

    T2 = time.time()
    execution_time = T2 - T1

    print(f'error: {error:.2e}')
    print(f"execution time : {execution_time:.6f} s")

    # 保存结果（转换为numpy数组以便保存）
    # 确保results目录存在
    os.makedirs('results', exist_ok=True)
    np.savez(
        f'results/fourier_t_{d}.npz',
        u_pred=np.asarray(u_pred),
        execution_time=execution_time,
        history=np.asarray(history)
    )


if __name__ == "__main__":
    main()


