import argparse
import os
import json

from src.methods.dfl_abstract import DFL
from src.training.experiment import Experiment
from src.training.environment_builder import EnvironmentBuilder

TRAIN_CONFIG_PATH = os.path.join("experiments", "training configurations")
TRAIN_RESULTS_PATH = os.path.join("experiments", "results")
MODEL_CONFIGURATIONS_PATH = os.path.join("experiments", "model configurations")


def parse_arguments() -> argparse.Namespace:

    parser = argparse.ArgumentParser(description='DFL training')
    parser.add_argument("config_name", type=str, help="Name of the training configuration file")
    parser.add_argument("--epochs", type=int, help="Number of training epochs", default=1000)
    parser.add_argument("--batch_size", type=int, help="Number of batch samples", default=32)
    parser.add_argument("--train_ratio", type=float, help="Training samples ratio", default=0.8)
    parser.add_argument("--val_ratio", type=float, help="Validation samples ratio", default=0.1)
    parser.add_argument("--time_limit", type=float, help="Training time limit", default=0.0)

    _args = parser.parse_args()

    return _args


def save_results(_args: argparse.Namespace, _seeds: list[int], _configurations: list[str], _problem: str,
                 _dataset_name: str, _results) -> None:

    if not os.path.exists(TRAIN_RESULTS_PATH):
        os.mkdir(TRAIN_RESULTS_PATH)

    output_path = os.path.join(TRAIN_RESULTS_PATH, _args.config_name + ".txt")

    _time_limit = None if _args.time_limit == 0.0 else _args.time_limit

    with open(output_path, "w") as f:

        f.write("TRAINING PARAMETERS:\n\n")
        f.write("Problem: {}\n".format(problem))
        f.write("Dataset name: {}\n".format(dataset_name))
        f.write("Seeds: {}\n".format(seeds))
        f.write("Epochs: {}\n".format(_args.epochs))
        f.write("Batch size: {}\n".format(_args.batch_size))
        f.write("Train limit: {}\n".format(_time_limit))
        f.write("Train-Val-Test ratio: {}-{}-{}\n\n".format(_args.train_ratio, _args.val_ratio,
                                                            round(1.0 - (_args.train_ratio + _args.val_ratio), 2)))

        f.write("MODELS:\n\n")
        for i, config in enumerate(_configurations):
            f.write("{}) ".format(i+1))
            json_path = os.path.join(MODEL_CONFIGURATIONS_PATH, config + ".json")
            with open(json_path, 'r') as json_file:
                params = json.load(json_file)
                for param in params:
                    f.write("{}: {}\n".format(param, params[param]))
            f.write("\n")

        f.write("RESULTS:\n\n")
        f.write(str(_results))


if __name__ == '__main__':

    args = parse_arguments()
    config_name = args.config_name
    epochs = args.epochs
    batch_size = args.batch_size
    train_ratio = args.train_ratio
    val_ratio = args.val_ratio
    time_limit = None if args.time_limit == 0.0 else args.time_limit

    config_path = os.path.join(TRAIN_CONFIG_PATH, config_name + ".json")
    with open(config_path, 'r') as file:
        parameters = json.load(file)
        problem = parameters["problem"]
        dataset_name = parameters["dataset_name"]
        seeds = parameters["seeds"]
        configurations = parameters["configurations"]

    solver, dataset = EnvironmentBuilder.parse_problem(problem, dataset_name)

    input_dim, output_dim = dataset.get_dims()

    experiment = Experiment(config_name, dataset, solver, train_ratio, val_ratio)

    for configuration in configurations:
        trainer: DFL = EnvironmentBuilder.parse_configuration(configuration, input_dim, output_dim)
        experiment.add_trainer(trainer)

    if len(seeds) == 1:
        results = experiment.launch_single(seeds[0], epochs, batch_size, time_limit=time_limit)
    else:
        results = experiment.launch_multiple(seeds, epochs, batch_size, time_limit=time_limit)

    save_results(args, seeds, configurations, problem, dataset_name, results)

    print("Results saved in {}\n".format(TRAIN_RESULTS_PATH))
