import os
import toml
import numpy as np
import torch as tc
from tabulate import tabulate
from argparse import ArgumentParser
from math import ceil
from tqdm import tqdm
from sklearn.model_selection import KFold
from src.data import MassSpring
from src.utils import Mesh, SamplingArgs, LinearSolveArgs, eval_dxdt
from src.train import sample_and_linear_solve
from src.model import GNN

argparser = ArgumentParser()
argparser.add_argument("-f", "--toml", type=str, required=True,
                       help=".toml file for reading config of the experiment")
argparser.add_argument("-o", "--outdir", type=str, required=True,
                       help="Output directory to save the integration results and models")
argparser.add_argument("-a", "--abs", action="store_true", default=False,
                       help="Take absolute differences of positions when encoding edge features (incorporates additional bias).")
args = vars(argparser.parse_args())
out_dir = args["outdir"]
config = toml.load(args["toml"])
if config["dtype"] == "float": dtype_np, dtype_tc = np.float32, tc.float32
elif config["dtype"] == "double": dtype_np, dtype_tc = np.float64, tc.float64
else: raise ValueError("Unknown precision")

rng = np.random.default_rng(config["seed"])
get_seed = lambda: rng.integers(0, 10**9)
dof = config["dof"]
n_obj = config["n_obj"]
cross_validation_n_splits = config["cross_validation_n_splits"]

# Noise scales
noise_scales = config["noise_scales"]
print("-> Noise scales are in", noise_scales)

def train_test_split(x, y, y_noisy, train_set_ratio, rng):
    indices = np.arange(len(x))
    rng.shuffle(indices)
    train_size: int = ceil(len(x) * train_set_ratio)
    train_indices, test_indices = indices[:train_size], indices[train_size:]
    # Get y from truths for the test set, from noisy samples for the train set
    x_train, x_test = x[train_indices], x[test_indices]
    y_train, y_test = y_noisy[train_indices], y[test_indices]
    return x_train, x_test, y_train, y_test

def prepare_mass_spring_data(n_obj, data_seed, noise_scale=0.0):
    rng = np.random.default_rng(data_seed)
    n_features = np.prod(n_obj) * 2*dof
    x = rng.uniform(config["x_min"], config["x_max"], size=(config["n_points"], n_features)).astype(dtype_np)
    q, p = np.split(x, 2, axis=-1)
    system = MassSpring(config["n_points"], n_features, q, p, n_obj, dof,
                        config["mass"], config["spring_constant"], Mesh("rectangular"))

    x_all = tc.from_numpy(system.to_array(flatten=True))
    y_all = tc.from_numpy(system.dxdt(flatten=True, noise_scale=0.0, rng=None))
    y_all_noisy = tc.from_numpy(system.dxdt(flatten=True, noise_scale=noise_scale, rng=rng))

    x_train, x_test, dxdt_train, dxdt_test = train_test_split(x_all, y_all, y_all_noisy,
                                                              config["train_set_ratio"], rng)

    L = tc.from_numpy(system.L())
    return x_train, x_test, L, system.edge_index(), dxdt_train, dxdt_test

def create_gnn(n_obj, edge_index, model_seed):
    return GNN(dof=dof, n_obj=n_obj, edge_index=edge_index,
               take_absolute_diff=args["abs"],
               direct=False,
               msg_width = config["network_width"] - config["enc_width"],
               enc_width=config["enc_width"],
               local_pooling="sum", global_pooling="sum",
               activ_str=config["activ_str"], init_method="none",
               seed=model_seed, dtype=dtype_tc)

def sample_fit(model, x_train, L, dxdt_train, sampling_seed):
    sampling_args = SamplingArgs(
        seed=sampling_seed,
        param_sampler=config["param_sampler"],
        sample_uniformly=True,
        dtype=dtype_np,
    )
    linear_solve_args = LinearSolveArgs(driver=config["driver"], rcond=config["rcond"], device=config["device"])
    sample_and_linear_solve(model, x_train, L, dxdt_train, sampling_args, linear_solve_args)
    dxdt_mse, dxdt_rel2 = eval_dxdt(model, x_train, L, dxdt_train, verbose=False)
    return dxdt_mse, dxdt_rel2

