import argparse
import pickle
import h5py
import numpy as np
import torch
import os
import sys
from functools import partial
from scipy.integrate import solve_ivp
from evaluation.eval_eq import di_sindy_truth
import matplotlib.pyplot as plt

def get_dataset(pde):
    test_data_path = f"data/{pde}_test_4096_default.h5"
    pde_path = f"data/{pde}_default.pkl"

    with open(pde_path, 'rb') as f:
        pde = pickle.load(f)

    nt = pde.nt_effective
    nx = pde.nx

    f = h5py.File(test_data_path, 'r')
    data = f['test']
    dataset = f'pde_{nt}-{nx}'

    dataset_size = data[dataset].shape[0]
    n_data = 4
    ind = np.random.randint(dataset_size, size=n_data)
    data = {k: np.array(data[k])[ind] for k in data.keys()}
    f.close()

    t = np.linspace(pde.tmax - (pde.nt_effective - 1) * pde.dt, pde.tmax, pde.nt_effective)
    u = data[dataset]
    return t, u, pde

def fun(t, u, pde, model, coef, shape):
    u = u.reshape(shape)
    dudx = (np.roll(u, -1, axis=-1) - np.roll(u, 1, axis=-1)) / (2 * pde.dx)
    dudxdx = (np.roll(u, -1, axis=-1) - 2 * u + np.roll(u, 1, axis=-1)) / pde.dx / pde.dx
    dudxdxdx = (np.roll(u, -2, axis=-1) - 2 * np.roll(u, -1, axis=-1) + 2 * np.roll(u, 1, axis=-1) - np.roll(u, 2, axis=-1)) / 2 / pde.dx / pde.dx / pde.dx
    dudxdxdxdx = (np.roll(u, -2, axis=-1) - 4 * np.roll(u, -1, axis=-1) + 6 * u - 4 * np.roll(u, 1, axis=-1) + np.roll(u, 2, axis=-1)) / pde.dx / pde.dx / pde.dx / pde.dx
    if str(pde) in ['KdV', 'KS', 'Burgers']:
        if model == 'disindy':
            dudt = -u * dudx + coef[0, 0] * dudx + coef[0, 1] * dudxdx + coef[0, 2] * dudxdxdx + coef[0, 3] * dudxdxdxdx
        elif model == 'esindy':
            dudt = coef[0, 0] * u + coef[0, 1] * dudx + coef[0, 2] * dudxdx + coef[0, 3] * dudxdxdx + coef[0, 4] * dudxdxdxdx \
                    + coef[0, 5] * u ** 2 + coef[0, 6] * u * dudx + coef[0, 7] * u * dudxdx + coef[0, 8] * u * dudxdxdx + coef[0, 9] * u * dudxdxdxdx \
                    + coef[0, 10] * dudx ** 2 + coef[0, 11] * dudx * dudxdx + coef[0, 12] * dudx * dudxdxdx + coef[0, 13] * dudx * dudxdxdxdx \
                    + coef[0, 14] * dudxdx ** 2 + coef[0, 15] * dudxdx * dudxdxdx + coef[0, 16] * dudxdx * dudxdxdxdx \
                    + coef[0, 17] * dudxdxdx ** 2 + coef[0, 18] * dudxdxdx * dudxdxdxdx \
                    + coef[0, 19] * dudxdxdxdx ** 2
    elif str(pde) == 'nKdV':
        if model == 'disindy':
            dudt = -u * dudx + coef[0, 0] * dudx + coef[0, 1] * dudxdx + coef[0, 2] * dudxdxdx + coef[0, 3] * dudxdxdxdx
            dudt *= np.exp(t / 50)
        elif model == 'esindy':
            dudt = coef[0, 0] * u + coef[0, 1] * dudx + coef[0, 2] * dudxdx + coef[0, 3] * dudxdxdx + coef[0, 4] * dudxdxdxdx \
                    + coef[0, 5] * u ** 2 + coef[0, 6] * u * dudx + coef[0, 7] * u * dudxdx + coef[0, 8] * u * dudxdxdx + coef[0, 9] * u * dudxdxdxdx \
                    + coef[0, 10] * dudx ** 2 + coef[0, 11] * dudx * dudxdx + coef[0, 12] * dudx * dudxdxdx + coef[0, 13] * dudx * dudxdxdxdx \
                    + coef[0, 14] * dudxdx ** 2 + coef[0, 15] * dudxdx * dudxdxdx + coef[0, 16] * dudxdx * dudxdxdxdx \
                    + coef[0, 17] * dudxdxdx ** 2 + coef[0, 18] * dudxdxdx * dudxdxdxdx \
                    + coef[0, 19] * dudxdxdxdx ** 2
            dudt *= np.exp(t / 50)
    return dudt.reshape(-1)

def eval_ltp_accuracy(t, u, pde, model, coef):
    u0 = u[:, 0]
    shape = u0.shape
    fun_part = partial(fun, pde=pde, model=model, coef=coef, shape=shape)
    t_span = (t[0], t[-1])
    solution = solve_ivp(fun_part, t_span, u0.reshape(-1), method='RK45', t_eval=t)
    if solution.status == -1:
        return None
    u_pred = solution.y.reshape(shape[0], shape[1], -1).transpose(0, 2, 1)
    error = np.mean((u_pred - u) ** 2, axis=(0, 2))
    return error

