import copy
from typing import List, Tuple

import numpy as np
from easydict import EasyDict

from ding.utils.compression_helper import jpeg_data_decompressor


class GameSegment:
    """
    Overview:
        A game segment from a full episode trajectory.

        The length of one episode in (Atari) games is often quite large. This class represents a single game segment
        within a larger trajectory, split into several blocks.

    Interfaces:
        - __init__
        - __len__
        - reset
        - pad_over
        - is_full
        - legal_actions
        - append
        - get_observation
        - zero_obs
        - step_obs
        - get_targets
        - game_segment_to_array
        - store_search_stats
    """

    def __init__(self, action_space: int, game_segment_length: int = 200, config: EasyDict = None) -> None:
        """
        Overview:
            Init the ``GameSegment`` according to the provided arguments.
        Arguments:
             action_space (:obj:`int`): action space
            - game_segment_length (:obj:`int`): the transition number of one ``GameSegment`` block
        """
        self.action_space = action_space
        self.game_segment_length = game_segment_length
        self.num_unroll_steps = config.num_unroll_steps
        self.td_steps = config.td_steps
        self.frame_stack_num = config.model.frame_stack_num # 1
        self.discount_factor = config.discount_factor
        self.action_space_size = config.model.action_space_size
        self.gray_scale = config.gray_scale
        self.transform2string = config.transform2string
        self.sampled_algo = config.sampled_algo
        self.gumbel_algo = config.gumbel_algo
        self.use_ture_chance_label_in_chance_encoder = config.use_ture_chance_label_in_chance_encoder

        if isinstance(config.model.observation_shape, int) or len(config.model.observation_shape) == 1:
            # for vector obs input, e.g. classical control and box2d environments
            self.zero_obs_shape = config.model.observation_shape
        elif len(config.model.observation_shape) == 3:
            # image obs input, e.g. atari environments
            self.zero_obs_shape = (config.model.image_channel, config.model.observation_shape[-2], config.model.observation_shape[-1])

        self.obs_segment = []
        self.action_segment = []
        self.reward_segment = []

        self.child_visit_segment = []
        self.root_value_segment = []

        self.action_mask_segment = []
        self.to_play_segment = []

        # those from the target_model
        self.target_values = [] 
        self.target_rewards = []
        self.target_policies = []

        self.improved_policy_probs = []

        if self.sampled_algo:
            self.root_sampled_actions = []
        if self.use_ture_chance_label_in_chance_encoder:
            self.chance_segment = []

    def get_unroll_obs(self, timestep: int, num_unroll_steps: int = 0, padding: bool = False) -> np.ndarray:
        """
        Overview:
            Get an observation of the correct format: o[t, t + stack frames + num_unroll_steps].
        Arguments:
            - timestep (int): The time step.
            - num_unroll_steps (int): The extra length of the observation frames.
            - padding (bool): If True, pad frames if (t + stack frames) is outside of the trajectory.
        """
        stacked_obs = self.obs_segment[timestep:timestep + self.frame_stack_num + num_unroll_steps]
        if padding:
            pad_len = self.frame_stack_num + num_unroll_steps - len(stacked_obs)
            if pad_len > 0:
                pad_frames = np.array([stacked_obs[-1] for _ in range(pad_len)])
                stacked_obs = np.concatenate((stacked_obs, pad_frames))
        if self.transform2string:
            stacked_obs = [jpeg_data_decompressor(obs, self.gray_scale) for obs in stacked_obs]
        return stacked_obs

    def zero_obs(self) -> List:
        """
        Overview:
            Return an observation frame filled with zeros.
        Returns:
            ndarray: An array filled with zeros.
        """
        return [np.zeros(self.zero_obs_shape, dtype=np.float32) for _ in range(self.frame_stack_num)]

    def get_obs(self) -> List:
        """
        Overview:
            Return an observation in the correct format for model inference.
        Returns:
              stacked_obs (List): An observation in the correct format for model inference.
          """
        timestep_obs = len(self.obs_segment) - self.frame_stack_num
        timestep_reward = len(self.reward_segment)
        assert timestep_obs == timestep_reward, "timestep_obs: {}, timestep_reward: {}".format(
            timestep_obs, timestep_reward
        )
        timestep = timestep_reward
        stacked_obs = self.obs_segment[timestep:timestep + self.frame_stack_num]
        if self.transform2string:
            stacked_obs = [jpeg_data_decompressor(obs, self.gray_scale) for obs in stacked_obs]
        return stacked_obs

    def append(
            self,
            action: np.ndarray,
            obs: np.ndarray,
            reward: np.ndarray,
            action_mask: np.ndarray = None,
            to_play: int = -1,
            chance: int = 0,
    ) -> None:
        """
        Overview:
            Append a transition tuple, including a_t, o_{t+1}, r_{t}, action_mask_{t}, to_play_{t}.
        """
        self.action_segment.append(action)
        self.obs_segment.append(obs)
        self.reward_segment.append(reward)

        self.action_mask_segment.append(action_mask)
        self.to_play_segment.append(to_play)
        if self.use_ture_chance_label_in_chance_encoder:
            self.chance_segment.append(chance)

    def pad_over(
            self, next_segment_observations: List, next_segment_rewards: List, next_segment_root_values: List,
            next_segment_child_visits: List, next_segment_improved_policy: List = None, next_chances: List = None,
    ) -> None:
        """
        Overview:
            To make sure the correction of value targets, we need to add (o_t, r_t, etc) from the next game_segment
            , which is necessary for the bootstrapped values at the end states of previous game_segment.
            e.g: len = 100; target value v_100 = r_100 + gamma^1 r_101 + ... + gamma^4 r_104 + gamma^5 v_105,
            but r_101, r_102, ... are from the next game_segment.
        Arguments:
            - next_segment_observations (:obj:`list`): o_t from the next game_segment
            - next_segment_rewards (:obj:`list`): r_t from the next game_segment
            - next_segment_root_values (:obj:`list`): root values of MCTS from the next game_segment
            - next_segment_child_visits (:obj:`list`): root visit count distributions of MCTS from the next game_segment
            - next_segment_improved_policy (:obj:`list`): root children select policy of MCTS from the next game_segment (Only used in Gumbel MuZero)
        """
        assert len(next_segment_observations) <= self.num_unroll_steps
        assert len(next_segment_child_visits) <= self.num_unroll_steps
        assert len(next_segment_root_values) <= self.num_unroll_steps + self.td_steps
        assert len(next_segment_rewards) <= self.num_unroll_steps + self.td_steps - 1
        # ==============================================================
        # The core difference between GumbelMuZero and MuZero
        # ==============================================================
        if self.gumbel_algo:
            assert len(next_segment_improved_policy) <= self.num_unroll_steps + self.td_steps

        # NOTE: next block observation should start from (stacked_observation - 1) in next trajectory
        # why not use extend
        for observation in next_segment_observations:
            self.obs_segment.append(copy.deepcopy(observation))

        for reward in next_segment_rewards:
            self.reward_segment.append(reward)

        for value in next_segment_root_values:
            self.root_value_segment.append(value)

        for child_visits in next_segment_child_visits:
            self.child_visit_segment.append(child_visits)
        
        if self.gumbel_algo:
            for improved_policy in next_segment_improved_policy:
                self.improved_policy_probs.append(improved_policy)
        if self.use_ture_chance_label_in_chance_encoder:
            for chances in next_chances:
                self.chance_segment.append(chances)

    def get_targets(self, timestep: int) -> Tuple:
        """
        Overview:
            return the value/reward/policy targets at step timestep
        """
        return self.target_values[timestep], self.target_rewards[timestep], self.target_policies[timestep]

    def store_search_stats(
            self, visit_counts: List, root_value: List, root_sampled_actions: List = None, improved_policy: List = None, idx: int = None
    ) -> None:
        """
        Overview:
            store the visit count distributions and value of the root node after MCTS.
        """
        sum_visits = sum(visit_counts)
        if sum_visits == 0:
            # if the sum of visit counts is 0, set it to a small value to avoid division by zero
            sum_visits = 1e-6
        if idx is None:
            self.child_visit_segment.append([visit_count / sum_visits for visit_count in visit_counts])
            self.root_value_segment.append(root_value)
            if self.sampled_algo:
                self.root_sampled_actions.append(root_sampled_actions)
            # store the improved policy in Gumbel MuZero: \pi'=softmax(logits + \sigma(CompletedQ))
            if self.gumbel_algo:
                self.improved_policy_probs.append(improved_policy)
        else:
            self.child_visit_segment[idx] = [visit_count / sum_visits for visit_count in visit_counts]
            self.root_value_segment[idx] = root_value
            self.improved_policy_probs[idx] = improved_policy

    def game_segment_to_array(self) -> None:
        """
        Overview:
            Post-process the data when a `GameSegment` block is full. This function converts various game segment
            elements into numpy arrays for easier manipulation and processing.
        Structure:
            The structure and shapes of different game segment elements are as follows. Let's assume
            `game_segment_length`=20, `stack`=4, `num_unroll_steps`=5, `td_steps`=5:

            - obs:            game_segment_length + stack + num_unroll_steps, 20+4+5
            - action:         game_segment_length -> 20
            - reward:         game_segment_length + num_unroll_steps + td_steps -1  20+5+5-1
            - root_values:    game_segment_length + num_unroll_steps + td_steps -> 20+5+5
            - child_visits:   game_segment_length + num_unroll_steps -> 20+5
            - to_play:        game_segment_length -> 20
            - action_mask:    game_segment_length -> 20
        Examples:
            Here is an illustration of the structure of `obs` and `rew` for two consecutive game segments
            (game_segment_i and game_segment_i+1):

            - game_segment_i (obs):     4       20        5
                                      ----|----...----|-----|
            - game_segment_i+1 (obs):              4       20        5
                                                  ----|----...----|-----|

            - game_segment_i (rew):        20        5      4
                                      ----...----|------|-----|
            - game_segment_i+1 (rew):                 20        5    4
                                                 ----...----|------|-----|

        Postprocessing:
            - self.obs_segment (:obj:`numpy.ndarray`): A numpy array version of the original obs_segment.
            - self.action_segment (:obj:`numpy.ndarray`): A numpy array version of the original action_segment.
            - self.reward_segment (:obj:`numpy.ndarray`): A numpy array version of the original reward_segment.
            - self.child_visit_segment (:obj:`numpy.ndarray`): A numpy array version of the original child_visit_segment.
            - self.root_value_segment (:obj:`numpy.ndarray`): A numpy array version of the original root_value_segment.
            - self.improved_policy_probs (:obj:`numpy.ndarray`): A numpy array version of the original improved_policy_probs.
            - self.action_mask_segment (:obj:`numpy.ndarray`): A numpy array version of the original action_mask_segment.
            - self.to_play_segment (:obj:`numpy.ndarray`): A numpy array version of the original to_play_segment.
            - self.chance_segment (:obj:`numpy.ndarray`, optional): A numpy array version of the original chance_segment. Only
               created if `self.use_ture_chance_label_in_chance_encoder` is True.

        .. note::
            For environments with a variable action space, such as board games, the elements in `child_visit_segment` may have
            different lengths. In such scenarios, it is necessary to use the object data type for `self.child_visit_segment`.
        """
        self.obs_segment = np.array(self.obs_segment)
        self.action_segment = np.array(self.action_segment)
        self.reward_segment = np.array(self.reward_segment)

        # Check if all elements in self.child_visit_segment have the same length
        if all(len(x) == len(self.child_visit_segment[0]) for x in self.child_visit_segment):
            self.child_visit_segment = np.array(self.child_visit_segment)
        else:
            # In the case of environments with a variable action space, such as board games,
            # the elements in child_visit_segment may have different lengths.
            # In such scenarios, it is necessary to use the object data type.
            self.child_visit_segment = np.array(self.child_visit_segment, dtype=object)

        self.root_value_segment = np.array(self.root_value_segment)
        self.improved_policy_probs = np.array(self.improved_policy_probs)

        self.action_mask_segment = np.array(self.action_mask_segment)
        self.to_play_segment = np.array(self.to_play_segment)
        if self.use_ture_chance_label_in_chance_encoder:
            self.chance_segment = np.array(self.chance_segment)

    def reset(self, init_observations: np.ndarray) -> None:
        """
        Overview:
            Initialize the game segment using ``init_observations``,
            which is the previous ``frame_stack_num`` stacked frames.
        Arguments:
            - init_observations (:obj:`list`): list of the stack observations in the previous time steps.
        """
        self.obs_segment = []
        self.action_segment = []
        self.reward_segment = []

        self.child_visit_segment = []
        self.root_value_segment = []

        self.action_mask_segment = []
        self.to_play_segment = []
        if self.use_ture_chance_label_in_chance_encoder:
            self.chance_segment = []

        assert len(init_observations) == self.frame_stack_num

        for observation in init_observations:
            self.obs_segment.append(copy.deepcopy(observation))

    def is_full(self) -> bool:
        """
        Overview:
            Check whether the current game segment is full, i.e. larger than the segment length.
        Returns:
            bool: True if the game segment is full, False otherwise.
        """
        return len(self.action_segment) >= self.game_segment_length

    def legal_actions(self):
        return [_ for _ in range(self.action_space.n)]

    def __len__(self):
        return len(self.action_segment)
