import os
import toml
import numpy as np
import torch as tc
import time
from tabulate import tabulate
from argparse import ArgumentParser
from math import ceil
from tqdm import tqdm
from sklearn.model_selection import KFold
from src.data import LennardJones, get_equil_r, get_eq_lattice_x0
from src.model import GNN
from src.train import sample_and_linear_solve
from src.utils import eval_dxdt, LinearSolveArgs, SamplingArgs

argparser = ArgumentParser()
argparser.add_argument("-f", "--toml", type=str, required=True,
                       help=".toml file for reading config of the experiment")
argparser.add_argument("-o", "--outdir", type=str, required=True,
                       help="Output directory to save the integration results and models")
argparser.add_argument("-a", "--abs", action="store_true", default=False,
                       help="Take absolute differences of positions when encoding edge features (incorporates additional bias).")

args = vars(argparser.parse_args())
out_dir = args["outdir"]
config = toml.load(args["toml"])
if config["dtype"] == "float": dtype_np, dtype_tc = np.float32, tc.float32
elif config["dtype"] == "double": dtype_np, dtype_tc = np.float64, tc.float64
else: raise ValueError("Unknown precision")

rng = np.random.default_rng(config["seed"])
get_seed = lambda: rng.integers(0, 10**9)
dof = 2
n_obj = config["n_obj_x"] * config["n_obj_y"]

GB = 1024 ** 3
convert_bytes_to_GB = lambda n_bytes: n_bytes / GB

cutoff = np.inf if config["cutoff"] == "inf" else config["config"]
q_uniform_scale = config["q_uniform_scale"]
p_uniform_scale = config["p_uniform_scale"]

def get_x0_and_box(nx, ny, sig, r_eq, q_uniform_scale, p_uniform_scale, rng):
    """Creates initial state with a proper tight box for reflective boundary condition
    Returns x of shape (n_obj, 2*dof), box_start, box_end
    """
    q = get_eq_lattice_x0(nx, ny, sig).astype(dtype_np)
    p = np.zeros_like(q, dtype=dtype_np)
    q += rng.uniform(-q_uniform_scale, q_uniform_scale, size=q.shape).astype(dtype_np)
    p += rng.uniform(-p_uniform_scale, p_uniform_scale, size=p.shape).astype(dtype_np)
    return q, p, np.min(q) - (r_eq / 2.0), np.max(q) + (r_eq)

def get_indices(arr, indices):
    new_arr = []
    for index in indices:
        new_arr.append(arr[index])
    return new_arr

def train_test_split(x, edge_index, y, y_noisy, train_set_ratio, rng):
    indices = np.arange(len(x))
    rng.shuffle(indices)
    train_size: int = ceil(len(x) * train_set_ratio)
    train_indices, test_indices = indices[:train_size], indices[train_size:]
    # Get y from truths for the test set, from noisy samples for the train set
    x_train, x_test = x[train_indices], x[test_indices]
    edge_index_train, edge_index_test = get_indices(edge_index, train_indices), get_indices(edge_index, test_indices)
    y_train, y_test = y_noisy[train_indices], y[test_indices]
    return x_train, x_test, edge_index_train, edge_index_test, y_train, y_test