def aggregate_results(run_name, t, u, pde, model, min_seed=0, max_seed=100):
    directory = os.path.join(os.getcwd(), result_dir, run_name)
    coefs = []
    errors = []
    for filename in os.listdir(directory):
        if filename.endswith('.npz'):
            print(filename)
            file_path = os.path.join(directory, filename)
            seed = int(filename.split('.')[0][4:])
            if seed >= max_seed or seed < min_seed:
                continue
            res = np.load(file_path)
            coef = res['coefficients']
            error = eval_ltp_accuracy(t, u, pde, model, coef)
            if error is None:
                print('continue')
                continue
            errors.append(error)
    return np.stack(errors)

def aggregate_results_gt(t, u, pde):
    coef = di_sindy_truth[str(pde)]
    error = eval_ltp_accuracy(t, u, pde, 'disindy', coef)
    return error

def get_args():
    parser = argparse.ArgumentParser()

    parser.add_argument('--pde', type=str, default='KdV')
    parser.add_argument('--gpu', type=int, default=0)
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--load', action='store_true')

    args = parser.parse_args()
    args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')

    return args

if __name__ == '__main__':
    args = get_args()

    seed = args.seed
    torch.manual_seed(seed)
    np.random.seed(seed)

    args = vars(args)

    t, u, pde = get_dataset(pde=args['pde'])
    
    if not args['load']:
        print(f"{str(pde).lower()}_gt")
        error_gt = aggregate_results_gt(t, u, pde)
        print(error_gt.shape)

        result_dir = 'eval_results'

        model = 'disindy'
        run_name = f"{str(pde).lower()}_disindy"
        print(run_name)
        error_disindy = aggregate_results(run_name, t, u, pde, model, max_seed=50)
        print(error_disindy.shape)

        model = 'esindy'
        
        run_name = f"{str(pde).lower()}_sindy"
        print(run_name)
        error_sindy = aggregate_results(run_name, t, u, pde, model, max_seed=50)
        print(error_sindy.shape)

        run_name = f"{str(pde).lower()}_esindy_1e-3"
        print(run_name)
        error_esindy_m3 = aggregate_results(run_name, t, u, pde, model, max_seed=50)
        print(error_esindy_m3.shape)

        run_name = f"{str(pde).lower()}_esindy_1e-2"
        print(run_name)
        error_esindy_m2 = aggregate_results(run_name, t, u, pde, model, max_seed=50)
        print(error_esindy_m2.shape)

        run_name = f"{str(pde).lower()}_esindy_1e-1"
        print(run_name)
        error_esindy_m1 = aggregate_results(run_name, t, u, pde, model, max_seed=50)
        print(error_esindy_m1.shape)

        eval_ltp_results = {
            'gt': error_gt,
            'disindy': error_disindy,
            'sindy': error_sindy,
            'esindy_m3': error_esindy_m3,
            'esindy_m2': error_esindy_m2,
            'esindy_m1': error_esindy_m1,
        }

        eval_ltp_save_dir = f'eval_ltp_results'
        if not os.path.exists(eval_ltp_save_dir):
            os.makedirs(eval_ltp_save_dir)
        np.savez(f"{eval_ltp_save_dir}/{str(pde)}.npz", **eval_ltp_results)
    
    else:
        eval_ltp_save_dir = f'eval_ltp_results'
        file_path = f"{eval_ltp_save_dir}/{str(pde)}.npz"
        eval_ltp_results = np.load(file_path)
    
    plt.plot(t, eval_ltp_results['gt'], label='Ground truth')

    mean = np.mean(eval_ltp_results['disindy'], axis=0)
    std = np.std(eval_ltp_results['disindy'], axis=0)
    line, = plt.plot(t, mean, label='DI-SINDy (Ours)')
    color = line.get_color()
    plt.fill_between(t, mean - std, mean + std, color=color, alpha=0.1)

    mean = np.mean(eval_ltp_results['sindy'], axis=0)
    std = np.std(eval_ltp_results['sindy'], axis=0)
    line, = plt.plot(t, mean, label='SINDy')
    color = line.get_color()
    plt.fill_between(t, mean - std, mean + std, color=color, alpha=0.1)

    mean = np.mean(eval_ltp_results['esindy_m3'], axis=0)
    std = np.std(eval_ltp_results['esindy_m3'], axis=0)
    line, = plt.plot(t, mean, label=r'EquivSINDy-r ($\lambda=10^{-3}$)')
    color = line.get_color()
    plt.fill_between(t, mean - std, mean + std, color=color, alpha=0.1)

    mean = np.mean(eval_ltp_results['esindy_m2'], axis=0)
    std = np.std(eval_ltp_results['esindy_m2'], axis=0)
    line, = plt.plot(t, mean, label=r'EquivSINDy-r ($\lambda=10^{-2}$)')
    color = line.get_color()
    plt.fill_between(t, mean - std, mean + std, color=color, alpha=0.1)

    mean = np.mean(eval_ltp_results['esindy_m1'], axis=0)
    std = np.std(eval_ltp_results['esindy_m1'], axis=0)
    line, = plt.plot(t, mean, label=r'EquivSINDy-r ($\lambda=10^{-1}$)')
    color = line.get_color()
    plt.fill_between(t, mean - std, mean + std, color=color, alpha=0.1)

    if str(pde) == 'Burgers':
        plt.ylim(-2e-7, 5e-6)
    plt.legend()
    plt.xlabel('Time')
    plt.ylabel('MSE')
    plt.title(f'Long-term prediction error for {str(pde)} equation')
    plt.tight_layout()
    plt.savefig(f"{eval_ltp_save_dir}/{str(pde)}.pdf")
    