from copy import deepcopy

import numpy as np
import torch
from mlwiz.evaluation.config import Config
from mlwiz.experiment import Experiment
from mlwiz.static import LOSS, MAIN_LOSS, SCORE, MAIN_SCORE
from mlwiz.training.engine import TrainingEngine
from skopt import gp_minimize
from skopt.space import Integer

from model import DynamicArchitecture


class WidthExperiment(Experiment):
    def create_engine(
        self,
        config: Config,
        model: DynamicArchitecture,
    ) -> TrainingEngine:
        """
        Sets the optimizer into the model to allow for dynamic optimization
        of newly inserted layers.
        """
        engine = super().create_engine(
            config,
            model,
        )
        model.set_optimizer(engine.optimizer)
        return engine


class LocalSearchExperiment(Experiment):
    """
    Performs a local search experiment as defined in https://arxiv.org/pdf/2005.02960
    Because local search has not been implemented in mlwiz, this experiment
    assumes that there is only one configuration specified in the configuration file,
    so that only run_test is called. Run_test performs a model selection using
    the value in the `min_neurons` and `max_neurons` keys before selecting the best
    results and running the final re-training.
    Note that this local search experiment focuses on the number of neurons only.

    IMPORTANT: Please look at results in the experiment.log file and not those returned by mlwiz
    """

    def run_valid(self, dataset_getter, logger):
        raise NotImplementedError("Please read the class documentation.")

    def run_test(self, dataset_getter, logger):
        """
        This function returns the training, validation and test results
        for a `final run`.
        **Do not use the test to train the model
        nor for early stopping reasons!**
        **If possible, rely on already available subclasses of this class**.

        It implements a simple training scheme.

        Args:
            dataset_getter (:class:`~mlwiz.data.provider.DataProvider`):
                a data provider
            logger (:class:`~mlwiz.log.logger.Logger`): the logger

        Returns:
            a tuple of training,validation,test dictionaries.
            Each dictionary has two keys:

            * ``LOSS`` (as defined in ``mlwiz.static``)
            * ``SCORE`` (as defined in ``mlwiz.static``)

            For instance, training_results[SCORE] is a dictionary itself with
            other fields to be used by the evaluator.
        """
        batch_size = self.model_config["batch_size"]
        shuffle = (
            self.model_config["shuffle"]
            if "shuffle" in self.model_config
            else True
        )

        dataset_getter.set_inner_k(0)
        dataset_getter.set_outer_k(0)

        config = dict(self.model_config)

        ### ------------------ LOCAL SEARCH PHASE ------------------------- ###
        higher_results_are_better = config["higher_results_are_better"]

        # Instantiate the Dataset
        train_loader = dataset_getter.get_inner_train(
            batch_size=batch_size, shuffle=shuffle
        )
        val_loader = dataset_getter.get_inner_val(
            batch_size=batch_size, shuffle=shuffle
        )

        dim_input_features = dataset_getter.get_dim_input_features()
        dim_target = dataset_getter.get_dim_target()

        min_neurons = config["min_neurons"]
        max_neurons = config["max_neurons"]

        assert (
            "num_hidden_neurons" not in config
        ), "You should not have num_hidden_neurons in your config for local search experiment!"
        local_search_budget = int(config["local_search_budget"])

        if higher_results_are_better:
            # we assume the score needs to be tracked
            best_score = float("-inf")
        else:
            # we assume the loss needs to be tracked
            best_score = float("inf")
        best_num_hidden_neurons = None

        logger.log(f"Starting local search experiment.")
        t = 0
        while t < local_search_budget:
            num_hidden_neurons = int(
                torch.randint(min_neurons, max_neurons, (1,))
            )
            config["num_hidden_neurons"] = num_hidden_neurons

            # Instantiate the Model
            model = self.create_model(dim_input_features, dim_target, config)

            # Instantiate the engine (it handles the training loop and the
            # inference phase by abstracting the specifics)
            training_engine = self.create_engine(config, model)

            (
                train_loss,
                train_score,
                _,  # check the ordering is correct
                val_loss,
                val_score,
                _,
                _,
                _,
                _,
            ) = training_engine.train(
                train_loader=train_loader,
                validation_loader=val_loader,
                test_loader=None,
                max_epochs=config["epochs"],
                logger=logger,
            )

            if higher_results_are_better:
                score = val_score["main_score"].item()
                if score > best_score:
                    best_score = score
                    best_num_hidden_neurons = num_hidden_neurons
                    logger.log(
                        f"Found new best score {score} at iteration {t}. "
                    )
                else:
                    logger.log(f"Reached local minimum, abort search. ")
                    break
            else:
                loss = val_loss["main_loss"].item()
                if loss < best_score:
                    best_score = loss
                    best_num_hidden_neurons = num_hidden_neurons
                    logger.log(
                        f"Found new best loss {loss} at iteration {t}. "
                    )
                else:
                    logger.log(f"Reached local minimum, abort search. ")
                    break

        num_hidden_neurons = best_num_hidden_neurons
        logger.log(f"Running final retrainings with {num_hidden_neurons}.")
        num_retrainings = 10  # HARDCODED

        config["num_hidden_neurons"] = num_hidden_neurons

        # Instantiate the Dataset
        train_loader = dataset_getter.get_outer_train(
            batch_size=batch_size, shuffle=shuffle
        )
        val_loader = dataset_getter.get_outer_val(
            batch_size=batch_size, shuffle=shuffle
        )
        test_loader = dataset_getter.get_outer_test(
            batch_size=batch_size, shuffle=shuffle
        )

        # Call this after the loaders: the datasets may need to be instantiated
        # with additional parameters
        dim_input_features = dataset_getter.get_dim_input_features()
        dim_target = dataset_getter.get_dim_target()

        test_losses, test_scores = [], []
        for t in range(num_retrainings):
            # Instantiate the Model
            model = self.create_model(dim_input_features, dim_target, config)

            # Instantiate the engine (it handles the training loop and the
            # inference phase by abstracting the specifics)
            training_engine = self.create_engine(config, model)

            (
                train_loss,
                train_score,
                _,
                val_loss,
                val_score,
                _,
                test_loss,
                test_score,
                _,
            ) = training_engine.train(
                train_loader=train_loader,
                validation_loader=val_loader,
                test_loader=test_loader,
                max_epochs=config["epochs"],
                logger=logger,
            )

            test_losses.append(test_loss[MAIN_LOSS])
            test_scores.append(test_score[MAIN_SCORE])

            train_res = {LOSS: train_loss, SCORE: train_score}
            val_res = {LOSS: val_loss, SCORE: val_score}
            test_res = {LOSS: test_loss, SCORE: test_score}

        test_loss_mean, test_loss_std = np.mean(test_losses), np.std(
            test_losses
        )
        test_score_mean, test_score_std = np.mean(test_scores), np.std(
            test_scores
        )
        print(
            f"Exp ended with test loss {test_loss_mean} +- {test_loss_std} and test score {test_score_mean} +- {test_score_std}"
        )

        # THIS IS NOT RELEVANT FOR LOCAL SEARCH ANYMORE, WE TWEAKED MLWIZ
        # USE THE LOGGED STRING ABOVE
        return train_res, val_res, test_res


