import os, torch
import json
import numpy as np
from continual_rl.experiments.run_metadata import RunMetadata
from continual_rl.utils.utils import Utils
from continual_rl.utils.common_exceptions import OutputDirectoryNotSetException
from metrics import Average_Performance, Average_Forgetting

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

raw_experiment = {
    # "experiment": "minihack_nav_paired_2_cycles",  # specify the experiment name
    "atari_6_tasks_5_cycles": [
        "SpaceInvadersNoFrameskip-v4",
        "KrullNoFrameskip-v4",
        "BeamRiderNoFrameskip-v4",
        "HeroNoFrameskip-v4",
        "StarGunnerNoFrameskip-v4",
        "MsPacmanNoFrameskip-v4"],
    "minihack_nav_paired_2_cycles":
        [

            ("Room-Random-5x5-v0", "Room-Random-15x15-v0"),
            ("Corridor-R2-v0", "Corridor-R5-v0"),
            ("Room-Dark-5x5-v0", "Room-Dark-15x15-v0"),
            ("Corridor-R3-v0", "Corridor-R5-v0"),
            ("Room-Monster-5x5-v0", "Room-Monster-15x15-v0"),
            ("CorridorBattle-v0", "CorridorBattle-Dark-v0"),
            ("Room-Trap-5x5-v0", "Room-Trap-15x15-v0"),
            ("HideNSeek-v0", "HideNSeek-Big-v0"),
            ("Room-Ultimate-5x5-v0", "Room-Ultimate-15x15-v0"),
            ("HideNSeek-Lava-v0", "HideNSeek-Big-v0"),
        ],
    # Add more parameters as needed
}






class InvalidTaskAttributeException(Exception):
    def __init__(self, error_msg):
        super().__init__(error_msg)


