import numpy as np
import nninfo
import os
import re

log = nninfo.log.get_logger(__name__)

__all__ = ["Experiment", "Schedule"]


class Experiment:
    """
    Manages the entire experiment, is directly in contact with the user script but also
    the main components.
    After connecting the components, user script should use preferably methods of this class.

    1) is given an instance of TaskManager to feed data into the program and split the dataset
       if necessary

    2) is given an instance of NeuralNetwork in model.py that will be trained and tested on

    3) is given an instance of Trainer that is responsible
       for each chapter (predefined set of training epochs)
       of the training process (gets data from TaskManager via dataset_name).

    4) is given an instance of Tester that is called after each chunk of training is done
       (gets data from TaskManager)

    5) can be given an instance of Schedule. This makes automating the experiment easier.

    6) creates an instance of CheckpointManager that stores the main experiment parameters
       for loading afterwards. This can be used to
       a) analyze the training afterwards
       b) resume training from last or earlier chapters
    """

    def __init__(self, experiment_id=None, overwrite_experiment_dir=False, load=False):
        """
        Init instance of experiment. Also sets up logging, so best use at the beginning of the script.

        Args:
            experiment_id (int or str, optional): Every instance is recommended to have an unique id,
                then it can be used for naming of files etc.. If guild is used, it is recommended
                to pass the guild run id (or a short version of it) here as a str.
            overwrite_experiment_dir (bool, optional): If true,
                the directory corresponding to this experiment_id
                is overwritten in case it already exists. A user prompt is triggered to
                avoid unexpected loss of data.
        """
        if experiment_id is None:
            # create a new experiment directory with the next id possible
            directory_reader = nninfo.file_io.FileManager("../experiments/", read=True)
            exp_list = directory_reader.list_subdirs_in_dir()
            id_list = []
            for e in exp_list:
                id_list.append(
                    int(e.split("/")[-1].lstrip("exp_"))
                )  # get only the experiment_id
            id_list.sort()
            experiment_id = id_list[-1] + 1
            log.info("Initializing new experiment with id {}.".format(experiment_id))

            standard_dir_maker = nninfo.file_io.FileManager(
                "../experiments/", write=True
            )
            self._experiment_dir = standard_dir_maker.make_experiment_dir(
                experiment_id, overwrite=False
            )
        else:
            if load:
                # load from experiment directory
                dirname = "exp_" + (
                    "{:04d}".format(int(experiment_id))
                    if isinstance(experiment_id, int)
                    else str(experiment_id)
                )
                log.info("Initializing experiment by loading {}".format(dirname))

                self._experiment_dir = os.path.dirname(
                    __file__
                ) + "/" "../experiments/{}/".format(dirname)
            else:
                # create new experiment directory. if overwrite_experiment_dir is True,
                # user possibly is asked whether to overwrite an existing directory
                standard_dir_maker = nninfo.file_io.FileManager(
                    "../experiments/", write=True
                )
                self._experiment_dir = standard_dir_maker.make_experiment_dir(
                    experiment_id, overwrite=overwrite_experiment_dir
                )

        self.experiment_id = experiment_id
        self._run_id = 0

        print(self._experiment_dir)
        nninfo.log.add_exp_file_handler(self._experiment_dir)
        log.info("Starting exp_{}".format(experiment_id))

        # create placeholder for relevant objects
        self._task = None
        self._network = None
        self._trainer = None
        self._tester = None
        self._schedule = None

        # create checkpoint_manager that saves checkpoints to _experiment_dir for the entire experiment
        self._checkpoint_manager = nninfo.file_io.CheckpointManager()
        self._checkpoint_manager.parent = self
        self._components_locked = False
        self._closed = False

        if load:
            self.load_components()
            # load the last experiment checkpoint
            checkpoints = self._checkpoint_manager.list_checkpoints()
            if checkpoints != list():
                self._checkpoint_manager.load(filename=checkpoints[-1])

    def lock_and_save_components(self):
        """
        For experiment consistency reasons, the experiment components are
        locked using
        this function once the first epoch of training
        is starting. At the moment,
        it cannot (should not) be unlocked again once locked.

        Saves all experiment components once they are locked.
        """
        self.save_components()
        self._components_locked = True

    def save_components(self):
        """
        Saves all main components of the experiment to components/.
        Components can also be saved manually before locking.

        Saves network parameters to network.json.

        Saves trainer parameters to trainer.json.

        Saves tester parameters to tester.json.

        Saves schedule to schedule.json.

        Saves task structure (indices) to task.json.
        Also saves the full dataset if dataset requires it.
        """
        component_dir = self._experiment_dir + "components/"
        component_saver = nninfo.file_io.FileManager(component_dir, write=True)

        if self._network is not None:
            log.info("Saving network settings to network.json.")
            network_settings = self._network.get_network_settings()
            component_saver.write(network_settings, "network.json")
        if self._task is not None:
            log.info("Calling task.save().")
            self._task.save(component_dir)
        if self._trainer is not None:
            log.info("Saving training settings to trainer.json.")
            training_settings = self._trainer.get_training_parameters()
            component_saver.write(training_settings, "trainer.json")
        if self._schedule is not None:
            log.info("Calling schedule.save()")
            self._schedule.save(component_dir)
        if self._tester is not None:
            log.info("Saving Testing parameters to tester.json")
            testing_settings = self._tester.get_testing_parameters()
            component_saver.write(testing_settings, "tester.json")
        # TODO: check for overwriting.

    def load_components(self):
        component_dir = self._experiment_dir + "components/"
        component_loader = nninfo.file_io.FileManager(component_dir, read=True)
        component_list = component_loader.list_files_in_dir()
        if "schedule.json" in component_list:
            schedule = nninfo.exp.Schedule()
            schedule.load(component_dir)
            self.connect(schedule=schedule)
            log.info("Successfully loaded and connected tester.")
        if "network.json" in component_list:
            d = component_loader.read("network.json")
            if d["net_type"] == "noisy_feedforward":
                network = nninfo.model.NoisyNeuralNetwork(**d)
            else:
                network = nninfo.model.NeuralNetwork(**d)
            self.connect(network=network)
            log.info("Successfully loaded and connected network.")
        if "task.json" in component_list:
            task = nninfo.task.TaskManager(reload=True, component_dir=component_dir)
            self.connect(task=task)
            log.info("Successfully loaded and connected task.")
        if "tester.json" in component_list:
            d = component_loader.read("tester.json")
            tester = nninfo.exp_comp.Tester(d["dataset_name"])
            self.connect(tester=tester)
            log.info("Successfully loaded and connected tester.")
        if "trainer.json" in component_list:
            d = component_loader.read("trainer.json")
            trainer = nninfo.exp_comp.Trainer()

            trainer.set_training_parameters(
                dataset_name=d["dataset_name"],
                optim_str=d["optim_str"],
                loss_str=d["loss_str"],
                lr=d["lr"],
                shuffle=d["shuffle"],
                batch_size=d["batch_size"],
                n_epochs_chapter=d["n_epochs_chapter"],
                quantizer=d.get("quantizer", None)
            )
            self.connect(trainer=trainer)
            log.info("Successfully loaded and connected task.")

        if (
            self.all_key_components_connected
            and self._checkpoint_manager.list_checkpoints() != []
        ):
            trainer.initialize_components()

    def run_following_schedule(self, continue_run=False, chapter_ends=None, use_cuda=False, use_ipex=False):

        if chapter_ends is None:
            if self._schedule is None:
                log.error(
                    "You can only use run_following_schedule if you have "
                    + "a schedule connected to the experiment or pass a schedule."
                )
                return
            else:
                chapter_ends = self._schedule.chapter_ends

        if continue_run:
            log.warning(
                "Continuing run {} at chapter {}.".format(self.run_id, self.chapter_id)
            )
        else:
            if self.chapter_id != 0 or self.epoch_id != 0:
                log.error(
                    "You can only use run_following_schedule if you reset the training to a new run."
                )
                return

        info = "Starting training on run {} starting at chapter {}, epoch {}".format(
            self.run_id, self.chapter_id, self.epoch_id
        )
        log.info(info)
        print(info)
        for c in range(self.chapter_id, len(chapter_ends) - 1):
            if chapter_ends[c] != self.epoch_id:
                log.error(
                    "Error on continuing schedule,"
                    + " schedule.chapter_ends[{}]={}".format(c, chapter_ends[c])
                    + " and experiment's self.epoch_id={} ".format(self.epoch_id)
                    + "do not fit together."
                )
                raise ValueError
            # running c+1:
            n_epochs_chapter = chapter_ends[c + 1] - chapter_ends[c]
            self._trainer.train_chapter(n_epochs_chapter=n_epochs_chapter, use_cuda=use_cuda, use_ipex=use_ipex)

    def continue_runs_following_schedule(self, runs_id_list, stop_epoch, schedule=None, use_cuda=False):
        if schedule is None:
            if self._schedule is None:
                log.error(
                    "You can only use run_following_schedule if you have "
                    + "a schedule connected to the experiment or pass a schedule."
                )
                return
            else:
                schedule = self._schedule

        cut_off = np.argmax(np.array(schedule.chapter_ends_continued) > stop_epoch)
        chapter_ends = schedule.chapter_ends_continued[:cut_off]
        for run_id in runs_id_list:
            last_ckpt = self._checkpoint_manager.list_checkpoints([run_id])[-1]
            self._checkpoint_manager.load(filename=last_ckpt)
            self.run_following_schedule(continue_run=True, chapter_ends=chapter_ends, use_cuda=use_cuda)

    def rerun(self, n_runs, like_run_id=None):
        """
        Reruns the experiment for a given number of runs. For doing this, it uses
        the checkpoints of a previous run that are found in the checkpoints directory
        and produces the same checkpoints with a new network initialization.

        Args:
            n_runs (int): Number of additional runs that should be performed.
            like_run_id (int): Run id of the run that should be replicated. If not set,
                defaults to run_id=0.
        """
        log.info("Setting up rerun of experiment: n_runs=" + str(n_runs))

        def get_ids_from_filename(filename):
            substrings = re.findall(r"\d+", filename)
            return [int(x) for x in substrings]

        if like_run_id is None:
            like_run_id = 0

        ckpt_list = self._checkpoint_manager.list_checkpoints(run_ids=[like_run_id])

        log.info("Extracting schedule from checkpoint directory.")
        rerun_chapter_ends = []
        for c in ckpt_list:
            _, _, epoch_id = get_ids_from_filename(c)
            rerun_chapter_ends.append(epoch_id)

        log.warning("Extracted schedule: " + str(rerun_chapter_ends))

        last_run_id, _, _ = get_ids_from_filename(
            self._checkpoint_manager.list_checkpoints()[-1]
        )

        for i in range(n_runs):
            # getting everything to the same state as requested
            self.load_checkpoint(run_id=like_run_id, chapter_id=0)
            # get the new run_id
            self._run_id = last_run_id + 1
            # reinitialize the network with a new seed
            self._network.init_weights(randomize_seed=True)
            self.run_following_schedule(chapter_ends=rerun_chapter_ends)
            last_run_id = self._run_id

    def disconnect(self):
        log.info("Disconnecting all components from experiment.")
        for attr, value in self.__dict__.items():
            if isinstance(value, nninfo.exp_comp.ExperimentComponent):
                value.parent = None
                self.__dict__[attr] = None
        nninfo.log.remove_exp_file_handler(self._experiment_dir)

    def save_executing_code(self):
        """
        **********************
        * NOT FUNCTIONAL YET *
        **********************
        For full restoring capability of the experiment this function
        allows the storage of the entire
        setup, including the code that ran the experiment.

        """

        def execution_type():
            try:
                shell = get_ipython().__class__.__name__
                if shell == "google.colab._shell":
                    return ""
                elif shell == "ZMQInteractiveShell":
                    return "notebook"  # Jupyter notebook or qtconsole
                elif shell == "TerminalInteractiveShell":
                    return "ipython_terminal"  # Terminal running IPython
                else:
                    return "other"  # Other type (?)
            except NameError:
                return "script"  # Probably standard Python interpreter

    def save_checkpoint(self):
        """
        Calls the CheckpointManager to save the current state of the network and the optimizer
        (is necessary for optimizers that depend on their own past)
        together with the state of random number generators (of numpy and torch).
        """
        self._checkpoint_manager.save()

    def load_checkpoint(self, run_id=None, chapter_id=None, filename=None):
        """
        Load an old checkpoint using either (run_id AND chapter_id) or filename.

        Keyword Args:
            run_id (int): identifying int for the run.
            chapter_id (int): identifying int for the chapter.
            filename (str): name of the file that is supposed to be loaded. Can
                include * symbols for uncertain/unimportant parts of the filename.
        """
        if run_id is not None and chapter_id is not None:
            ret_code = self._checkpoint_manager.load(run_id, chapter_id)
        elif filename is not None:
            ret_code = self._checkpoint_manager.load(filename=filename)
        else:
            raise AttributeError
        
        if not ret_code == 0:
            raise ValueError("Could not load checkpoint.")

    def list_checkpoints(self, run_ids=None, chapter_ids=None):
        """
        List all checkpoints that CheckpointManager can find in the experiments checkpoint
        directory.

        Keyword Args:
            run_ids (list): only take runs contained in this list
            chapter_ids (list): only take chapters contained in this list
        """
        return self._checkpoint_manager.list_checkpoints(run_ids, chapter_ids)

    def connect(
        self, network=None, task=None, trainer=None, tester=None, schedule=None
    ):
        """
        Only function that allows to connect new components to this experiment.
        If some components are already
        in place, the connection to those is overwritten if new components are given.
        Typically, define your components outside and connect them using this function.

        Keyword Args:
            network (nninfo NeuralNetwork): The NeuralNetwork to train and test on. Make sure the input
                and output fits to your dataset in the task manager.
            task (nninfo TaskManager): The task manager, where your dataset and subdatasets are stored.
            trainer (nninfo Trainer): Trainer that will carry out training chapters you define.
            tester (nninfo Tester): Tester that will perform tests, during training or afterwards.
            schedule (nninfo Schedule, optional):
        """
        if self._components_locked:
            log.error("Experiment components are locked. Not allowed to be changed.")
            log.error("As a tip, connect existing components to new experiment.")
            raise PermissionError

        if isinstance(network, nninfo.model.NeuralNetwork):
            self._network = network
        if isinstance(task, nninfo.task.TaskManager):
            self._task = task
        if isinstance(trainer, nninfo.exp_comp.Trainer):
            self._trainer = trainer
        if isinstance(tester, nninfo.exp_comp.Tester):
            self._tester = tester
        if isinstance(schedule, Schedule):
            self._schedule = schedule
        if self.all_key_components_connected and self._check_in_out_dimensions():
            self._connect()

    def _connect(self):
        """
        Helper function to set this experiment as the parent of each component.
        This is used often by the other components to use each others functions
        and is the only connection between them at the moment.
        """
        self._network.parent = self
        self._task.parent = self
        self._trainer.parent = self
        self._tester.parent = self

    def _check_in_out_dimensions(self):
        net_input_dim, net_output_dim = self.network.get_input_output_dimensions()
        task_input_dim, task_output_dim = self.task.get_input_output_dimensions()
        ok_flag = True
        if net_input_dim != task_input_dim:
            print(
                "Input dimension of network ({}) and task ({}) do not coincide.".format(
                    net_input_dim, task_input_dim
                )
            )
            ok_flag = False
        if net_output_dim != task_output_dim:
            print(
                "Output dimension of network ({}) and task ({}) do not coincide.".format(
                    net_output_dim, task_output_dim
                )
            )
            ok_flag = False
        return ok_flag

    @property
    def all_key_components_connected(self):
        """
        Property (function that is disguised as an object variable)
        that checks whether all components for this experiment
        are already in place. (For now, all are needed to start the experiment,
        this could be changed in the future though, for example Test might not
        be relevant for every experiment.)

        Returns:
             (bool): All components are connected, True or False.
        """

        all_comp_flag = True
        if self._task is None:
            log.info("Task still missing.")
            all_comp_flag = False
        if self._network is None:
            log.info("Network still missing.")
            all_comp_flag = False
        if self._trainer is None:
            log.info("Trainer still missing.")
            all_comp_flag = False
        if self._tester is None:
            log.info("Tester still missing.")
            all_comp_flag = False
        return all_comp_flag

    @property
    def experiment_dir(self):
        return self._experiment_dir

    @property
    def id(self):
        return self.experiment_id

    def checkpoint_loader_sets_run_id(self, run_id):
        log.info("Run_id is being changed from {} to {}.".format(self._run_id, run_id))
        self._run_id = run_id

    @property
    def run_id(self):
        return self._run_id

    @property
    def chapter_id(self):
        return self.trainer.n_chapters_trained

    @property
    def epoch_id(self):
        return self.trainer.n_epochs_trained

    @property
    def network(self):
        return self._network

    @property
    def trainer(self):
        return self._trainer

    @property
    def task(self):
        return self._task

    @property
    def tester(self):
        return self._tester

    @property
    def schedule(self):
        return self._schedule

    @property
    def components_locked(self):
        return self._components_locked