class BayesOptSearchExperiment(Experiment):
    """
    Performs a Bayesian Optimization search experiment using Expected Improvement
    as the acquisition function.
    Because BO search has not been implemented in mlwiz, this experiment
    assumes that there is only one configuration specified in the configuration file,
    so that only run_test is called. Run_test performs a model selection using
    the value in the `min_neurons` and `max_neurons` keys before selecting the best
    results and running the final re-training.
    Note that this local search experiment focuses on the number of neurons only.

    IMPORTANT: Please look at results in the experiment.log file and not those returned by mlwiz
    """

    def run_valid(self, dataset_getter, logger):
        raise NotImplementedError("Please read the class documentation.")

    def run_test(self, dataset_getter, logger):
        """
        This function returns the training, validation and test results
        for a `final run`.
        **Do not use the test to train the model
        nor for early stopping reasons!**
        **If possible, rely on already available subclasses of this class**.

        It implements a simple training scheme.

        Args:
            dataset_getter (:class:`~mlwiz.data.provider.DataProvider`):
                a data provider
            logger (:class:`~mlwiz.log.logger.Logger`): the logger

        Returns:
            a tuple of training,validation,test dictionaries.
            Each dictionary has two keys:

            * ``LOSS`` (as defined in ``mlwiz.static``)
            * ``SCORE`` (as defined in ``mlwiz.static``)

            For instance, training_results[SCORE] is a dictionary itself with
            other fields to be used by the evaluator.
        """
        batch_size = self.model_config["batch_size"]
        shuffle = (
            self.model_config["shuffle"]
            if "shuffle" in self.model_config
            else True
        )

        dataset_getter.set_inner_k(0)
        dataset_getter.set_outer_k(0)

        config = dict(self.model_config)

        ### ------------------ LOCAL SEARCH PHASE ------------------------- ###
        higher_results_are_better = config["higher_results_are_better"]

        # Instantiate the Dataset
        train_loader = dataset_getter.get_inner_train(
            batch_size=batch_size, shuffle=shuffle
        )
        val_loader = dataset_getter.get_inner_val(
            batch_size=batch_size, shuffle=shuffle
        )

        dim_input_features = dataset_getter.get_dim_input_features()
        dim_target = dataset_getter.get_dim_target()

        min_neurons = config["min_neurons"]
        max_neurons = config["max_neurons"]

        assert (
            "num_hidden_neurons" not in config
        ), "You should not have num_hidden_neurons in your config for local search experiment!"
        bo_search_budget = int(config["bo_search_budget"])

        logger.log(f"Starting Bayesian Optimization experiment.")

        def objective(params):
            num_hidden_neurons = params[0]
            copy_config = deepcopy(config)
            copy_config["num_hidden_neurons"] = num_hidden_neurons

            # Instantiate the Model
            model = self.create_model(
                dim_input_features, dim_target, copy_config
            )

            # Instantiate the engine (it handles the training loop and the
            # inference phase by abstracting the specifics)
            training_engine = self.create_engine(copy_config, model)

            (
                train_loss,
                train_score,
                _,  # check the ordering is correct
                val_loss,
                val_score,
                _,
                _,
                _,
                _,
            ) = training_engine.train(
                train_loader=train_loader,
                validation_loader=val_loader,
                test_loader=None,
                max_epochs=copy_config["epochs"],
                logger=logger,
            )

            if higher_results_are_better:
                # assume accuracy
                return -float(val_score["main_score"])
            else:
                # assume loss
                return float(val_loss["main_loss"])

        # Perform Bayesian optimization
        result = gp_minimize(
            objective,  # Objective function
            [
                Integer(min_neurons, max_neurons, name="num_hidden_neurons")
            ],  # Hyperparameter space
            n_calls=bo_search_budget,  # Number of evaluations
            n_random_starts=1,
            random_state=int(torch.initial_seed()),  # Ensures reproducibility
        )

        num_hidden_neurons = result.x[0]
        logger.log(f"Best num_neurons found by BO are {num_hidden_neurons}.")
        if higher_results_are_better:
            logger.log(f"Best score found by BO are {-result.fun}.")
        else:
            logger.log(f"Best loss found by BO are {result.fun}.")

        logger.log(f"Running final retrainings with {num_hidden_neurons}.")
        num_retrainings = 10  # HARDCODED

        config["num_hidden_neurons"] = num_hidden_neurons

        # Instantiate the Dataset
        train_loader = dataset_getter.get_outer_train(
            batch_size=batch_size, shuffle=shuffle
        )
        val_loader = dataset_getter.get_outer_val(
            batch_size=batch_size, shuffle=shuffle
        )
        test_loader = dataset_getter.get_outer_test(
            batch_size=batch_size, shuffle=shuffle
        )

        # Call this after the loaders: the datasets may need to be instantiated
        # with additional parameters
        dim_input_features = dataset_getter.get_dim_input_features()
        dim_target = dataset_getter.get_dim_target()

        test_losses, test_scores = [], []
        for t in range(num_retrainings):
            # Instantiate the Model
            model = self.create_model(dim_input_features, dim_target, config)

            # Instantiate the engine (it handles the training loop and the
            # inference phase by abstracting the specifics)
            training_engine = self.create_engine(config, model)

            (
                train_loss,
                train_score,
                _,
                val_loss,
                val_score,
                _,
                test_loss,
                test_score,
                _,
            ) = training_engine.train(
                train_loader=train_loader,
                validation_loader=val_loader,
                test_loader=test_loader,
                max_epochs=config["epochs"],
                logger=logger,
            )

            test_losses.append(test_loss[MAIN_LOSS])
            test_scores.append(test_score[MAIN_SCORE])

            train_res = {LOSS: train_loss, SCORE: train_score}
            val_res = {LOSS: val_loss, SCORE: val_score}
            test_res = {LOSS: test_loss, SCORE: test_score}

        test_loss_mean, test_loss_std = np.mean(test_losses), np.std(
            test_losses
        )
        test_score_mean, test_score_std = np.mean(test_scores), np.std(
            test_scores
        )
        print(
            f"Exp ended with test loss {test_loss_mean} +- {test_loss_std} and test score {test_score_mean} +- {test_score_std}"
        )

        # THIS IS NOT RELEVANT FOR BO SEARCH ANYMORE, WE TWEAKED MLWIZ
        # USE THE LOGGED STRING ABOVE
        return train_res, val_res, test_res
