import logging
import numpy as np
import ray
from ray import cloudpickle as pickle
from collections import deque


@ray.remote
class TimeBasedTaskGenerator:
    def __init__(self, config: dict, seed):
        self._tasks = [
            TASK_academy_empty_goal_close,
            TASK_academy_3_vs_1,
            TASK_academy_pass_and_shoot_with_keeper,
            TASK_3_vs_3,
            TASK_5_vs_5,
            TASK_8_vs_8,
            TASK_11_vs_11_easy_stochastic,
        ]
        self._n_tasks = len(self._tasks)
        self.config = config
        self.timesteps = 0
        self.stop_timesteps = self.config.get("stop_timesteps", 100000000)
        self._seed = seed
        np.random.seed(self._seed)
        self.sample_flag = False

        self.global_env_steps = 0
        self.global_high_level_steps = 0

        self.info = {}

    def update_timesteps(self, timesteps):
        self.timesteps = timesteps

    def set_sample_flag(self, flag):
        """a flag to enable env to reset"""
        self.sample_flag = flag

    def get_sample_flag(self):
        return self.sample_flag

    def sample_task(self):
        return self._tasks[self.timesteps // (self.stop_timesteps // self._n_tasks)]

    def get_task(self):
        return self._tasks[self.timesteps // (self.stop_timesteps // self._n_tasks)]

    def update_info(self, new_info: dict):
        self.info.update(new_info)

    def get_info(self):
        return self.info

    def inc_env_steps(self):
        self.global_env_steps += 1

    def get_env_steps(self):
        return self.global_env_steps

    def inc_high_level_steps(self):
        self.global_high_level_steps += 1

    def get_high_level_steps(self):
        return self.global_high_level_steps

    def save(self) -> bytes:
        """Serializes this task generator's current state and returns it.

        Returns:
            The current state of this task generator as a serialized, pickled
            byte sequence.
        """
        return pickle.dumps(
            {
                "historical_policies": self.historical_policies,
                "burn_in_policies": self._burn_in_policies,
                "version": self.version,
                "game_results": self._game_results,
                "win_rate_history": self._win_rate_history,
                "wins": self._wins,
                "draws": self._draws,
                "losses": self._losses,
                "games": self._games,
            }
        )

    def restore(self, objs: bytes) -> None:
        """Restores this task generator's state from a sequence of bytes.

        Args:
            objs: The byte sequence to restore this task generator's state from.
        """
        objs = pickle.loads(objs)
        self.historical_policies = objs["historical_policies"]
        self._burn_in_policies = objs["burn_in_policies"]
        self.version = objs["version"]
        self._game_results = objs["game_results"]
        self._win_rate_history = objs["win_rate_history"]
        self._wins = objs["wins"]
        self._draws = objs["draws"]
        self._losses = objs["losses"]
        self._games = objs["games"]