def cross_validate(data_seed, initial_model_seed, initial_sampling_seed):
    cv_mse_lst = []
    cv_rel2_lst = []
    test_mse_lst = []
    test_rel2_lst = []
    for idx, noise_scale in tqdm(enumerate(noise_scales)):
        # Train/Validation/Test split and do cross-validation and take average error on the test set
        x_train, x_test, L, edge_index, dxdt_train, dxdt_test = prepare_mass_spring_data(n_obj, data_seed, noise_scale)
        print()
        print(f"-> {cross_validation_n_splits}-Fold Cross-Validation with the following data: noise scale idx {idx+1}/{len(noise_scales)}")
        print(f"- dqdt_train min {dxdt_train[..., :dof].min():.15e}  dpdt_train min {dxdt_train[..., dof:].min():.15e}")
        print(f"- dqdt_train max {dxdt_train[..., :dof].max():.15e}  dpdt_train max {dxdt_train[..., dof:].max():.15e}")
        print(f"- dqdt_test  min {dxdt_test[..., :dof].min():.15e}  dpdt_test min {dxdt_test[..., dof:].min():.15e}")
        print(f"- dqdt_test  max {dxdt_test[..., :dof].max():.15e}  dpdt_test max {dxdt_test[..., dof:].max():.15e}")
        # print(f"-> Running with noise_scale = {noise_scale} added to x_train (normal distributed with loc=0.0, scale=noise_scale) when computing dxdt(x_train)     {idx}/{len(noise_scales)}")

        kfold = KFold(n_splits=cross_validation_n_splits)
        current_model_seed = initial_model_seed
        current_sampling_seed = initial_sampling_seed
        cv_mse, cv_rel2 = 0.0, 0.0
        test_mse, test_rel2 = 0.0, 0.0
        for train_indices, val_indices in kfold.split(x_train.reshape(len(x_train), -1)):
            gnn = create_gnn(n_obj, edge_index, current_model_seed)
            current_model_seed += 1
            current_sampling_seed += 1
            sample_fit(gnn, x_train[train_indices], L, dxdt_train[train_indices], current_sampling_seed)
            dxdt_mse, dxdt_rel2 = eval_dxdt(gnn, x_train[val_indices], L, dxdt_train[val_indices], verbose=False)
            cv_mse += dxdt_mse
            cv_rel2 += dxdt_rel2
            dxdt_mse, dxdt_rel2 = eval_dxdt(gnn, x_test, L, dxdt_test, verbose=False)
            test_mse += dxdt_mse
            test_rel2 += dxdt_rel2

            # Update seeds related to model and training
            current_model_seed, current_sampling_seed = get_seed(), get_seed()

        cv_mse /= kfold.n_splits
        cv_rel2 /= kfold.n_splits
        cv_mse_lst.append(cv_mse)
        cv_rel2_lst.append(cv_rel2)

        test_mse /= kfold.n_splits
        test_rel2 /= kfold.n_splits
        test_mse_lst.append(test_mse)
        test_rel2_lst.append(test_rel2)

    for arr in [cv_mse_lst, cv_rel2_lst, test_mse_lst, test_rel2_lst]:
        arr = np.array(arr)

    return cv_mse_lst, cv_rel2_lst, test_mse_lst, test_rel2_lst

# cross-validation MSE and relative L2 errors (average over k splits, gathered for all repeats)
cv_mse_lst_avg, cv_rel2_lst_avg = np.zeros(len(noise_scales)), np.zeros(len(noise_scales))
# test MSE and relative L2 errors (average over k splits, gathered for all repeats)
test_mse_lst_avg, test_rel2_lst_avg = np.zeros(len(noise_scales)), np.zeros(len(noise_scales))
for repeat_idx in tqdm(range(config["n_repeats"])):
    cv_mse_lst, cv_rel2_lst, test_mse_lst, test_rel2_lst = cross_validate(data_seed=get_seed(),
                                                                          initial_model_seed=get_seed(),
                                                                          initial_sampling_seed=get_seed())
    cv_mse_lst_avg += cv_mse_lst
    cv_rel2_lst_avg += cv_rel2_lst
    test_mse_lst_avg += test_mse_lst
    test_rel2_lst_avg += test_rel2_lst

cv_mse_lst_avg /= config["n_repeats"]
cv_rel2_lst_avg /= config["n_repeats"]
test_mse_lst_avg /= config["n_repeats"]
test_rel2_lst_avg /= config["n_repeats"]

results_table = tabulate(
    headers=["sigma"] + noise_scales,
    tabular_data=[
        ["CV MSE"] + list(cv_mse_lst_avg),
        ["CV L2 rel."] + list(cv_rel2_lst_avg),
        ["Test MSE"] + list(test_mse_lst_avg),
        ["Test L2 rel."] + list(test_rel2_lst_avg),
    ],
    floatfmt=".2e"
)

print("\nTable: Noise scale study results with additive Gaussian noise are displayed.")
print(results_table)
filename = os.path.join(out_dir, "noise_scale_study.txt")
with open(filename, 'w') as f:
    f.write(results_table)
print("-> Results are saved at", filename)
exit(0)
