from continual_rl.utils.utils import Utils
from continual_rl.utils.common_exceptions import OutputDirectoryNotSetException


class InvalidTaskAttributeException(Exception):
    def __init__(self, error_msg):
        super().__init__(error_msg)


class Experiment(object):
    def __init__(self, tasks, continual_testing_freq=None, cycle_count=1):
        """
        The Experiment class contains everything that should be held consistent when the experiment is used as a
        setting for a baseline.

        A single experiment can cover tasks with a variety of action spaces. It is up to the policy on how they wish
        to handle this, but what the Experiment does is create a dictionary mapping action_space_id to action space, and
        ensures that all tasks claiming the same id use the same action space.

        The observation size and time batch sizes are both restricted to being the same for all tasks. This
        initialization will assert if this is violated.

        :param tasks: A list of subclasses of TaskBase. These need to have a consistent observation size.
        :param continual_testing_freq: If not None, all non-eval tasks will be periodically "force-eval"'d to evaluate
        forward/reverse transfer, every x timesteps
        :param cycle_count: The number of loops through all the tasks we make
        :param output_dir: The directory in which logs will be stored.
        """
        self.tasks = tasks
        self.action_spaces = self._get_action_spaces(self.tasks)
        self.observation_space = self._get_common_attribute([task.observation_space for task in self.tasks])
        self.time_batch_size = self._get_common_attribute([task.time_batch_size for task in self.tasks])
        self._output_dir = None
        self._continual_testing_freq = continual_testing_freq  # TODO: will be timesteps, currently train-steps
        self._cycle_count = cycle_count

    def set_output_dir(self, output_dir):
        self._output_dir = output_dir

    @property
    def output_dir(self):
        if self._output_dir is None:
            raise OutputDirectoryNotSetException("Output directory not set, but is attempting to be used. "
                                                 "Call set_output_dir.")
        return self._output_dir

    @property
    def _logger(self):
        return Utils.create_logger(f"{self.output_dir}/core_process.log")

    @classmethod
    def _get_action_spaces(self, tasks):
        action_space_map = {}  # Maps task id to its action space

        for task in tasks:
            if task.action_space_id not in action_space_map:
                action_space_map[task.action_space_id] = task.action_space
            elif action_space_map[task.action_space_id] != task.action_space:
                raise InvalidTaskAttributeException(f"Action sizes were mismatched for task {task.action_space_id}")

        return action_space_map

    @classmethod
    def _get_common_attribute(self, task_attributes):
        common_attribute = None

        for task_attribute in task_attributes:
            if common_attribute is None:
                common_attribute = task_attribute

            if task_attribute != common_attribute:
                raise InvalidTaskAttributeException("Tasks do not have a common attribute.")

        return common_attribute

    def _run_forced_tests(self, task_run_id, policy, summary_writer, total_timesteps):
        # Run a small amount of eval on all non-eval, not-currently-running tasks
        for test_task_run_id, test_task in enumerate(self.tasks):
            if test_task_run_id != task_run_id and not test_task._task_spec.eval_mode:
                self._logger.info(f"Continual eval for task: {test_task_run_id}")

                # Don't increment the total_timesteps counter for continual tests
                test_task_runner = self.tasks[test_task_run_id].run(test_task_run_id, policy, summary_writer,
                                                                    output_dir=self.output_dir,
                                                                    force_eval=True,
                                                                    timestep_log_offset=total_timesteps)
                test_complete = False
                while not test_complete:
                    try:
                        next(test_task_runner)
                    except StopIteration:
                        test_complete = True

                self._logger.info(f"Completed continual eval for task: {test_task_run_id}")

    def _run(self, policy, summary_writer):
        total_train_timesteps = 0

        for cycle_id in range(self._cycle_count):
            for task_run_id, task in enumerate(self.tasks):
                # Run the current task as a generator so we can intersperse testing tasks during the run
                self._logger.info(f"Cycle {cycle_id}, Starting task {task_run_id}")
                task_complete = False
                task_runner = task.run(task_run_id, policy, summary_writer, output_dir=self.output_dir, force_eval=False,
                                       timestep_log_offset=total_train_timesteps)
                task_timesteps = 0  # What timestep the task is currently on. Cumulative during a task.
                task_step = 0  # For scheduling continual testing
                continual_freq = self._continual_testing_freq

                while not task_complete:
                    try:
                        task_timesteps = next(task_runner)
                    except StopIteration:
                        task_complete = True

                    # If we're already doing eval, don't do a forced eval run (nothing has trained to warrant it anyway)
                    if continual_freq is not None and not task._task_spec.eval_mode:
                        # Evaluate intermittently. Every time is too slow
                        if task_step % continual_freq == 0:
                            self._run_forced_tests(task_run_id, policy, summary_writer,
                                                   total_train_timesteps + task_timesteps)

                    task_step += 1

                self._logger.info(f"Cycle {cycle_id} Task {task_run_id} complete")

                # Only increment the global counter for training (it's supposed to represent number of frames *trained on*)
                if not task._task_spec.eval_mode:
                    total_train_timesteps += task_timesteps

    def try_run(self, policy, summary_writer):
        try:
            self._run(policy, summary_writer)
        except Exception as e:
            self._logger.exception(f"Failed with exception: {e}")
            policy.shutdown()

            raise e
