import jax
import jax.numpy as jnp
import numpy as np
import argparse
import os
import time
from jax import pmap, device_put

# 启用JAX的内存分析和优化
jax.config.update("jax_debug_nans", False)

parser = argparse.ArgumentParser(description="SINN_high with memory optimization")
parser.add_argument("--d", type=int, default=2, help="dimension of data")
parser.add_argument("--q", type=int, default=5, help="hyperparameter of hyperbolic cross")
parser.add_argument("--device", type=str, default="3", help="GPU devices to use, e.g., '0,1' for multiple GPUs")
parser.add_argument("--chunk_size", type=int, default=1000, help="Chunk size for memory optimization")
parser.add_argument("--percentage", type=float, default=1.0, help="training percentage")
args = parser.parse_args()

# 设置使用的GPU设备
os.environ["CUDA_VISIBLE_DEVICES"] = args.device
devices = jax.devices()
print(f"Using devices: {devices}")
num_devices = len(devices)


# 索引集生成函数保持不变
def index_set(i, N_max, mode):
    if i == 0:
        return jnp.array([], dtype=int)
    N_i = 2 ** i + 1
    if mode == "cheby":
        return jnp.arange(min(N_i, N_max))
    elif mode == "fourier":
        return jnp.arange(-min(N_i, N_max) + 1, min(N_i, N_max))


def generate_Idq_fourier(q, d, N_max=100):
    i_single = jnp.arange(1, q + 1)
    vec_i = jnp.array(jnp.meshgrid(*([i_single] * d))).reshape(d, -1).T
    i_l1 = jnp.linalg.norm(vec_i, ord=1, axis=-1)
    idq_list = []
    for j in range(vec_i.shape[0]):
        if i_l1[j] <= q:
            indexs = []
            for i_d in vec_i[j, :]:
                X = jnp.setdiff1d(index_set(i_d, N_max, "fourier"),
                                  index_set(i_d - 1, N_max, "fourier"))
                indexs.append(X)
            grids = jnp.array(jnp.meshgrid(*indexs)).reshape(d, -1).T
            idq_list.append(grids)
    idq = jnp.concatenate(idq_list, axis=0)
    idq = jnp.unique(idq, axis=0)
    return idq


@jax.jit
def fourier_polynomial(k, x):
    return jnp.prod(jnp.exp(1j * k * x))


@jax.jit
def get_cof_high(k, coe_1d):
    return jnp.prod(coe_1d[k])


# 分块处理的核心计算函数
def process_chunk(kdq_chunk, f_fft_vals_chunk, ksq_chunk, test_points, chunk_size):
    """处理一部分kdq数据，避免显存溢出"""
    n_points = test_points.shape[0]
    u_pred_chunk = jnp.zeros(n_points, dtype=jnp.complex128)

    # 再次分块以进一步减少内存占用
    for i in range(0, len(kdq_chunk), chunk_size):
        kdq_subchunk = kdq_chunk[i:i + chunk_size]
        f_fft_vals_subchunk = f_fft_vals_chunk[i:i + chunk_size]
        ksq_subchunk = ksq_chunk[i:i + chunk_size]

        # 计算phi子块
        phi_subchunk = jax.vmap(
            lambda p: jax.vmap(lambda k: fourier_polynomial(k, p))(kdq_subchunk)
        )(test_points)

        # 计算贡献子块
        denom_subchunk = jnp.where(ksq_subchunk != 0, ksq_subchunk, 1)
        contrib_subchunk = jnp.where(
            ksq_subchunk[None, :] != 0,
            (f_fft_vals_subchunk / denom_subchunk)[None, :] * phi_subchunk,
            f_fft_vals_subchunk[None, :] * phi_subchunk
        )

        # 累加结果并清除中间变量
        u_pred_chunk += jnp.sum(contrib_subchunk, axis=1)
        del phi_subchunk, contrib_subchunk  # 显式删除以释放内存

    return u_pred_chunk


