import os
import toml
import torch as tc
import numpy as np
from math import ceil
from argparse import ArgumentParser
from tabulate import tabulate
from sklearn.model_selection import train_test_split
from src.data import MassSpring
from src.model import GNN
from src.train import sample_and_linear_solve
from src.utils import Mesh, SamplingArgs, LinearSolveArgs, eval_dxdt

argparser = ArgumentParser()
argparser.add_argument("-f", "--toml", type=str, help=".toml file for reading config of the experiment", required=True)
argparser.add_argument("-o", "--outdir", type=str, help="Output directory to save the MSE results", required=True)
args = vars(argparser.parse_args())
config = toml.load(args["toml"])
out_dir = args["outdir"]

if config["dtype"] == "float": dtype_np, dtype_tc = np.float32, tc.float32
elif config["dtype"] == "double": dtype_np, dtype_tc = np.float64, tc.float64
else: raise ValueError("Unknown precision")
n_obj, dof, auto_diff_mode, device = config["n_obj"], config["dof"], config["auto_diff_mode"], config["device"]

# For the enc width network width study
enc_widths = config["enc_widths"]
network_widths = config["network_widths"]

# For the number of message passes study
n_msg_passes = config["n_msg_passes"]
local_poolings = config["local_poolings"]

data_config = config["data"]
model_config = config["model"]
train_config = config["train"]

default_enc_width = model_config["enc_width"]
default_network_width = model_config["network_width"]
default_local_pooling = model_config["local_pooling"]

def prepare_data(n_obj, data_seed):
    rng = np.random.default_rng(data_seed)
    n_points = data_config["n_points"]
    n_features = np.prod(n_obj) * 2*dof
    x = rng.uniform(data_config["x_min"], data_config["x_max"], size=(n_points, n_features)).astype(dtype_np)
    q, p = np.split(x, 2, axis=-1)
    system = MassSpring(n_points, n_features, q, p, n_obj, dof, data_config["mass"], data_config["spring_constant"], Mesh(data_config["meshing"]))
    x_train, x_test, dxdt_train, dxdt_test = train_test_split(
        tc.from_numpy(system.to_array(flatten=True)),
        tc.from_numpy(system.dxdt(flatten=True)),
        train_size=ceil(n_points*data_config["train_test_split"]), shuffle=True, random_state=data_seed+1
    )
    L = tc.from_numpy(system.L())
    return x_train, x_test, L, system.edge_index(), dxdt_train, dxdt_test

def create_gnn(n_obj, edge_index, model_seed, enc_width=default_enc_width, network_width=default_network_width, n_msg_pass=1, local_pooling=default_local_pooling):
    return GNN(dof=dof, n_obj=n_obj, edge_index=edge_index,
               direct=model_config["direct"],
               msg_width=network_width-enc_width,
               enc_width=enc_width,
               local_pooling=local_pooling, global_pooling=model_config["global_pooling"],
               activ_str=model_config["activ_str"], init_method=model_config["init_method"],
               seed=model_seed, dtype=dtype_tc, n_msg_passes=n_msg_pass)

def sample_fit(model, x_train, L, dxdt_train, param_sampler, sample_uniformly, sampling_seed):
    sampling_args = SamplingArgs(
        param_sampler=param_sampler,
        seed=sampling_seed,
        sample_uniformly=sample_uniformly,
        resample_duplicates=train_config["resample_duplicates"],
        dtype=dtype_np,
    )
    linear_solve_args = LinearSolveArgs(
        mode=auto_diff_mode,
        driver=train_config["driver"],
        rcond=train_config["rcond"],
        device=device,
    )
    cond = sample_and_linear_solve(model, x_train, L, dxdt_train, sampling_args, linear_solve_args, return_cond=True)
    dxdt_mse, dxdt_rel2 = eval_dxdt(model, x_train, L, dxdt_train, mode=auto_diff_mode, verbose=False)
    return dxdt_mse, dxdt_rel2, cond

def tabulate_errors(all_error, table_title, out_filename_prefix):
    all_mse = np.vstack(all_error)
    all_mse_path = os.path.join(out_dir, out_filename_prefix + ".npy")
    np.save(all_mse_path, all_mse)
    print("All MSEs are gathered as:")
    print(all_mse)
    print()
    table = tabulate(
        list(all_mse),
        floatfmt=".3e",
    )
    print(table_title + '\n' + table)
    table_path = os.path.join(out_dir, out_filename_prefix + ".txt")
    with open(table_path, 'w') as f:
        f.write(table_title + '\n' + table)
    print(f"-> MSEs are saved at: '{all_mse_path}'")
    print(f"-> Table saved at: {table_path}")

# Ablation study with encoder and network widths
all_error = []  # MSE
all_cond = []   # Condition number of the matrix associated with the linear system
for enc_width in enc_widths:
    col_err = []
    col_cond = []
    for width in network_widths:
        # Prepare data
        x_train, x_test, L, edge_index, dxdt_train, dxdt_test = prepare_data(n_obj, data_config["data_seed"])
        # Create GNN
        rf_hgn = create_gnn(n_obj, edge_index, model_config["model_seed"],
                            enc_width=enc_width, network_width=width)
        # Train (SWIM) RF-HGN
        _, _, cond = sample_fit(rf_hgn, x_train, L, dxdt_train, param_sampler=train_config["param_sampler"], sample_uniformly=True, sampling_seed=train_config["sampling_seed"])
        # Evaluate (Test set MSE)
        mse, _ = eval_dxdt(rf_hgn, x_test, L, dxdt_test, mode=auto_diff_mode, verbose=False)
        col_err.append(mse)
        col_cond.append(cond)
        print(f"Test MSE[{enc_width}, {width}] = {mse}")
        print(f"Cond[{enc_width}, {width}] = {cond}")
    all_error.append(col_err)
    all_cond.append(col_cond)

tabulate_errors(
    all_error,
    "Table: Ablation study of Test MSE with rows (increasing encoder widths) and columns (increasing network widths)",
    "ablation_study"
)
tabulate_errors(
    all_cond,
    "Table: Ablation study of the condition number with rows (increasing encoder widths) and columns (increasing network widths)",
    "ablation_study_cond"
)


# Ablation study with number of message passes
all_error = []  # MSE
for local_pooling in local_poolings:
    col_err = []
    for n_msg_pass in n_msg_passes:
        # Prepare data
        x_train, x_test, L, edge_index, dxdt_train, dxdt_test = prepare_data(n_obj, data_config["data_seed"])
        # Create GNN
        rf_hgn = create_gnn(n_obj, edge_index, model_config["model_seed"],
                            local_pooling=local_pooling, n_msg_pass=n_msg_pass)
        # Train (SWIM) RF-HGN
        sample_fit(rf_hgn, x_train, L, dxdt_train, param_sampler=train_config["param_sampler"], sample_uniformly=True, sampling_seed=train_config["sampling_seed"])
        # Evaluate (Test set MSE)
        mse, _ = eval_dxdt(rf_hgn, x_test, L, dxdt_test, mode=auto_diff_mode, verbose=False)
        col_err.append(mse)
        print(f"Test MSE[{local_pooling}, {n_msg_pass}] = {mse}")
    all_error.append(col_err)

tabulate_errors(
    all_error,
    "Table: Ablation study with rows (type of local pooling) and columns (increasing number of message passes)",
    "ablation_study_msg_passes"
)

exit(0)
