import h5py
import numpy as np
from constrained_sindy import ConstrainedSINDy
from argparse import ArgumentParser
import warnings
from tqdm import tqdm
import os

import sys
sys.path.append('../')
import numpy as np
import pysindy as ps
from sklearn.metrics import mean_squared_error

from sindy.ext_pde_library import ExtendedPDELibrary
from stlsq2 import STLSQ2
from eval_eq import eval_eq_params

# Odd polynomial terms in (u, v), up to second order derivatives in (u, v)
library_functions = [
    lambda x: x,
    lambda x: x * x,
    lambda x, y: x * y,
    lambda x: x * x * x,
    lambda x, y: x * y * y,
    lambda x, y: x * x * y,
]
library_function_names = [
    lambda x: x,
    lambda x: x + x,
    lambda x, y: x + y,
    lambda x: x + x + x,
    lambda x, y: x + y + y,
    lambda x, y: x + x + y,
]

def get_gt_params(data_path):
    file_name = os.path.basename(data_path)
    file_name = os.path.splitext(file_name)[0]
    file_name = file_name.split('_')
    if len(file_name) < 3:
        return np.array([
            [1, 0, 0, 0, 0, -1, 1, -1, 1, 0, 0, 0.1, 0, 0, 0, 0, 0, 0.1, 0,],
            [0, 1, 0, 0, 0, -1, -1, -1, -1, 0, 0, 0, 0.1, 0, 0, 0, 0, 0, 0.1],
        ]), "original"
    # symmetry breaking: rd_sb{i}_eps{eps}
    sb_type = int(file_name[1][2:])
    sb_eps = eval(file_name[2][3:])
    if sb_type == 1:
        return np.array([
            [1, 0, 0, 0, 0, -1, 1, -1, 1, 0, 0, 0.1, 0, 0, 0, 0, 0, 0.1, 0,],
            [0, 1, 0, 0, 0, -1, -1, -1, -1, 0, 0, 0, 0.1, 0, 0, 0, 0, 0, 0.1],
        ]), f"sb1_{sb_eps:.2f}"
    elif sb_type == 2:
        return np.array([
            [1, -sb_eps, 0, 0, 0, -1, 1, -1, 1, 0, 0, 0.1, 0, 0, 0, 0, 0, 0.1, 0,],
            [-sb_eps, 1, 0, 0, 0, -1, -1, -1, -1, 0, 0, 0, 0.1, 0, 0, 0, 0, 0, 0.1],
        ]), f"sb2_{sb_eps:.2f}"


