import numpy as np
from prettytable import PrettyTable

from src.readers.dataset_reader import DatasetReader
from src.methods.dfl_abstract import DFL
from src.solvers.solver import Solver

from src.utils.common import shuffle_arrays
from src.utils.probabilities import set_seed


class Experiment:

    def __init__(self, name: str, dataset: DatasetReader, solver: Solver, train_ratio: float, val_ratio: float):

        assert 0.0 < train_ratio < 1.0
        assert 0.0 < val_ratio < 1.0
        assert train_ratio + val_ratio < 1.0

        self._name = name
        self._dataset = dataset
        self._solver = solver
        self._train_ratio = train_ratio
        self._val_ratio = val_ratio

        self._trainers: list[DFL] = []

        self._metrics = ["MSE", "Avg. regret", "Avg. relative regret", "Epochs", "Total runtime", "Pre-training runtime",
                         "Calls", "Avg. calls", "Infeasible solutions ratio"]

    def add_trainer(self, trainer: DFL) -> None:
        self._trainers.append(trainer)

    def launch_single(self, seed: int, epochs: int, batch_size: int, time_limit: float | None = None) -> PrettyTable:

        assert len(self._trainers) > 0

        train_set, test_set, val_set, problem_params = self._read_dataset(seed)

        y_mean = np.mean(train_set["y"], axis=0)
        y_std = np.std(train_set["y"], axis=0)

        table = PrettyTable(["Method"] + self._metrics)

        for trainer in self._trainers:

            print("TRAINING MODEL: {}".format(trainer.name))

            trainer.set_solver(self._solver)
            trainer.set_problem_params(problem_params)
            trainer.set_y_stats(y_mean, y_std)

            train_metrics = trainer.train(train_set["x"], train_set["y"], train_set["z"], train_set["cost"],
                                          val_set["x"], val_set["y"], val_set["z"], val_set["cost"],
                                          epochs, batch_size, time_limit)

            test_metrics = trainer.test(test_set["x"], test_set["y"], test_set["z"], test_set["cost"], batch_size)

            table.add_row([trainer.name, round(test_metrics["mse"], 4), round(test_metrics["avg. regret"], 4),
                           round(test_metrics["avg. relative regret"], 4), train_metrics["epochs"],
                           round(train_metrics["runtime"], 4), round(train_metrics["pre-training runtime"], 4),
                           train_metrics["calls"], round(train_metrics["avg. calls"], 4),
                           test_metrics["infeasible solutions ratio"]])

        return table

    def launch_multiple(self, seeds: list[int], epochs: int, batch_size: int,
                        time_limit: float | None = None) -> PrettyTable:

        assert len(self._trainers) > 0
        assert len(seeds) > 1

        table = PrettyTable(["Method"] + self._metrics)

        partial_metrics = {trainer: {metric: [] for metric in self._metrics} for trainer in self._trainers}

        for seed_count, seed in enumerate(seeds):

            train_set, test_set, val_set, problem_params = self._read_dataset(seed)

            y_mean = np.mean(train_set["y"], axis=0)
            y_std = np.std(train_set["y"], axis=0)

            for trainer in self._trainers:

                print("SEED {} - TRAINING MODEL: {}".format(seed, trainer.name))

                if seed_count > 0:
                    trainer.reset()

                trainer.set_solver(self._solver)
                trainer.set_problem_params(problem_params)
                trainer.set_y_stats(y_mean, y_std)

                train_metrics = trainer.train(train_set["x"], train_set["y"], train_set["z"], train_set["cost"],
                                              val_set["x"], val_set["y"], val_set["z"], val_set["cost"],
                                              epochs, batch_size, time_limit)

                test_metrics = trainer.test(test_set["x"], test_set["y"], test_set["z"], test_set["cost"], batch_size)

                partial_metrics[trainer]["MSE"].append(test_metrics["mse"])
                partial_metrics[trainer]["Avg. regret"].append(test_metrics["avg. regret"])
                partial_metrics[trainer]["Avg. relative regret"].append(test_metrics["avg. relative regret"])
                partial_metrics[trainer]["Epochs"].append(train_metrics["epochs"])
                partial_metrics[trainer]["Total runtime"].append(train_metrics["runtime"])
                partial_metrics[trainer]["Pre-training runtime"].append(train_metrics["pre-training runtime"])
                partial_metrics[trainer]["Calls"].append(train_metrics["calls"])
                partial_metrics[trainer]["Avg. calls"].append(train_metrics["avg. calls"])
                partial_metrics[trainer]["Infeasible solutions ratio"].append(test_metrics["infeasible solutions ratio"])

        for trainer in self._trainers:
            row = [trainer.name]
            for key in partial_metrics[trainer]:
                values = partial_metrics[trainer][key]
                row.append(str(round(float(np.mean(values)), 4)) + " +- " + str(round(float(np.std(values)), 4)))
            table.add_row(row)

        return table

    def _read_dataset(self, seed: int) -> tuple[dict, dict, dict, dict]:

        set_seed(seed)

        x, y, z, cost, problem_params = self._dataset.read()

        x, y, z, cost = shuffle_arrays((x, y, z, cost))

        n_samples = x.shape[0]

        train_samples = int(self._train_ratio * n_samples)
        val_samples = int(self._val_ratio * n_samples)

        x_train = x[0:train_samples]
        y_train = y[0:train_samples]
        z_train = z[0:train_samples]
        cost_train = cost[0:train_samples]

        x_val = x[train_samples: train_samples + val_samples]
        y_val = y[train_samples: train_samples + val_samples]
        z_val = z[train_samples: train_samples + val_samples]
        cost_val = cost[train_samples: train_samples + val_samples]

        x_test = x[train_samples + val_samples:]
        y_test = y[train_samples + val_samples:]
        z_test = z[train_samples + val_samples:]
        cost_test = cost[train_samples + val_samples:]

        train_set = {"x": x_train, "y": y_train, "z": z_train, "cost": cost_train}
        test_set = {"x": x_test, "y": y_test, "z": z_test, "cost": cost_test}
        val_set = {"x": x_val, "y": y_val, "z": z_val, "cost": cost_val}

        return train_set, test_set, val_set, problem_params