class Experiment(object):
    def __init__(self, tasks, continual_testing_freq=None, cycle_count=1,num_timesteps = None):
        """
        The Experiment class contains everything that should be held consistent when the experiment is used as a
        setting for a baseline.

        A single experiment can cover tasks with a variety of action spaces. It is up to the policy on how they wish
        to handle this, but what the Experiment does is create a dictionary mapping action_space_id to action space, and
        ensures that all tasks claiming the same id use the same action space.

        The observation space and time batch sizes are both restricted to being the same for all tasks. This
        initialization will assert if this is violated.

        :param tasks: A list of subclasses of TaskBase. These need to have a consistent observation space.
        :param output_dir: The directory in which logs will be stored.
        :param continual_testing_freq: The number of timesteps between evaluation steps on the not-currently-training
        tasks.
        :param cycle count: The number of times to cycle through the list of tasks.
        """
        self.tasks = tasks
        self.action_spaces = self._get_action_spaces(self.tasks)
        self.observation_space = self._get_common_attribute(
            [task.observation_space for task in self.tasks]
        )
        self.task_ids = [task.task_id for task in tasks]
        self._output_dir = None
        self._continual_testing_freq = continual_testing_freq
        self._cycle_count = cycle_count
        self.num_timesteps = num_timesteps

    def set_output_dir(self, output_dir):
        self._output_dir = output_dir

    @property
    def output_dir(self):
        if self._output_dir is None:
            raise OutputDirectoryNotSetException("Output directory not set, but is attempting to be used. Call set_output_dir.")
        return self._output_dir

    @property
    def _logger(self):
        return Utils.create_logger(f"{self.output_dir}/core_process.log")

    @classmethod
    def _get_action_spaces(self, tasks):
        action_space_map = {}  # Maps task id to its action space

        for task in tasks:
            if task.action_space_id not in action_space_map:
                action_space_map[task.action_space_id] = task.action_space
            elif action_space_map[task.action_space_id] != task.action_space:
                raise InvalidTaskAttributeException(f"Action sizes were mismatched for task {task.action_space_id}")

        return action_space_map

    @classmethod
    def _get_common_attribute(self, task_attributes):
        common_attribute = None

        for task_attribute in task_attributes:
            if common_attribute is None:
                common_attribute = task_attribute

            if task_attribute != common_attribute:
                raise InvalidTaskAttributeException("Tasks do not have a common attribute.")

        return common_attribute

    def _run_continual_eval(self, task_run_id, policy, summary_writer, total_timesteps):
        # Run a small amount of eval on all non-eval, not-currently-running tasks
        print(total_timesteps)
        rewards = []
        for test_task_run_id, test_task in enumerate(self.tasks):
            # if test_task_run_id % 2 == 0:
                # not checking test_task._task_spec.eval_mode anymore since some eval tasks
                # (for train/test pairs) should be continual eval
            if not test_task._task_spec.with_continual_eval:
                continue

            self._logger.info(f"Continual eval for task: {test_task_run_id}")

            # Don't increment the total_timesteps counter for continual tests
            test_task_runner = self.tasks[test_task_run_id].continual_eval(
                test_task_run_id,
                policy,
                summary_writer,
                output_dir=self.output_dir,
                timestep_log_offset=total_timesteps,
            )
            test_complete = False
            while not test_complete:
                try:
                    _, reward_to_return = next(test_task_runner)
                    # reward_to_return = reward_to_return[0]
                    if reward_to_return is not None:
                        print(reward_to_return)
                        print(reward_to_return[0])
                        mean_rewards = np.array(reward_to_return[0]).mean()
                        rewards.append(mean_rewards)
                        # print(rewards)
                except StopIteration:
                    test_complete = True

            self._logger.info(f"Completed continual eval for task: {test_task_run_id}")

        return rewards


    def _run(self, policy, summary_writer):
        # Load as necessary
        policy.load(self.output_dir)
        run_metadata = RunMetadata(self._output_dir)
        start_cycle_id = run_metadata.cycle_id
        start_task_id = run_metadata.task_id
        start_task_timesteps = run_metadata.task_timesteps
        eval_freq = self._continual_testing_freq
        # Only updated after a task is complete. To get the current within-task number, add task_timesteps
        total_train_timesteps = run_metadata.total_train_timesteps

        timesteps_per_save = policy.config.timesteps_per_save
        total_eval_freq = self.num_timesteps  # 设置为每个任务10000步训练
        # print(total_eval_freq)

        metrics = []

        for cycle_id in range(start_cycle_id, self._cycle_count):

            reward_matrix = []
            metrics_0 = []
            # metrics_1 = []

            for task_run_id, task in enumerate(self.tasks[start_task_id:], start=start_task_id):

                reward_vector = []
                # Run the current task as a generator so we can intersperse testing tasks during the run
                self._logger.info(f"Starting cycle {cycle_id} task {task_run_id}")
                task_complete = False
                task_runner = task.run(
                    task_run_id,
                    policy,
                    summary_writer,
                    self.output_dir,
                    timestep_log_offset=total_train_timesteps,
                    task_timestep_start=start_task_timesteps,
                )


                task_timesteps = start_task_timesteps
                last_eval_step = 0
                last_timestep_saved = None

                while not task_complete:
                    try:
                        task_timesteps, _ = next(task_runner)
                    except StopIteration:
                        task_complete = True

                    if not task._task_spec.eval_mode:
                        if last_timestep_saved is None or task_timesteps - last_timestep_saved >= timesteps_per_save or \
                                task_complete:
                            # Save the metadata that allows us to resume where we left off.
                            # This will not copy files in large_file_path such as
                            # replay buffers, and is intended for debugging model changes
                            # at task boundaries.
                            run_metadata.save(cycle_id, task_run_id, task_timesteps, total_train_timesteps)
                            policy.save(self.output_dir, cycle_id, task_run_id, task_timesteps)
                            if task_complete:
                                task_boundary_dir = os.path.join(self.output_dir, f'cycle{cycle_id}_task{task_run_id}')
                                os.makedirs(task_boundary_dir, exist_ok=True)

                                policy.save(task_boundary_dir, cycle_id, task_run_id, task_timesteps)

                            last_timestep_saved = task_timesteps

                    # If we're already doing eval, don't do a forced eval run (nothing has trained to warrant it anyway)
                    # Evaluate intermittently. Every time is too slow
                    if task_timesteps - last_eval_step >= eval_freq:
                        print(task_timesteps)
                        print(last_eval_step)
                        print(eval_freq)

                        avg_rewards = self._run_continual_eval(
                            task_run_id,
                            policy,
                            summary_writer,
                            total_train_timesteps + task_timesteps,
                        )
                        print(avg_rewards)
                        reward_vector.append(avg_rewards)
                        last_eval_step += eval_freq

                if task_timesteps - last_eval_step >= eval_freq:
                    task_timesteps, _,  = next(task_runner)  # 恢复训练步骤


                if task_complete:

                    print(f"cycle {cycle_id} task {task_run_id}已完成。")

                    # 任务完成后，更新训练总步数
                if not task._task_spec.eval_mode:
                    total_train_timesteps += task_timesteps

                        # 重置任务步数
                start_task_timesteps = 0
                total_train_timesteps = 0


                # print(avg_rewards)
                reward_matrix.append(np.array(avg_rewards))
    
            # 每个周期后重置任务ID为0，以便在下一个周期重新开始任务
            start_task_id = 0

            reward_numpy = np.array(reward_matrix)
            self._logger.info(f"reward_numpy1: {reward_numpy}")
            save_path = os.path.join(self.output_dir, f'reward_numpy{cycle_id}.npy')
            os.makedirs(self.output_dir, exist_ok=True)  # 确保路径存在
            np.save(save_path, reward_numpy)
            print(f"数组已保存到: {save_path}")

            performance_0 = Average_Performance(reward_numpy)
            self._logger.info(f"Cycle {cycle_id} performance: {performance_0}")
            print(f"per_0: {performance_0}")
            metrics_0.append(performance_0)

            F_0 = Average_Forgetting(reward_numpy)
            self._logger.info(f"Cycle {cycle_id} forgetting: {F_0}")
            print(f"F_0: {F_0}")
            metrics_0.append(F_0)

            metrics_0 = np.array(metrics_0)
            metrics.append(metrics_0)

        metrics = np.array(metrics)
        self._logger.info(f"Metrics: {metrics}")

        metrics = np.mean(metrics, axis=0)
        self._logger.info(f"Metrics: {metrics}")


    def try_run(self, policy, summary_writer):
        try:
            self._run(policy, summary_writer)
        except Exception as e:
            self._logger.exception(f"Failed with exception: {e}")
            policy.shutdown()

            raise e
