import deepxde as dde
import numpy as np
import torch
import pandas as pd
from sklearn.metrics import mean_squared_error
import re
import os
import importlib
import pinn_pde_definitions
from sklearn.metrics import mean_squared_error
import json
import argparse

def jsonify(d):
    """recursively formats dicts for json serialization"""
    if isinstance(d, list):
        d_new = []
        for v in d:
            d_new.append(jsonify(v))
        return d_new
    elif isinstance(d, dict):
        for k in d.keys():
            d[k] = jsonify(d[k])
    elif isinstance(d, np.ndarray):
        return d.tolist()
    elif d.__class__.__name__.startswith('int'):
        return int(d)
    elif d.__class__.__name__.startswith('float'):
        return float(d)
    elif isinstance(d, pd.DataFrame) or isinstance(d, pd.Series):
        return d.values.tolist()
    elif isinstance(d, bool):
        return d
    elif d == None:
        return None
    elif not isinstance(d, str):
        print("WARNING: attempting to store ",d,"as a str for json")
        return str(d)
    return d

def read_file(filename, label='target', use_dataframe=True, sep=None):
    
    if filename.endswith('gz'):
        compression = 'gzip'
    else:
        compression = None
    
    print('compression:',compression)
    print('filename:',filename)

    input_data = pd.read_csv(filename, sep=sep, compression=compression)
     

    feature_names = [x for x in input_data.columns.values if x != label]
    feature_names = np.array(feature_names)

    X = input_data.drop(label, axis=1)
    if not use_dataframe:
        X = X.values
    y = input_data[label].values

    assert(X.shape[1] == feature_names.shape[0])

    return X, y, feature_names

# def heat_pde(x, y, ):
#     a=0.4
#     dy_t = dde.grad.jacobian(y, x, i=0, j=1)
#     dy_xx = dde.grad.hessian(y, x, i=0, j=0)
#     return dy_t - a * dy_xx


# def setup_heat(filename):
#     dataset = pd.read_csv(filename,sep='\t')
#     X_obs,y_obs, feature_names = read_file(filename, label='target', use_dataframe=True, sep=None)
    
    
#     # 定义求解域
#     geom = dde.geometry.Interval(0, 1)
#     timedomain = dde.geometry.TimeDomain(0, 2)
#     geomtime = dde.geometry.GeometryXTime(geom, timedomain)
    
#     # Define data for PINN
#     bc_data = dde.icbc.PointSetBC(X_obs,y_obs)

#     data = dde.data.PDE(
#         geomtime,
#         heat_pde,
#         bcs = [bc_data],
#         num_domain=1000,
#         num_boundary=len(y_obs)
#     )

#     # Define neural network
#     net = dde.nn.FNN([2] + [40] * 6 + [1], "tanh", "Glorot normal")

#     # Define model
#     model = dde.Model(data, net)
    
#     return model, dataset