class Schedule:
    """
    Can create epoch lists for preplanned experiment chapters. These chapters are the main
    structure of the training period of the experiment and allow for spaced saving of
    checkpoints.

    The plan is to end a chapter of the experiment when a epoch contained in the chapter_ends
    variable of this class is reached. This is not applied yet, but the class
    is already able to create log spaced and lin spaced numbers of epochs,
    which should then help with the actual experiment run. However, for the log-spaced chapter
    planning, the number of chapter_ends can be lower than the number of chapters that are
    given as n_chapter_wished.

    Does not need to inherit from ExperimentComponent, because it is not calling anything else.
    """

    def __init__(self):
        self.chapter_ends = None
        self.chapter_ends_continued = None

    def create_log_spaced_chapters(self, n_epochs, n_chapters_wished):
        """
        Function that creates a list of numbers which are the epoch indices where chapters
        are ended. The indices are created logarithmically spaced over the total number of
        epochs for this experiment (n_epochs).

        Args:
            n_epochs (int): Total number of epochs for this experiment.
            n_chapters_wished (int): Number of chapters the experiment should take to reach the
                total number of epochs n_epochs.

        Sets self.chapter_ends to a list of these indices (int).

        Sets self.chapter_ends_continued to a list of a continued of chapter_ends
        until n_epochs*n_epochs (int).
        """

        def log_space(n_e, n_c):
            end = np.log10(n_e)
            epochs = np.logspace(0, end, n_c + 1, endpoint=True)
            epochs = np.round(epochs).astype(int)
            epochs = np.unique(epochs)
            epochs = np.insert(epochs, 0, 0)  # add a 0 in the front for consistency
            return epochs

        self.chapter_ends = log_space(n_epochs, n_chapters_wished).tolist()
        self.chapter_ends_continued = log_space(
            n_epochs * n_epochs, n_chapters_wished * 2
        ).tolist()

    def create_lin_spaced_chapters(self, n_epochs, n_chapters_wished):
        """
        Function that creates a list of numbers, which are the epoch indices where chapters
        are ended. The indices are created linearly spaced over the total number of
        epochs for this experiment (n_epochs).

        Args:
            n_epochs (int): Total number of epochs for this experiment.
            n_chapters_wished (int): Number of chapters the experiment should take to reach the
                total number of epochs n_epochs.

        Sets self.chapter_ends to a list of these indices (int).

        Sets self.chapter_ends_continued to a list of a continued of chapter_ends
        until n_epochs*100 (int).
        """

        def lin_space(n_e, n_c):
            epochs = np.linspace(0, n_e, n_c + 1, endpoint=True)
            epochs = np.round(epochs).astype(int)
            epochs = np.unique(epochs)
            return epochs

        self.chapter_ends = lin_space(n_epochs, n_chapters_wished).tolist()
        self.chapter_ends_continued = lin_space(
            n_epochs * 100, n_chapters_wished * 100
        ).tolist()

    def save(self, path):
        saver = nninfo.file_io.FileManager(path, write=True)
        save_dict = {
            "chapter_ends": self.chapter_ends,
            "chapter_ends_continued": self.chapter_ends_continued,
        }
        saver.write(save_dict, "schedule.json")

    def load(self, path):
        loader = nninfo.file_io.FileManager(path, read=True)
        load_dict = loader.read("schedule.json")
        self.chapter_ends = load_dict["chapter_ends"]
        self.chapter_ends_continued = load_dict["chapter_ends_continued"]

    def __str__(self):
        return str(self.chapter_ends)
