import os, torch
import json
import numpy as np
from continual_rl.experiments.run_metadata import RunMetadata
from continual_rl.utils.utils import Utils
from continual_rl.utils.common_exceptions import OutputDirectoryNotSetException
from metrics import Average_Performance, Average_Forgetting



device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

raw_experiment = {
    # "experiment": "minihack_nav_paired_2_cycles",  # specify the experiment name
    "atari_6_tasks_5_cycles": [
        "SpaceInvadersNoFrameskip-v4",
        "KrullNoFrameskip-v4",
        "BeamRiderNoFrameskip-v4",
        "HeroNoFrameskip-v4",
        "StarGunnerNoFrameskip-v4",
        "MsPacmanNoFrameskip-v4"],
    "minihack_nav_paired_2_cycles":
        [


            ("Room-Random-5x5-v0", "Room-Random-15x15-v0"),
            ("Corridor-R2-v0", "Corridor-R5-v0"),
            ("Room-Dark-5x5-v0", "Room-Dark-15x15-v0"),
            ("Corridor-R3-v0", "Corridor-R5-v0"),
            ("Room-Monster-5x5-v0", "Room-Monster-15x15-v0"),
            ("CorridorBattle-v0", "CorridorBattle-Dark-v0"),
            ("Room-Trap-5x5-v0", "Room-Trap-15x15-v0"),
            ("HideNSeek-v0", "HideNSeek-Big-v0"),
            ("Room-Ultimate-5x5-v0", "Room-Ultimate-15x15-v0"),
            ("HideNSeek-Lava-v0", "HideNSeek-Big-v0"),
        ],
    # Add more parameters as needed
}



class InvalidTaskAttributeException(Exception):
    def __init__(self, error_msg):
        super().__init__(error_msg)


