import os
import toml
import torch as tc
import numpy as np
from math import ceil
from argparse import ArgumentParser
from tabulate import tabulate
from sklearn.model_selection import train_test_split
from src.data import MassSpring
from src.model import GNN
from src.train import sample_and_linear_solve
from src.utils import Mesh, SamplingArgs, LinearSolveArgs, eval_dxdt

argparser = ArgumentParser()
argparser.add_argument("-f", "--toml", type=str, help=".toml file for reading config of the experiment", required=True)
argparser.add_argument("-o", "--outdir", type=str, help="Output directory to save the MSE results", required=True)
args = vars(argparser.parse_args())
config = toml.load(args["toml"])
out_dir = args["outdir"]

if config["dtype"] == "float": dtype_np, dtype_tc = np.float32, tc.float32
elif config["dtype"] == "double": dtype_np, dtype_tc = np.float64, tc.float64
else: raise ValueError("Unknown precision")
n_obj, dof, auto_diff_mode, device = config["n_obj"], config["dof"], config["auto_diff_mode"], config["device"]

enc_widths = config["enc_widths"]
network_widths = config["network_widths"]

data_config = config["data"]
model_config = config["model"]
train_config = config["train"]

def prepare_data(n_obj, data_seed):
    rng = np.random.default_rng(data_seed)
    n_points = data_config["n_points"]
    n_features = np.prod(n_obj) * 2*dof
    x = rng.uniform(data_config["x_min"], data_config["x_max"], size=(n_points, n_features)).astype(dtype_np)
    q, p = np.split(x, 2, axis=-1)
    system = MassSpring(n_points, n_features, q, p, n_obj, dof, data_config["mass"], data_config["spring_constant"], Mesh(data_config["meshing"]))
    x_train, x_test, dxdt_train, dxdt_test = train_test_split(
        tc.from_numpy(system.to_array(flatten=True)),
        tc.from_numpy(system.dxdt(flatten=True)),
        train_size=ceil(n_points*data_config["train_test_split"]), shuffle=True, random_state=data_seed+1
    )
    L = tc.from_numpy(system.L())
    return x_train, x_test, L, system.edge_index(), dxdt_train, dxdt_test

def create_gnn(n_obj, edge_index, enc_width, network_width, model_seed):
    return GNN(dof=dof, n_obj=n_obj, edge_index=edge_index,
               direct=model_config["direct"],
               msg_width=network_width-enc_width,
               enc_width=enc_width,
               local_pooling=model_config["local_pooling"], global_pooling=model_config["global_pooling"],
               activ_str=model_config["activ_str"], init_method=model_config["init_method"],
               seed=model_seed, dtype=dtype_tc)

def sample_fit(model, x_train, L, dxdt_train, param_sampler, sample_uniformly, sampling_seed):
    sampling_args = SamplingArgs(
        param_sampler=param_sampler,
        seed=sampling_seed,
        sample_uniformly=sample_uniformly,
        resample_duplicates=train_config["resample_duplicates"],
        dtype=dtype_np,
    )
    linear_solve_args = LinearSolveArgs(
        mode=auto_diff_mode,
        driver=train_config["driver"],
        rcond=train_config["rcond"],
        device=device,
    )
    sample_and_linear_solve(model, x_train, L, dxdt_train, sampling_args, linear_solve_args)
    dxdt_mse, dxdt_rel2 = eval_dxdt(model, x_train, L, dxdt_train, mode=auto_diff_mode, verbose=False)
    return dxdt_mse, dxdt_rel2

all_error = []  # MSE
for enc_width in enc_widths:
    col_err = []
    for width in network_widths:
        # Prepare data
        x_train, x_test, L, edge_index, dxdt_train, dxdt_test = prepare_data(n_obj, data_config["data_seed"])
        # Create GNN
        rf_hgn = create_gnn(n_obj, edge_index, enc_width, width, model_config["model_seed"])
        # Train (SWIM) RF-HGN
        sample_fit(rf_hgn, x_train, L, dxdt_train, param_sampler=train_config["param_sampler"], sample_uniformly=True, sampling_seed=train_config["sampling_seed"])
        # Evaluate (Test set MSE)
        mse, _ = eval_dxdt(rf_hgn, x_test, L, dxdt_test, mode=auto_diff_mode, verbose=False)
        col_err.append(mse)
        print(f"Test MSE[{enc_width}, {width}] = {mse}")
    all_error.append(col_err)

all_mse = np.vstack(all_error)
all_mse_path = os.path.join(out_dir, "ablation_study.npy")
np.save(all_mse_path, all_mse)
print("All MSEs are gathered as:")
print(all_mse)
print()
table_title = "Table: Ablation study with rows (increasing encoder widths) and columns (increasing network widths)"
table = tabulate(
    list(all_mse),
    floatfmt=".2e",
)
print(table_title + '\n' + table)
table_path = os.path.join(out_dir, "ablation_study.txt")
with open(table_path, 'w') as f:
    f.write(table_title + '\n' + table)
print(f"-> MSEs are saved at: '{all_mse_path}'")
print(f"-> Table saved at: {table_path}")
exit(0)
