import abc
import itertools
from collections import OrderedDict
from typing import Dict, List

import numpy as np
import gymnasium as gym
from gymnasium import spaces

# from gym3.types import DictType, Discrete, TensorType

from mcu.arm.utils.vpt_lib.actions import Buttons


class ActionMapping(abc.ABC):
    """Class that maps between the standard MC factored action space and a new one you define!

    :param n_camera_bins: Need to specify this to define the original ac space for stats code
    """

    # This is the default buttons groups, it can be changed for your action space
    BUTTONS_GROUPS = OrderedDict(
        hotbar=["none"] + [f"hotbar.{i}" for i in range(1, 10)],
        fore_back=["none", "forward", "back"],
        left_right=["none", "left", "right"],
        sprint_sneak=["none", "sprint", "sneak"],
        use=["none", "use"],
        drop=["none", "drop"],
        attack=["none", "attack"],
        jump=["none", "jump"],
    )

    def __init__(self, n_camera_bins: int = 11):
        assert n_camera_bins % 2 == 1, "n_camera_bins should be odd"
        self.n_camera_bins = n_camera_bins
        self.camera_null_bin = n_camera_bins // 2
        self.stats_ac_space = spaces.Dict(
            {
                "buttons": spaces.MultiBinary(n=len(Buttons.ALL)),
                "camera": spaces.MultiDiscrete([n_camera_bins, n_camera_bins]),
            }
        )

    @abc.abstractmethod
    def from_factored(self, ac: Dict) -> Dict:
        """Converts a factored action (ac) to the new space

        :param ac: Dictionary of actions that must have a batch dimension
        """
        pass

    @abc.abstractmethod
    def to_factored(self, ac: Dict) -> Dict:
        """Converts an action in the new space (ac) to the factored action space.

        :param ac: Dictionary of actions that must have a batch dimension
        """
        pass

    @abc.abstractmethod
    def get_action_space_update(self):
        """Return a magym (gym3) action space. This will be used to update the env action space."""
        pass

    @abc.abstractmethod
    def get_zero_action(self):
        """Return the zero or null action for this action space"""
        pass

    def factored_buttons_to_groups(self, ac_buttons: np.ndarray, button_group: List[str]) -> List[str]:
        """For a mutually exclusive group of buttons in button_group, find which option
        in the group was chosen. Assumes that each button group has the option of 'none'
        meaning that no button in the group was pressed.

        :param ac_buttons: button actions from the factored action space. Should dims [B, len(Buttons.ALL)]
        :param button_group: List of buttons in a mutually exclusive group. Each item in the
            list should appear in Buttons.ALL except for the special case 'none' which means
            no button in the group was pressed. e.g. ['none', 'forward', 'back']. For now
            'none' must be the first element of button_group

        Returns a list of length B, where each element is an item from button_group.
        """
        assert ac_buttons.shape[1] == len(
            Buttons.ALL
        ), f"There should be {len(Buttons.ALL)} buttons in the factored buttons space"
        assert button_group[0] == "none", "This function only works if 'none' is in button_group"
        # Actions in ac_buttons with order according to button_group
        group_indices = [Buttons.ALL.index(b) for b in button_group if b != "none"]
        ac_choices = ac_buttons[:, group_indices]

        # Special cases for forward/back, left/right where mutual press means do neither
        if "forward" in button_group and "back" in button_group:
            ac_choices[np.all(ac_choices, axis=-1)] = 0
        if "left" in button_group and "right" in button_group:
            ac_choices[np.all(ac_choices, axis=-1)] = 0
        ac_non_zero = np.where(ac_choices)
        ac_choice = ["none" for _ in range(ac_buttons.shape[0])]
        # Iterate over the non-zero indices so that if two buttons in a group were pressed at the same time
        # we give priority to the button later in the group. E.g. if hotbar.1 and hotbar.2 are pressed during the same
        # timestep, hotbar.2 is marked as pressed
        for index, action in zip(ac_non_zero[0], ac_non_zero[1]):
            ac_choice[index] = button_group[action + 1]  # the zero'th index will mean no button pressed
        return ac_choice

class IDMActionMapping(ActionMapping):
    """For IDM, but essentially this is just an identity mapping"""
    def from_factored(self, ac: Dict) -> Dict:
        return ac

    def to_factored(self, ac: Dict) -> Dict:
        return ac

    def get_action_space_update(self):
        """Return a magym (gym3) action space. This will be used to update the env action space."""
        return {
            "buttons": spaces.MultiBinary(n=len(Buttons.ALL)),
            "camera": spaces.MultiDiscrete([n_camera_bins, n_camera_bins]),
        }

    def get_zero_action(self):
        raise NotImplementedError()

