import logging
import os

import utils

import numpy
import torch


class Experiment:

    ID = 0
    DELTA_SEED = 1
    MAX_TRAINING_RETRIES = 3

    def __init__(
        self,
        trainer, model, acquiref, max_labelled,
        reinit_model_per_acquisition_step,
        first_training_epochs, min_training_score, result,
        save_first_labelled_pool_path=None
    ):
        self._trainer = trainer
        self._model = model
        self._acquiref = acquiref
        self._testset = None
        self._max_labelled = max_labelled
        self._reinit_model_per_acquisition_step = reinit_model_per_acquisition_step
        self._first_training_epochs = first_training_epochs
        self._result = result
        self._min_training_score = min_training_score
        self._save_first_labelled_pool_path = save_first_labelled_pool_path

        self._id = Experiment.ID
        logging.info("Created experiment {}:".format(self._id))
        logging.info(" - Model: {}".format(self._model.get_descriptive_name()))
        logging.info(" - Acquisition function: {}".format(self._acquiref.get_descriptive_name()))
        Experiment.ID += 1

    @staticmethod
    def repeat(datapool_creator, init_labelled, n, expts, start_seed):
        logging.info("Loading {} test set...".format(
            datapool_creator.get_name()
        ))

        testset = datapool_creator.get_full_testset()
        for expt in expts:
            expt._set_testset(testset)

        seed = start_seed
        for i in range(n):
            logging.info("Experiment repeat {}/{}".format(i+1, n))

            utils.random_seed(seed)
            datapool = datapool_creator.create(init_labelled)

            # NOTE: must update seed before, otherwise insidious bug
            # where random picks the exact same labelled data as
            # they are in the same random state.
            seed += Experiment.DELTA_SEED
            for expt in expts:
                utils.random_seed(seed)
                expt._run_active_learning(datapool, i)

    # === PROTECTED ===

    def _set_testset(self, testset):
        assert self._testset is None
        self._testset = testset

    def _run_active_learning(self, datapool, expt_repeat_id):
        logging.info("Running: experiment {}".format(self._id))
        datapool = datapool.copy()

        model = None
        assert self._reinit_model_per_acquisition_step == (
            self._first_training_epochs is None
        )
        epochs = self._first_training_epochs

        while True:

            trainer = None
            finished_training = False

            for _ in range(Experiment.MAX_TRAINING_RETRIES):
                if model is None or self._reinit_model_per_acquisition_step:
                    model = self._model.init()

                trainer = self._trainer.init(model, epochs)
                logging.info("Training with expected score >= {:.4f}".format(self._min_training_score))
                train_score = trainer.train(datapool)

                if train_score < self._min_training_score:
                    logging.info(
                        "Training score ({:.4f}) was below expectations. Retraining...".format(
                            train_score
                        )
                    )
                else:
                    finished_training = True
                    break

            assert finished_training, "Hyperparameters did not allow for sufficient training."

            epochs = None  # use default epochs after this first round
            acquiref = self._acquiref.init(model)
            score = trainer.evaluate(self._testset)

            self._result.add_entry(
                datapool=datapool,
                model=model,
                acquiref=acquiref,
                expt_repeat_id=expt_repeat_id,
                score=score
            )

            if datapool.count_labelled() < self._max_labelled:
                acquiref.acquire(datapool)
            else:
                break

        if expt_repeat_id == 0 and self._save_first_labelled_pool_path is not None:
            self._save_datapool(
                datapool, self._save_first_labelled_pool_path, self._acquiref.get_name()
            )

    def _save_datapool(self, datapool, fdir, name):
        X = []
        Y = []
        for x, y in datapool:
            X.append(x)
            Y.append(y)
        X = torch.stack(X, dim=0).numpy()
        Y = torch.LongTensor(Y).numpy()
        outf = os.path.join(fdir, name + ".npy")
        with open(outf, "wb") as f:
            assert len(X) == len(Y), "{} != {}".format(len(X), len(Y))
            numpy.save(f, X)
            numpy.save(f, Y)
        logging.info("Saved labelled data (n={}) to: {}".format(len(X), outf))