class Experiment(object):
    def __init__(self, tasks, continual_testing_freq=None, cycle_count=1,num_timesteps = None):

        self.tasks = tasks
        self.action_spaces = self._get_action_spaces(self.tasks)
        self.observation_space = self._get_common_attribute(
            [task.observation_space for task in self.tasks]
        )
        self.task_ids = [task.task_id for task in tasks]
        self._output_dir = None
        self._continual_testing_freq = continual_testing_freq
        self._cycle_count = cycle_count
        self.num_timesteps = num_timesteps

    def set_output_dir(self, output_dir):
        self._output_dir = output_dir

    @property
    def output_dir(self):
        if self._output_dir is None:
            raise OutputDirectoryNotSetException("Output directory not set, but is attempting to be used. Call set_output_dir.")
        return self._output_dir

    @property
    def _logger(self):
        return Utils.create_logger(f"{self.output_dir}/core_process.log")

    @classmethod
    def _get_action_spaces(self, tasks):
        action_space_map = {}  # Maps task id to its action space

        for task in tasks:
            if task.action_space_id not in action_space_map:
                action_space_map[task.action_space_id] = task.action_space
            elif action_space_map[task.action_space_id] != task.action_space:
                raise InvalidTaskAttributeException(f"Action sizes were mismatched for task {task.action_space_id}")

        return action_space_map

    @classmethod
    def _get_common_attribute(self, task_attributes):
        common_attribute = None

        for task_attribute in task_attributes:
            if common_attribute is None:
                common_attribute = task_attribute

            if task_attribute != common_attribute:
                raise InvalidTaskAttributeException("Tasks do not have a common attribute.")

        return common_attribute

    def _run_continual_eval(self, task_run_id, policy, summary_writer, total_timesteps):
        # Run a small amount of eval on all non-eval, not-currently-running tasks
        print(total_timesteps)
        rewards = []
        for test_task_run_id, test_task in enumerate(self.tasks):
            if test_task_run_id % 2 == 0:
                # not checking test_task._task_spec.eval_mode anymore since some eval tasks
                # (for train/test pairs) should be continual eval
                if not test_task._task_spec.with_continual_eval:
                    continue

                self._logger.info(f"Continual eval for task: {test_task_run_id}")

                # Don't increment the total_timesteps counter for continual tests
                test_task_runner = self.tasks[test_task_run_id].continual_eval(
                    test_task_run_id,
                    policy,
                    summary_writer,
                    output_dir=self.output_dir,
                    timestep_log_offset=total_timesteps,
                )
                test_complete = False
                while not test_complete:
                    try:
                        _, reward_to_return = next(test_task_runner)
                        # reward_to_return = reward_to_return[0]
                        if reward_to_return is not None:
                            print(reward_to_return)
                            print(reward_to_return[0])
                            mean_rewards = np.array(reward_to_return[0]).mean()
                            rewards.append(mean_rewards)
                            # print(rewards)
                    except StopIteration:
                        test_complete = True

                self._logger.info(f"Completed continual eval for task: {test_task_run_id}")

        return rewards



    def _run(self, policy, summary_writer):
        # Load as necessary
        policy.load(self.output_dir)
        run_metadata = RunMetadata(self._output_dir)
        start_cycle_id = run_metadata.cycle_id
        start_task_id = run_metadata.task_id
        start_task_timesteps = run_metadata.task_timesteps
        eval_freq = self._continual_testing_freq
        # Only updated after a task is complete. To get the current within-task number, add task_timesteps
        total_train_timesteps = run_metadata.total_train_timesteps

        timesteps_per_save = policy.config.timesteps_per_save
        total_eval_freq = self.num_timesteps  # 设置为每个任务10000步训练
        # print(total_eval_freq)

        metrics = []

        for cycle_id in range(start_cycle_id, self._cycle_count):

            reward_matrix = []
            metrics_0 = []
            # metrics_1 = []

            for task_run_id, task in enumerate(self.tasks[start_task_id:], start=start_task_id):
                if task_run_id % 2 == 0 :

                    reward_vector = []
                    # Run the current task as a generator so we can intersperse testing tasks during the run
                    self._logger.info(f"Starting cycle {cycle_id} task {task_run_id}")
                    task_complete = False
                    task_runner = task.run(
                        task_run_id,
                        policy,
                        summary_writer,
                        self.output_dir,
                        timestep_log_offset=total_train_timesteps,
                        task_timestep_start=start_task_timesteps,
                    )


                    task_timesteps = start_task_timesteps
                    last_eval_step = 0
                    last_timestep_saved = None

                    while not task_complete:
                        try:
                            task_timesteps, _ = next(task_runner)
                        except StopIteration:
                            task_complete = True

                        if not task._task_spec.eval_mode:
                            if last_timestep_saved is None or task_timesteps - last_timestep_saved >= timesteps_per_save or \
                                    task_complete:
                                # Save the metadata that allows us to resume where we left off.
                                # This will not copy files in large_file_path such as
                                # replay buffers, and is intended for debugging model changes
                                # at task boundaries.
                                run_metadata.save(cycle_id, task_run_id, task_timesteps, total_train_timesteps)
                                policy.save(self.output_dir, cycle_id, task_run_id, task_timesteps)
                                if task_complete:
                                    task_boundary_dir = os.path.join(self.output_dir, f'cycle{cycle_id}_task{task_run_id}')
                                    os.makedirs(task_boundary_dir, exist_ok=True)

                                    policy.save(task_boundary_dir, cycle_id, task_run_id, task_timesteps)

                                last_timestep_saved = task_timesteps

                        # If we're already doing eval, don't do a forced eval run (nothing has trained to warrant it anyway)
                        # Evaluate intermittently. Every time is too slow
                        if task_timesteps - last_eval_step >= eval_freq:
                            print(task_timesteps)
                            print(last_eval_step)
                            print(eval_freq)

                            avg_rewards = self._run_continual_eval(
                                task_run_id,
                                policy,
                                summary_writer,
                                total_train_timesteps + task_timesteps,
                            )
                            print(avg_rewards)
                            reward_vector.append(avg_rewards)

                            last_eval_step += eval_freq

                    if task_timesteps - last_eval_step >= eval_freq:
                        task_timesteps, _,  = next(task_runner)  # 恢复训练步骤


                    if task_complete:

                        print(f"cycle {cycle_id} task {task_run_id}已完成。")

                        # 任务完成后，更新训练总步数
                    if not task._task_spec.eval_mode:
                        total_train_timesteps += task_timesteps

                            # 重置任务步数
                    start_task_timesteps = 0
                    total_train_timesteps = 0


                    # print(avg_rewards)
                    reward_matrix.append(np.array(avg_rewards))
        
                # 每个周期后重置任务ID为0，以便在下一个周期重新开始任务
            start_task_id = 0
            # print(avg_rewards)
            # reward_matrix.append(avg_rewards)
            reward_numpy1 = np.array(reward_matrix)
            self._logger.info(f"reward_numpy1: {reward_numpy1}")
            reward_unique = reward_numpy1[::2]
            reward_numpy1 = reward_unique[:,::2]
            self._logger.info(f"reward_numpy1: {reward_numpy1}")
            # print(reward_numpy1)
            # print(reward_numpy1.shape)
            save_path = os.path.join(self.output_dir, f'reward_numpy{cycle_id}.npy')
            os.makedirs(self.output_dir, exist_ok=True)  # 确保路径存在
            np.save(save_path, reward_numpy1)
            print(f"数组已保存到: {save_path}")

            performance_0 = Average_Performance(reward_numpy1)
            self._logger.info(f"Cycle {cycle_id} performance: {performance_0}")
            print(f"per_0: {performance_0}")
            metrics_0.append(performance_0)

            F_0 = Average_Forgetting(reward_numpy1)
            self._logger.info(f"Cycle {cycle_id} forgetting: {F_0}")
            print(f"F_0: {F_0}")
            metrics_0.append(F_0)

            metrics_0 = np.array(metrics_0)
            # print(metrics_0)
            # print(metrics_0.shape)

            metrics.append(metrics_0)

        metrics = np.array(metrics)
        self._logger.info(f"Metrics: {metrics}")
        print(metrics)
        # print(metrics.shape)

        metrics = np.mean(metrics, axis=0)
        self._logger.info(f"Metrics: {metrics}")
        print(metrics)
        # print(metrics.shape)


    def try_run(self, policy, summary_writer):
        try:
            self._run(policy, summary_writer)
        except Exception as e:
            self._logger.exception(f"Failed with exception: {e}")
            policy.shutdown()

            raise e
