# -*- coding: utf-8 -*-
# import os
# import sys
# sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import numpy as np
import torch
import os
import yaml

# from Poisson import *
# from Advection import *
# from Heat import *
# from Wave import *
# from Diffusion_reaction import *
import importlib
import re
import pandas as pd

def gen_dataset(problem_name, problem_index):
    
    expr_str = problem_name+problem_index+"_sol"
    eva_torch_str = problem_name+problem_index+"_eva"

    eva_module = importlib.import_module(problem_name)
    if hasattr(eva_module, expr_str):
        expr = getattr(eva_module, expr_str)
    if hasattr(eva_module, eva_torch_str):
        eva_torch = getattr(eva_module, eva_torch_str)



    save_root = "./PDEdataset-psr-noise"
    save_path_problem=os.path.join(problem_name,problem_index)
    save_path_dir = os.path.join(save_root, save_path_problem )
    if not os.path.exists(save_path_dir):
        os.makedirs(save_path_dir)
    save_path_meatadata = os.path.join(save_path_dir,  "metadata.yaml")
    save_path_data = os.path.join(save_path_dir,  "data.tsv.gz")
     
    for i in range(len(X)):
        X[i]=torch.tensor(X[i],requires_grad=True)
    sol_str = expr(X,eva_torch,save_path_data)

    ###############################
    #================添加噪声===========
    # 读取保存的数据
    data = pd.read_csv(save_path_data, sep='\t')

    # 给 target 列添加高斯噪声
    if 'target' in data.columns:
        noise = np.random.normal(0, 0.001, size=data['target'].shape)
        data['target'] += noise

    # 将带噪声的数据重新保存到文件
    data.to_csv(save_path_data, sep='\t', index=False, compression='gzip')
    #=========添加噪声结束=========

    

    subdata = {"problem_name": problem_name,
        'problem_index': problem_index,
        'sol': sol_str,
        'eva': eva_torch.__name__,}
    # 将字典写入.yaml文件
    with open(save_path_meatadata, 'w') as file:
        yaml.dump(subdata, file)


    
    pass

if __name__ == '__main__':

    name_group = ["Diffusion_reaction", "Wave", "Heat", "Advection", "Poisson"]
    #name_group = ["Diffusion_reaction"]
    
    index_group = ["1_1D", "1_2D", "1_3D", "2_1D", "2_2D", "2_3D", ]
    #index_group = ['2_1D'] #['1_1D', '1_2D', '1_3D', '2_1D', '2_2D', '2_3D']

    # problem_name = "Heat"#"Diffusion_reaction/1-1D" #"Wave/1-1D" #"Heat/1-1D" #"Advection/1-1D" # "Poisson/1-1D"
    # problem_index = "1_1D"

    for problem_name in name_group:
        for problem_index in index_group:
            match = re.search(r'_(\d+)D', problem_index)
            dim = int(match.group(1))

            if dim == 1:
                x1 = np.round(np.random.uniform(0.1, 1, 400), 3)
                t = np.round(np.random.uniform(0.1, 2, 400), 3)
                X = np.stack([x1, t], axis=0)
            elif dim == 2:
                # 100 pts (x1, 0), 100 pts (0, x2), 300 pts (x1, x2)
                x1 = np.concatenate([np.round(np.random.uniform(0.1, 1, 100), 3), np.zeros(100), np.round(np.random.uniform(0.1, 1, 300), 3)])
                x2 = np.concatenate([np.zeros(100), np.round(np.random.uniform(0.1, 1, 100), 3), np.round(np.random.uniform(0.1, 1, 300), 3)])
                t = np.round(np.random.uniform(0.1, 2, 500), 3)
                X = np.stack([x1, x2, t], axis=0)
            elif dim == 3:
                # 100 pts (x1,0,0), 100 pts (0,x2,0), 100 pts (0,0,x3), 300 pts (x1,x2,x3)
                x1 = np.concatenate([np.round(np.random.uniform(0.1, 1, 100), 3), np.zeros(100), np.zeros(100), np.round(np.random.uniform(0.1, 1, 300), 3)])
                x2 = np.concatenate([np.zeros(100), np.round(np.random.uniform(0.1, 1, 100), 3), np.zeros(100), np.round(np.random.uniform(0.1, 1, 300), 3)])
                x3 = np.concatenate([np.zeros(100), np.zeros(100), np.round(np.random.uniform(0.1, 1, 100), 3), np.round(np.random.uniform(0.1, 1, 300), 3)])
                t = np.round(np.random.uniform(0.1, 2, 600), 3)
                X = np.stack([x1, x2, x3, t], axis=0)

            X = torch.tensor(X).tolist() 
            gen_dataset(problem_name, problem_index)
    
    

    