if __name__ == "__main__":

    parser = ArgumentParser()
    parser.add_argument('--n_runs', type=int, default=100)
    parser.add_argument('--n_timesteps', type=int, default=50)
    parser.add_argument('--noise', type=float, default=0.02)
    parser.add_argument('--lib', choices=['weak', 'original', 'extended'], default='weak')
    parser.add_argument('-D', '--data_path', type=str, default='../data/ReacDiff2D/rd.h5')
    parser.add_argument('-K', '--n_test_funcs', type=int, default=1000)
    parser.add_argument('-T', '--threshold', type=float, default=0.0)  # use method-specific default values if not set
    parser.add_argument('-S', '--use_symmetry', action='store_true')
    parser.add_argument('-SS', '--use_soft_symmetry', action='store_true')
    parser.add_argument('--seq_thres_method', choices=['raw', 'constrained'], default='raw')
    parser.add_argument('--sigma_a', type=float, default=1e3)
    parser.add_argument('--sigma_b', type=float, default=1e2)
    parser.add_argument('--alpha', type=float, default=1e-8)
    parser.add_argument('-R', '--save_results', action='store_true')
    parser.add_argument('--seed', type=int, default=42)
    args = parser.parse_args()

    with h5py.File(args.data_path, 'r') as f:
        x = np.asarray(f['spatial_grid'][:, 0, 0])  # (Nx,)
        y = np.asarray(f['spatial_grid'][0, :, 1])  # (Ny,)
        # spatial_grid = np.asarray(f['spatial_grid'])  # (Nx, Ny, 2)
        t = np.asarray(f['t'])  # (Nt,)
        u = np.asarray(f['u'])  # (Nx, Ny, Nt, 2)
        u_t = np.asarray(f['u_t'])  # (Nx, Ny, Nt, 2)

    gt_params, dataset_name = get_gt_params(args.data_path)

    Nt = args.n_timesteps
    if Nt > u.shape[2]:
        n_runs = u.shape[2] - Nt
        warnings.warn(f"Number of random runs is greater than the number of possible different data subsets.")
    else:
        n_runs = args.n_runs

    np.random.seed(args.seed)
    start_idx = np.random.choice(np.arange(0, u.shape[2] - Nt), size=n_runs, replace=False)

    correct_cnt = 0
    rmse_list = []

    for i, ts in tqdm(enumerate(start_idx)):
        te = ts + Nt
        if args.lib == 'weak':
            X, Y, T = np.meshgrid(x, y, t[ts:te], indexing="ij")
            XYT = np.transpose([X, Y, T], [1, 2, 3, 0])
            lib = ps.WeakPDELibrary(
                library_functions=library_functions,
                function_names=library_function_names,
                derivative_order=2,
                spatiotemporal_grid=XYT,
                is_uniform=True,
                periodic=True,
                K=args.n_test_funcs,
                include_interaction=False,
            )
            symm_constraint_path = ("./basis_rot2d_ord2_deg3.npy", "./basis_orth_rot2d_ord2_deg3.npy")
        elif args.lib == 'original':
            X, Y = np.meshgrid(x, y, indexing="ij")
            spatial_grid = np.transpose([X, Y], [1, 2, 0])
            lib = ps.PDELibrary(
                library_functions=library_functions,
                function_names=library_function_names,
                derivative_order=2,
                spatial_grid=spatial_grid,
                is_uniform=True,
                periodic=True,
                include_interaction=False,
            )
            symm_constraint_path = ("./basis_rot2d_ord2_deg3.npy", "./basis_orth_rot2d_ord2_deg3.npy")
        elif args.lib == 'extended':
            X, Y = np.meshgrid(x, y, indexing="ij")
            spatial_grid = np.transpose([X, Y], [1, 2, 0])
            lib = ExtendedPDELibrary(
                library_functions=library_functions,
                function_names=library_function_names,
                derivative_order=2,
                spatial_grid=spatial_grid,
                is_uniform=True,
                periodic=True
            )
            symm_constraint_path = ("./basis_rot2d_ord2_deg3_ext.npy", "./basis_orth_rot2d_ord2_deg3_ext.npy")  # not implemented yet

        lib.fit(u[:, :, ts:te])
        # print(lib.get_feature_names())
        # print("Number of eligible terms:", len(lib.get_feature_names()))

        # Initialize noisy data
        rmse = np.sqrt(mean_squared_error(u.flatten(), np.zeros(u.size)))
        u_noisy = u + np.random.normal(0, rmse * args.noise, u.shape)  # Add noise
        u_noisy = u_noisy[:, :, ts:te]

        if args.use_symmetry:
            threshold = 0.1 if args.threshold == 0.0 else args.threshold
            Q = np.load(symm_constraint_path[0])
            optimizer = ps.STLSQ(threshold=threshold, alpha=1e-8, normalize_columns=False)
            model = ConstrainedSINDy(constraint_tensor=Q, feature_library=lib, optimizer=optimizer)
            if args.lib == "weak":
                model.fit(u_noisy)
            else:
                u_dot = np.gradient(u_noisy, axis=2) / 0.05  # dt=0.05
                model.fit(u_noisy, x_dot=u_dot)
            pred_params = model.unconstrained_coefficients()
            success, rmse = eval_eq_params(gt_params, pred_params)
            result_dict = {
                "params": pred_params,
                "success": success,
                "rmse": rmse,
            }
            if success:
                correct_cnt += 1
                rmse_list.append(rmse)

        elif args.use_soft_symmetry:
            threshold = 0.05 if args.threshold == 0.0 else args.threshold
            Q = np.load(symm_constraint_path[0])
            P = np.load(symm_constraint_path[1])
            if args.seq_thres_method == 'constrained':
                optimizer = ps.STLSQ(threshold=threshold, alpha=1/(2*args.sigma_a**2), normalize_columns=False)
            elif args.seq_thres_method == 'raw':
                optimizer = STLSQ2(threshold=threshold, alpha=1/(2*args.sigma_a**2), normalize_columns=False, weight_transform=(Q, P*args.sigma_b/args.sigma_a))
            model = ConstrainedSINDy(Q, P, constraint_breaking_factor=args.sigma_b/args.sigma_a, feature_library=lib, optimizer=optimizer)
            if args.lib == "weak":
                model.fit(u_noisy)
            else:
                u_dot = np.gradient(u_noisy, axis=2) / 0.05  # dt=0.05
                model.fit(u_noisy, x_dot=u_dot)
            pred_params = model.unconstrained_coefficients()
            success, rmse = eval_eq_params(gt_params, pred_params)
            result_dict = {
                "params": pred_params,
                "success": success,
                "rmse": rmse,
            }
            if success:
                correct_cnt += 1
                rmse_list.append(rmse)
            
        else:
            threshold = 0.05 * np.sqrt(2) if args.threshold == 0.0 else args.threshold
            optimizer = ps.STLSQ(threshold=threshold, alpha=args.alpha, normalize_columns=False)
            model = ps.SINDy(feature_library=lib, optimizer=optimizer)
            if args.lib == "weak":
                model.fit(u_noisy)
            else:
                u_dot = np.gradient(u_noisy, axis=2) / 0.05  # dt=0.05
                model.fit(u_noisy, x_dot=u_dot)
            pred_params = model.coefficients()
            success, rmse = eval_eq_params(gt_params, pred_params)
            result_dict = {
                "params": pred_params,
                "success": success,
                "rmse": rmse,
            }
            if success:
                correct_cnt += 1
                rmse_list.append(rmse)

        model.print()
        if args.save_results:
            method = "hard_symm" if args.use_symmetry else "soft_symm" if args.use_soft_symmetry else "sindy"
            subdir = f"./results/{dataset_name}/{method}/K{args.n_test_funcs}_noise{args.noise:.3f}_Nt{Nt}"
            if not os.path.exists(subdir):
                os.makedirs(subdir)
            # np.save(f"{subdir}/run{i}.npy", raw_coef)
            np.savez(f"{subdir}/run{i}.npz", **result_dict)

    print(f"Correct count: {correct_cnt}/{n_runs}")
    print(f"Average RMSE over {correct_cnt} runs: {np.mean(rmse_list):.4f} ({np.std(rmse_list) if len(rmse_list) > 1 else 0:.4f})")