# 多GPU并行处理函数
def parallel_process(kdq, f_fft_vals, ksq, test_points, chunk_size, num_devices):
    """将数据分配到多个GPU上并行处理"""
    # 将数据均匀分配到各个设备
    kdq_split = jnp.array_split(kdq, num_devices)
    f_fft_vals_split = jnp.array_split(f_fft_vals, num_devices)
    ksq_split = jnp.array_split(ksq, num_devices)

    # 将测试点复制到所有设备
    test_points_replicated = pmap(lambda x: x)(jnp.stack([test_points] * num_devices))

    # 在每个设备上处理不同的数据块
    @pmap
    def device_compute(kdq_chunk, f_fft_vals_chunk, ksq_chunk, test_points):
        return process_chunk(kdq_chunk, f_fft_vals_chunk, ksq_chunk, test_points, chunk_size)

    # 执行并行计算
    results = device_compute(kdq_split, f_fft_vals_split, ksq_split, test_points_replicated)

    # 汇总所有设备的结果
    return jnp.sum(results, axis=0)

def find_common_vectors(k, f):
    """
    找出k中存在于f中的向量，并返回这些向量在k中的位置

    参数:
    k: 包含N_k个d维向量的数组，形状为(N_k, d)
    f: 包含N_f个d维向量的数组，形状为(N_f, d)

    返回:
    common_vectors: 共同存在的向量列表
    positions: 这些向量在k中的位置索引列表
    """
    # 确保输入是numpy数组
    k = np.array(k)
    f = np.array(f)

    # 检查维度是否一致
    if k.shape[1] != f.shape[1]:
        raise ValueError("两个数据集的向量维度不一致")

    # 将f中的向量转换为元组集合，便于快速查找
    f_set = set(tuple(vec) for vec in f)

    # 存储结果
    common_vectors = []
    positions = []

    # 遍历k中的每个向量
    for idx, vec in enumerate(k):
        vec_tuple = tuple(vec)
        if vec_tuple in f_set:
            common_vectors.append(vec)
            positions.append(idx)

    return common_vectors, positions

# ---------------- 主计算 ---------------- #
T1 = time.time()

d = args.d
q = args.q
N_x = 100
N_max = N_x
chunk_size = args.chunk_size

# 生成索引集 - 这一步本身可能也需要大量内存
print("Generating index set...")
kdq = generate_Idq_fourier(q=q, d=d, N_max=int(N_max / 2))
print(f"Generated {len(kdq)} indices for d={d}, q={q}")

# 加载数据
print("Loading data...")
data = jnp.load(f"data/test_point_simple_{d}.npz")
test_points = data["x_test"]
u_target = data["u_test"]
f_1d_fft = jnp.load("data/fourier_1d.npz")["f_1d_fft"]

k_full = data['k_full']
# 预处理系数
_,indices = find_common_vectors(kdq, k_full)
if args.percentage < 1:
    indices = jax.random.choice(jax.random.PRNGKey(0), indices, shape=(np.round(len(indices) * args.percentage).astype('int32'),),
                            replace=False)
    indices = np.sort(indices)

kdq=np.array(kdq)
kdq=kdq[indices]

print("Preprocessing coefficients...")
ksq = -jnp.sum(kdq ** 2, axis=1)
f_fft_vals = jax.vmap(lambda k: get_cof_high(k, f_1d_fft))(kdq)

# 根据设备数量选择计算方式
if num_devices > 1:
    print(f"Using {num_devices} GPUs for parallel processing...")
    u_pred = parallel_process(kdq, f_fft_vals, ksq, test_points, chunk_size, num_devices)
else:
    print("Using single GPU with chunked processing...")
    u_pred = process_chunk(kdq, f_fft_vals, ksq, test_points, chunk_size)

T2 = time.time()

# 计算误差
error = jnp.linalg.norm(u_pred.flatten() - u_target.flatten()) / jnp.linalg.norm(u_target.flatten())
execution_time = T2 - T1

print(f"error: {error:.2e}")
print(f"execution time: {execution_time:.6f} s")

# 保存结果
jnp.savez(f"results/fourier_{d}.npz", u_pred=u_pred, execution_time=execution_time)