def prepare_lennard_jones_data(n_obj, data_seed, noise_scale=0.0):
    rng = np.random.default_rng(data_seed)
    n_features = n_obj * (2 * dof)
    r_eq = get_equil_r(config["sigma"])
    q = np.empty(shape=(config["n_points"], n_features // 2))
    p = np.empty(shape=(config["n_points"], n_features // 2))
    for idx in range(config["n_points"]):
        q0, p0, _, _ = get_x0_and_box(config["n_obj_x"], config["n_obj_y"], config["sigma"],
                                      r_eq, q_uniform_scale, p_uniform_scale, rng)
        q[idx] = q0.reshape(-1)
        p[idx] = p0.reshape(-1)
    x = np.column_stack([q, p]).astype(dtype_np)
    q, p = np.split(x, 2, axis=-1)
    system = LennardJones(config["n_points"], n_features, q, p, n_obj, dof,
                          config["mass"], config["epsilon"], config["sigma"], cutoff)

    x_all = tc.from_numpy(system.to_array(flatten=True))
    # edge_index_all = np.array(system.edge_index(degree_limit=None)) # no degree limit during training
    edge_index_all = system.edge_index(degree_limit=None) # no degree limit during training
    y_all = tc.from_numpy(system.dxdt(flatten=True, noise_scale=0.0, rng=None))
    y_all_noisy = tc.from_numpy(system.dxdt(flatten=True, noise_scale=noise_scale, rng=rng))

    # return x_all, y_all, y_all_noisy
    x_train, x_test, edge_index_train, edge_index_test, dxdt_train, dxdt_test = \
            train_test_split(x_all, edge_index_all, y_all, y_all_noisy, config["train_set_ratio"], rng)

    L = tc.from_numpy(system.L())
    return x_train, x_test, L, edge_index_train, edge_index_test, dxdt_train, dxdt_test

def create_gnn(n_obj, edge_index, model_seed):
    return GNN(dof=dof, n_obj=n_obj, edge_index=edge_index, take_absolute_diff=args["abs"],
               direct=False,
               msg_width=config["network_width"] - config["enc_width"],
               enc_width=config["enc_width"],
               local_pooling="sum", global_pooling="sum",
               activ_str=config["activ_str"], init_method="none",
               seed=model_seed, dtype=dtype_tc)

def sample_fit(model, x_train, L, dxdt_train, sampling_seed, param_sampler, batch_size):
    sampling_args = SamplingArgs(
        seed=sampling_seed,
        param_sampler=param_sampler,
        sample_uniformly=True,
        dtype=dtype_np,
    )
    linear_solve_args = LinearSolveArgs(driver=config["driver"], rcond=config["rcond"],
                                        device=config["device"], batch_size=batch_size)
    sample_and_linear_solve(model, x_train, L, dxdt_train, sampling_args, linear_solve_args)
    dxdt_mse, dxdt_rel2 = eval_dxdt(model, x_train, L, dxdt_train, verbose=False)
    return dxdt_mse, dxdt_rel2

def cross_validate(x_train, x_test, L, edge_index_train, edge_index_test, dxdt_train, dxdt_test,
                   initial_model_seed, initial_sampling_seed, param_sampler, batch_size):
    kfold = KFold(n_splits=config["cross_validation_n_splits"])
    current_model_seed = initial_model_seed
    current_sampling_seed = initial_sampling_seed
    cv_mse, cv_rel2 = 0.0, 0.0
    test_mse, test_rel2 = 0.0, 0.0
    runtime = 0.0
    max_memory_allocated_in_GB = 0.0
    for train_indices, val_indices in kfold.split(x_train.reshape(len(x_train), -1)):
        gnn = create_gnn([n_obj], get_indices(edge_index_train, train_indices), current_model_seed)
        current_model_seed += 1
        current_sampling_seed += 1
        tc.cuda.reset_peak_memory_stats(config["device"])
        tc.cuda.empty_cache()
        time0 = time.perf_counter()
        sample_fit(gnn, x_train[train_indices], L, dxdt_train[train_indices], current_sampling_seed,
                   param_sampler, batch_size)
        time1 = time.perf_counter()
        max_memory_allocated_in_GB += convert_bytes_to_GB(tc.cuda.max_memory_allocated(config["device"]))
        runtime += time1 - time0
        gnn.edge_index = get_indices(edge_index_train, val_indices)
        dxdt_mse, dxdt_rel2 = eval_dxdt(gnn, x_train[val_indices], L, dxdt_train[val_indices], verbose=False)
        cv_mse += dxdt_mse
        cv_rel2 += dxdt_rel2
        gnn.edge_index = edge_index_test
        dxdt_mse, dxdt_rel2 = eval_dxdt(gnn, x_test, L, dxdt_test, verbose=False)
        test_mse += dxdt_mse
        test_rel2 += dxdt_rel2

        # Update seeds related to model and training
        current_model_seed = np.random.default_rng(current_model_seed).integers(0, 10**9)
        current_sampling_seed = np.random.default_rng(current_sampling_seed).integers(0, 10**9)

    cv_mse /= kfold.n_splits
    cv_rel2 /= kfold.n_splits
    test_mse /= kfold.n_splits
    test_rel2 /= kfold.n_splits
    runtime /= kfold.n_splits
    max_memory_allocated_in_GB /= kfold.n_splits

    return cv_mse, cv_rel2, test_mse, test_rel2, runtime, max_memory_allocated_in_GB

# cross-validation MSE and relative L2 errors (average over k splits, gathered for all repeats)
# results will be aggregated to these lists
swim_cv_mse_lst_avg, swim_cv_rel2_lst_avg = 0.0, 0.0
swim_batched_cv_mse_lst_avg, swim_batched_cv_rel2_lst_avg = 0.0, 0.0
elm_cv_mse_lst_avg, elm_cv_rel2_lst_avg = 0.0, 0.0
elm_batched_cv_mse_lst_avg, elm_batched_cv_rel2_lst_avg = 0.0, 0.0

# test MSE and relative L2 errors (average over k splits, gathered for all repeats)
swim_test_mse_lst_avg, swim_test_rel2_lst_avg = 0.0, 0.0
swim_batched_test_mse_lst_avg, swim_batched_test_rel2_lst_avg = 0.0, 0.0
elm_test_mse_lst_avg, elm_test_rel2_lst_avg = 0.0, 0.0
elm_batched_test_mse_lst_avg, elm_batched_test_rel2_lst_avg = 0.0, 0.0

# runtimes and memory usage in GB
swim_runtime_avg, swim_mem_avg = 0.0, 0.0
swim_batched_runtime_avg, swim_batched_mem_avg = 0.0, 0.0
elm_runtime_avg, elm_mem_avg = 0.0, 0.0
elm_batched_runtime_avg, elm_batched_mem_avg = 0.0, 0.0

for repeat_idx in tqdm(range(config["n_repeats"])):
    data_seed, initial_model_seed, initial_sampling_seed = get_seed(), get_seed(), get_seed()
    # Train/Validation/Test split and do cross-validation and take average error on the test set
    x_train, x_test, L, edge_index_train, edge_index_test, dxdt_train, dxdt_test = \
        prepare_lennard_jones_data(n_obj, data_seed)
    print(f"Repeating experiments with new data: Repeat index {repeat_idx + 1}/{config['n_repeats']}")
    print(f"- dqdt_train min {dxdt_train[..., :dof].min():.15e}  dpdt_train min {dxdt_train[..., dof:].min():.15e}")
    print(f"- dqdt_train max {dxdt_train[..., :dof].max():.15e}  dpdt_train max {dxdt_train[..., dof:].max():.15e}")
    print(f"- dqdt_test  min {dxdt_test[..., :dof].min():.15e}  dpdt_test min {dxdt_test[..., dof:].min():.15e}")
    print(f"- dqdt_test  max {dxdt_test[..., :dof].max():.15e}  dpdt_test max {dxdt_test[..., dof:].max():.15e}")

    # ELM without batching
    param_sampler = "random" # ELM
    batch_size = None
    cv_mse_lst, cv_rel2_lst, test_mse_lst, test_rel2_lst, runtime, mem = \
            cross_validate(x_train, x_test, L, edge_index_train, edge_index_test, dxdt_train, dxdt_test,
                           initial_model_seed, initial_sampling_seed, param_sampler, batch_size)
    elm_cv_mse_lst_avg += cv_mse_lst
    elm_cv_rel2_lst_avg += cv_rel2_lst
    elm_test_mse_lst_avg += test_mse_lst
    elm_test_rel2_lst_avg += test_rel2_lst
    elm_runtime_avg += runtime
    elm_mem_avg += mem

    # ELM with batching
    batch_size = config["batch_size"]
    cv_mse_lst, cv_rel2_lst, test_mse_lst, test_rel2_lst, runtime, mem = \
            cross_validate(x_train, x_test, L, edge_index_train, edge_index_test, dxdt_train, dxdt_test,
                           initial_model_seed, initial_sampling_seed, param_sampler, batch_size)
    elm_batched_cv_mse_lst_avg += cv_mse_lst
    elm_batched_cv_rel2_lst_avg += cv_rel2_lst
    elm_batched_test_mse_lst_avg += test_mse_lst
    elm_batched_test_rel2_lst_avg += test_rel2_lst
    elm_batched_runtime_avg += runtime
    elm_batched_mem_avg += mem

    # SWIM without batching
    param_sampler = "relu"
    batch_size = None
    cv_mse_lst, cv_rel2_lst, test_mse_lst, test_rel2_lst, runtime, mem = \
            cross_validate(x_train, x_test, L, edge_index_train, edge_index_test, dxdt_train, dxdt_test,
                           initial_model_seed, initial_sampling_seed, param_sampler, batch_size)
    swim_cv_mse_lst_avg += cv_mse_lst
    swim_cv_rel2_lst_avg += cv_rel2_lst
    swim_test_mse_lst_avg += test_mse_lst
    swim_test_rel2_lst_avg += test_rel2_lst
    swim_runtime_avg += runtime
    swim_mem_avg += mem

    # SWIM with batching
    batch_size = config["batch_size"]
    cv_mse_lst, cv_rel2_lst, test_mse_lst, test_rel2_lst, runtime, mem = \
            cross_validate(x_train, x_test, L, edge_index_train, edge_index_test, dxdt_train, dxdt_test,
                           initial_model_seed, initial_sampling_seed, param_sampler, batch_size)
    swim_batched_cv_mse_lst_avg += cv_mse_lst
    swim_batched_cv_rel2_lst_avg += cv_rel2_lst
    swim_batched_test_mse_lst_avg += test_mse_lst
    swim_batched_test_rel2_lst_avg += test_rel2_lst
    swim_batched_runtime_avg += runtime
    swim_batched_mem_avg += mem

cv_mse_all = [elm_cv_mse_lst_avg, elm_batched_cv_mse_lst_avg, swim_cv_mse_lst_avg, swim_batched_cv_mse_lst_avg]
cv_rel2_all = [elm_cv_rel2_lst_avg, elm_batched_cv_rel2_lst_avg, swim_cv_rel2_lst_avg, swim_batched_cv_rel2_lst_avg]
test_mse_all = [elm_test_mse_lst_avg, elm_batched_test_mse_lst_avg, swim_test_mse_lst_avg, swim_batched_test_mse_lst_avg]
test_rel2_all = [elm_test_rel2_lst_avg, elm_batched_test_rel2_lst_avg, swim_test_rel2_lst_avg, swim_batched_test_rel2_lst_avg]
runtime_all = [elm_runtime_avg, elm_batched_runtime_avg, swim_runtime_avg, swim_batched_runtime_avg]
mem_all = [elm_mem_avg, elm_batched_mem_avg, swim_mem_avg, swim_batched_mem_avg]

for experiment_sums in [cv_mse_all, cv_rel2_all, test_mse_all, test_rel2_all, runtime_all, mem_all]:
    for idx in range(len(experiment_sums)):
        experiment_sums[idx] /= config["n_repeats"]

arr_columns = ["ELM", "ELM (batched)", "SWIM", "SWIM (batched)"]
results_table = tabulate(
    headers=["Method"] + arr_columns,
    tabular_data=[
        ["CV MSE"] + list(cv_mse_all),
        ["CV L2 rel."] + list(cv_rel2_all),
        ["Test MSE"] + list(test_mse_all),
        ["Test L2 rel."] + list(test_rel2_all),
        ["Runtime (s)"] + list(runtime_all),
        ["Max memory usage (GiB)"] + list(mem_all),
    ],
    floatfmt=".2e"
)

print("\nTable: Batch-wise training study results with ELM and SWIM random feature sampling methods are displayed.")
print(results_table)
print("- Runtimes")
print(runtime_all)
print("- Memory usage")
print(mem_all)
filename = os.path.join(out_dir, "lennard_jones_with_batching.txt")
with open(filename, 'w') as f:
    f.write(results_table + '\n' + str(runtime_all) + '\n' + str(mem_all))
print("-> Results are saved at", filename)
exit(0)
