import toml
import numpy as np
import torch as tc
from argparse import ArgumentParser
from sklearn.model_selection import train_test_split
from src.data import MassSpring
from src.utils import Mesh, SamplingArgs, LinearSolveArgs, eval_dxdt
from src.train import sample_and_linear_solve
from src.model import GNN

argparser = ArgumentParser()
argparser.add_argument("-f", "--toml", type=str, required=True,
                       help=".toml file for reading config of the experiment")
argparser.add_argument("-a", "--abs", action="store_true", default=False,
                       help="Take absolute differences of positions when encoding edge features (incorporates additional bias).")
args = vars(argparser.parse_args())
config = toml.load(args["toml"])
if config["dtype"] == "float": dtype_np, dtype_tc = np.float32, tc.float32
elif config["dtype"] == "double": dtype_np, dtype_tc = np.float64, tc.float64
else: raise ValueError("Unknown precision")

# File constants
rng = np.random.default_rng(config["seed"])
get_seed = lambda: rng.integers(0, 10**9)

n_obj, dof = config["n_obj"], config["dof"]

# Model constants
msg_width = config["network_width"] - config["enc_width"]
activ_str = "softplus" # RFF will overwrite this with cosine

# experiment results for MSEs and relative L2 errors
# rows: enc_sigmas
# cols: msg_sigmas
avg_mse_list = np.empty((len(config["enc_sigmas"]), len(config["msg_sigmas"])), dtype=np.float64)
avg_rel2_list = np.empty((len(config["enc_sigmas"]), len(config["msg_sigmas"])), dtype=np.float64)

def prepare_mass_spring_data(n_obj, data_seed):
    rng = np.random.default_rng(data_seed)
    n_features = np.prod(n_obj) * 2*dof
    x = rng.uniform(config["x_min"], config["x_max"], size=(config["n_points"], n_features)).astype(dtype_np)
    q, p = np.split(x, 2, axis=-1)
    system = MassSpring(config["n_points"], n_features, q, p, n_obj, dof,
                        config["mass"], config["spring_constant"], Mesh("rectangular"))

    x_all = tc.from_numpy(system.to_array(flatten=True))
    y_all = tc.from_numpy(system.dxdt(flatten=True))

    x_train, x_test, dxdt_train, dxdt_test = train_test_split(x_all, y_all, train_size=config["train_set_ratio"], random_state=get_seed())

    L = tc.from_numpy(system.L())
    return x_train, x_test, L, system.edge_index(), dxdt_train, dxdt_test

def create_gnn(n_obj, edge_index, model_seed):
    return GNN(dof=dof, n_obj=n_obj, edge_index=edge_index,
               take_absolute_diff=args["abs"],
               direct=False,
               msg_width=msg_width,
               enc_width=config["enc_width"],
               local_pooling="sum", global_pooling="sum",
               activ_str=activ_str, init_method="none",
               seed=model_seed, dtype=dtype_tc)

def sample_fit(model, x_train, L, dxdt_train, enc_sigma, msg_sigma, sampling_seed):
    sampling_args = SamplingArgs(
        seed=sampling_seed,
        param_sampler="fourier",
        sample_uniformly=True,
        dtype=dtype_np,
        enc_sigma=enc_sigma,
        msg_sigma=msg_sigma,
    )
    linear_solve_args = LinearSolveArgs(driver=config["driver"], rcond=config["rcond"], device=config["device"])
    sample_and_linear_solve(model, x_train, L, dxdt_train, sampling_args, linear_solve_args)
    dxdt_mse, dxdt_rel2 = eval_dxdt(model, x_train, L, dxdt_train, verbose=False)
    return dxdt_mse, dxdt_rel2

data_seeds = [ get_seed() for _ in range(config["n_repeats"]) ]
model_seeds = [ get_seed() for _ in range(config["n_repeats"]) ]
sampling_seeds = [ get_seed() for _ in range(config["n_repeats"]) ]

for x_idx, enc_sigma in enumerate(config["enc_sigmas"]):
    for y_idx, msg_sigma in enumerate(config["msg_sigmas"]):
        total_mse = 0.0
        total_rel2 = 0.0
        for repeat_idx in range(config["n_repeats"]):
            x_train, x_test, L, edge_index, dxdt_train, dxdt_test = prepare_mass_spring_data(n_obj, data_seeds[repeat_idx])
            gnn = create_gnn(n_obj, edge_index, model_seeds[repeat_idx])
            sample_fit(gnn, x_train, L, dxdt_train, enc_sigma, msg_sigma, sampling_seeds[repeat_idx])
            dxdt_mse, dxdt_rel2 = eval_dxdt(gnn, x_test, L, dxdt_test, verbose=False)
            total_mse += dxdt_mse
            total_rel2 += dxdt_rel2

        avg_mse = total_mse / config["n_repeats"]
        avg_rel2 = total_rel2 / config["n_repeats"]

        avg_mse_list[x_idx, y_idx] = avg_mse
        avg_rel2_list[x_idx, y_idx] = avg_rel2
        print(f"-> avg mse, rel2 [{enc_sigma}, {msg_sigma}] = {avg_mse:.5f}, {avg_rel2:.5f}")


print("MSE List, Rows Node and Edge Encoder Sigma, Columns Message Encoder Sigma")
print(avg_mse_list)
print("Relative L2 List, Rows Node and Edge Encoder Sigma, Columns Message Encoder Sigma")
print(avg_rel2_list)