class CameraHierarchicalMapping(ActionMapping):
    """Buttons are joint as in ButtonsJointMapping, but now a camera on/off meta action is added into this joint space.
    When this meta action is triggered, the separate camera head chooses a camera action which is also now a joint space.

    :param n_camera_bins: number of camera bins in the factored space
    """

    # Add camera meta action to BUTTONS_GROUPS
    BUTTONS_GROUPS = ActionMapping.BUTTONS_GROUPS.copy()
    BUTTONS_GROUPS["camera"] = ["none", "camera"]
    BUTTONS_COMBINATIONS = list(itertools.product(*BUTTONS_GROUPS.values())) + ["inventory"]
    BUTTONS_COMBINATION_TO_IDX = {comb: i for i, comb in enumerate(BUTTONS_COMBINATIONS)}
    BUTTONS_IDX_TO_COMBINATION = {i: comb for i, comb in enumerate(BUTTONS_COMBINATIONS)}

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.camera_groups = OrderedDict(
            camera_x=[f"camera_x{i}" for i in range(self.n_camera_bins)],
            camera_y=[f"camera_y{i}" for i in range(self.n_camera_bins)],
        )
        self.camera_combinations = list(itertools.product(*self.camera_groups.values()))
        self.camera_combination_to_idx = {comb: i for i, comb in enumerate(self.camera_combinations)}
        self.camera_idx_to_combination = {i: comb for i, comb in enumerate(self.camera_combinations)}
        self.camera_null_idx = self.camera_combination_to_idx[
            (f"camera_x{self.camera_null_bin}", f"camera_y{self.camera_null_bin}")
        ]
        self._null_action = {
            "buttons": self.BUTTONS_COMBINATION_TO_IDX[tuple("none" for _ in range(len(self.BUTTONS_GROUPS)))]
        }
        self._precompute_to_factored()

    def _precompute_to_factored(self):
        """Precompute the joint action -> factored action matrix."""
        button_dim = self.stats_ac_space["buttons"].n
        self.BUTTON_IDX_TO_FACTORED = np.zeros((len(self.BUTTONS_IDX_TO_COMBINATION), button_dim), dtype=int)
        self.BUTTON_IDX_TO_CAMERA_META_OFF = np.zeros((len(self.BUTTONS_IDX_TO_COMBINATION)), dtype=bool)
        self.CAMERA_IDX_TO_FACTORED = np.zeros((len(self.camera_idx_to_combination), 2), dtype=int)

        # Pre compute Buttons
        for jnt_ac, button_comb in self.BUTTONS_IDX_TO_COMBINATION.items():
            new_button_ac = np.zeros(len(Buttons.ALL), dtype="i")
            if button_comb == "inventory":
                new_button_ac[Buttons.ALL.index("inventory")] = 1
            else:
                for group_choice in button_comb[:-1]:  # Last one is camera
                    if group_choice != "none":
                        new_button_ac[Buttons.ALL.index(group_choice)] = 1

                if button_comb[-1] != "camera":  # This means camera meta action is off
                    self.BUTTON_IDX_TO_CAMERA_META_OFF[jnt_ac] = True
            self.BUTTON_IDX_TO_FACTORED[jnt_ac] = new_button_ac

        # Pre compute camera
        for jnt_ac, camera_comb in self.camera_idx_to_combination.items():
            new_camera_ac = np.ones((2), dtype="i") * self.camera_null_bin
            new_camera_ac[0] = self.camera_groups["camera_x"].index(camera_comb[0])
            new_camera_ac[1] = self.camera_groups["camera_y"].index(camera_comb[1])
            self.CAMERA_IDX_TO_FACTORED[jnt_ac] = new_camera_ac

    def from_factored(self, ac: Dict) -> Dict:
        """Converts a factored action (ac) to the new space. Assumes ac has a batch dim"""
        assert ac["camera"].ndim == 2, f"bad camera label, {ac['camera']}"
        assert ac["buttons"].ndim == 2, f"bad buttons label, {ac['buttons']}"
        # Get button choices for everything but camera
        choices_by_group = OrderedDict(
            (k, self.factored_buttons_to_groups(ac["buttons"], v)) for k, v in self.BUTTONS_GROUPS.items() if k != "camera"
        )
        # Set camera "on off" action based on whether non-null camera action was given
        camera_is_null = np.all(ac["camera"] == self.camera_null_bin, axis=1)
        choices_by_group["camera"] = ["none" if is_null else "camera" for is_null in camera_is_null]

        new_button_ac = []
        new_camera_ac = []
        for i in range(ac["buttons"].shape[0]):
            # Buttons
            key = tuple([v[i] for v in choices_by_group.values()])
            if ac["buttons"][i, Buttons.ALL.index("inventory")] == 1:
                key = "inventory"
            new_button_ac.append(self.BUTTONS_COMBINATION_TO_IDX[key])

            # Camera -- inventory is also exclusive with camera
            if key == "inventory":
                key = (
                    f"camera_x{self.camera_null_bin}",
                    f"camera_y{self.camera_null_bin}",
                )
            else:
                key = (f"camera_x{ac['camera'][i][0]}", f"camera_y{ac['camera'][i][1]}")
            new_camera_ac.append(self.camera_combination_to_idx[key])

        return dict(
            buttons=np.array(new_button_ac)[:, None],
            camera=np.array(new_camera_ac)[:, None],
        )

    def to_factored(self, ac: Dict) -> Dict:
        """Converts an action in the new space (ac) to the factored action space. Assumes ac has a batch dim"""
        if ac["camera"].ndim == 0:
            ac["camera"] = ac["camera"][None]
            ac["buttons"] = ac["buttons"][None]

        assert ac["camera"].shape[-1] == 1
        assert ac["buttons"].shape[-1] == 1

        new_button_ac = self.BUTTON_IDX_TO_FACTORED[np.squeeze(ac["buttons"], -1)]
        camera_off = self.BUTTON_IDX_TO_CAMERA_META_OFF[np.squeeze(ac["buttons"], -1)]
        new_camera_ac = self.CAMERA_IDX_TO_FACTORED[np.squeeze(ac["camera"], -1)]
        new_camera_ac[camera_off] = self.camera_null_bin

        return dict(buttons=new_button_ac, camera=new_camera_ac)

    def get_action_space_update(self):
        return {
            "camera": spaces.MultiDiscrete([len(self.camera_combinations)]),
            "buttons": spaces.MultiDiscrete([len(self.BUTTONS_COMBINATIONS)])
        }

    def get_zero_action(self):
        return self._null_action