def main(args):
    #name_group = ["Diffusion_reaction", "Wave", "Heat", "Advection", "Poisson"]
    #index_group = ["1_1D", "1_2D", "1_3D", "2_1D", "2_2D", "2_3D", ]
    name_group = [args.name_group]
    index_group = [args.index_group]
    train_root = "./PDEdataset-psr-noise"
    test_root1 = "./PDEdataset-psr"
    test_root2 = "./PDEdataset-random"
    results_path = "./results-psr-new2"

    for problem_name in name_group:
        for problem_index in index_group:
            match = re.search(r'_(\d+)D', str(problem_index))
            dim = int(match.group(1))

            if dim == 1:
                geom = dde.geometry.Interval(0, 1)
            elif dim == 2:
                geom = dde.geometry.Rectangle((0, 0), (1, 1))
            elif dim == 3:
                geom = dde.geometry.Cuboid((0.0, 0.0, 0.0), (1.0, 1.0, 1.0))
            
            timedomain = dde.geometry.TimeDomain(0, 2)

            save_path_problem=os.path.join(problem_name,problem_index)
            train_save_path_dir = os.path.join(train_root, save_path_problem )
            test_save_path_dir1 = os.path.join(test_root1, save_path_problem )    
            test_save_path_dir2 = os.path.join(test_root2, save_path_problem )      
            results_save_path_dir = os.path.join(results_path, save_path_problem)
            if not os.path.exists(results_save_path_dir):
                os.makedirs(results_save_path_dir)
            if os.path.exists(os.path.join(results_save_path_dir, f"results.json")):
                continue  # Skip if results already exist
            train_flnm_data = os.path.join(train_save_path_dir, "data.tsv.gz")
            test_flnm_data1 = os.path.join(test_save_path_dir1, "data.tsv.gz")
            test_flnm_data2 = os.path.join(test_save_path_dir2, "data.tsv.gz")

            X_obs,y_obs, feature_names = read_file(train_flnm_data, label='target', use_dataframe=True, sep=None)
            if 't' in feature_names:
                timedomain = dde.geometry.TimeDomain(0, 2)
                geomtime = dde.geometry.GeometryXTime(geom, timedomain)
            else:
                geomtime = geom
            
            bc_data = dde.icbc.PointSetBC(X_obs,y_obs)

            if hasattr(pinn_pde_definitions, problem_name+problem_index):
                pde_define = getattr(pinn_pde_definitions, problem_name+problem_index)

                bc_data = dde.icbc.PointSetBC(X_obs,y_obs)

                data = dde.data.PDE(
                    geomtime,
                    pde_define,
                    bcs = [bc_data],
                    num_domain=1000,
                    num_boundary=len(y_obs)
                )
                net = dde.nn.FNN([len(feature_names)] + [40] * 6 + [1], "tanh", "Glorot normal")
                model = dde.Model(data, net)

    
                model_name = problem_name+problem_index + "_PINN"

                # Save model
                checker = dde.callbacks.ModelCheckpoint(
                    f"{model_name}.pt", save_better_only=True, period=5000
                )

                model.compile("adam", lr=1e-4, loss_weights=[1e-3,1])
                losshistory, train_state = model.train(
                    iterations=10000, display_every=1000, callbacks=[checker]
                )

                model.compile("L-BFGS") # 选择优化器
                losshistory, train_state = model.train(callbacks=[checker])

                # test model
                dde.saveplot(losshistory, train_state, issave=True, isplot=True)
                X_test1,y_test1, _ = read_file(test_flnm_data1, label='target', use_dataframe=True, sep=None)
                y_pred1 = model.predict(X_test1)
                f_pde1 = model.predict(X_test1, operator=pde_define)

                X_test2,y_test2, _ = read_file(test_flnm_data2, label='target', use_dataframe=True, sep=None)
                y_pred2 = model.predict(X_test2)
                f_pde2 = model.predict(X_test2, operator=pde_define)

                results = {
                    "train_path": train_flnm_data,
                    "mse0": mean_squared_error(y_obs, model.predict(X_obs)),
                    "train_variance": np.var(y_obs - model.predict(X_obs)),
                    "pde_residual_error0": mean_squared_error(model.predict(X_obs, operator=pde_define), np.zeros_like(model.predict(X_obs, operator=pde_define))),
                    "train_pde_variance": np.var(model.predict(X_obs, operator=pde_define)),
                    "test_path1": test_flnm_data1,
                    "mse1": mean_squared_error(y_test1, y_pred1),
                    "variance1": np.var(y_test1 - y_pred1),
                    "pde_residual_error1": mean_squared_error(f_pde1, np.zeros_like(f_pde1)),
                    "pde_variance1": np.var(f_pde1),
                    "test_path2": test_flnm_data2,
                    "mse2": mean_squared_error(y_test2, y_pred2),
                    "variance2": np.var(y_test2 - y_pred2),
                    "pde_residual_error2": mean_squared_error(f_pde2, np.zeros_like(f_pde2)),
                    "pde_variance2": np.var(f_pde2),
                    "model_name": model_name,
                    "model_path": f"{model_name}.pt",
                }
                
                with open(os.path.join(results_save_path_dir, f"results.json"), 'w') as f:
                    json.dump(jsonify(results), f, indent=4)
                

    

if __name__ == "__main__":
    #name_group = ["Diffusion_reaction", "Wave", "Heat", "Advection", "Poisson"]
    #index_group = ["1_1D", "1_2D", "1_3D", "2_1D", "2_2D", "2_3D", ]
    parser = argparse.ArgumentParser(
            description="PINN for solving PDE.", add_help=False)
    parser.add_argument('--name_group', type=str, nargs='?', default="Heat", help='Problem names')
    parser.add_argument('--index_group', type=str, nargs='?', default="1_1D", help='Problem indices')
                                                                        
    args = parser.parse_args()
    print("Arguments:", args.name_group, args.index_group)
    main(args)
    
