import copy
import enum
import numpy as np

from dataclasses import dataclass

from expground.types import AgentID, PolicyID, Dict, Any, Sequence, Union, Tuple, List
from expground.logger import Log


@dataclass
class PayoffTable:
    identify: AgentID
    agents: Sequence[AgentID]
    table: Any = None

    def __post_init__(self):
        # record policy idx
        self._policy_idx = {agent: {} for agent in self.agents}
        self.table = np.zeros([0] * len(self.agents), dtype=np.float32)

    def __getitem__(self, key: Dict[str, Sequence[PolicyID]]) -> np.ndarray:
        """Return a sub matrix"""
        idx = self._get_combination_index(key)
        # Log.info("key %s %s %s", key, idx, self.table)
        return self.table[idx]

    def __setitem__(self, key: Dict[AgentID, Sequence[PolicyID]], value: float):
        idx = self._get_combination_index(key)
        self.table[idx] = value

    def __setstate__(self, state):
        for k, v in state.items():
            if k == "table":
                self.__dict__[k] = v.copy()
            else:
                self.__dict__[k] = v

    def expand(self, policy_mapping: Dict[AgentID, List[PolicyID]]):
        pad_info = []
        for i, agent in enumerate(self.agents):
            _new_supports = policy_mapping[agent]
            _existed_supports = list(self._policy_idx[agent].keys())
            # filter supports which has been already existed
            _new_supports = [s for s in _new_supports if s not in _existed_supports]
            pad_info.append((0, len(_new_supports)))
            # fill into support
            _len = len(_existed_supports)
            self._policy_idx[agent].update(
                {k: _len + i for i, k in enumerate(_new_supports)}
            )
            # update index
        self.table = np.pad(self.table, pad_info)

    def _get_combination_index(
        self, policy_combination: Dict[AgentID, Sequence[PolicyID]]
    ) -> Tuple:
        """Return combination index, if doesn't exist, expand it"""
        res = []
        # expand_flag = False
        # pad_info = []
        for agent in self.agents:
            idx = []
            policy_seq = policy_combination[agent]
            if not isinstance(policy_seq, List):
                policy_seq = [policy_seq]

            for p in policy_seq:
                if self._policy_idx[agent].get(p) is None:
                    raise RuntimeError(
                        "Policy %s does not exist, legal for agent %s is %s"
                        % (p, agent, self._policy_idx[agent])
                    )
                    # expand_flag = True
                    # self._policy_idx[agent][p] = len(self._policy_idx[agent])
                    # pad_info.append((0, 1))
                # else:
                #     pad_info.append((0, 0))
                idx.append(self._policy_idx[agent][p])
            res.append(idx)
        # if expand_flag:
        #     self.expand_table(pad_info)
        return np.ix_(*res)

    def get_agent_support_idx(self, pid: PolicyID):
        return self._policy_idx[self.identify][pid]

    def replace(
        self, policy_mapping: Dict[AgentID, PolicyID], replaced: Dict[AgentID, PolicyID]
    ):
        for aid, rpid in replaced.items():
            idx = self._policy_idx[aid].pop(rpid)
            pid = policy_mapping[aid]
            self._policy_idx[aid][pid] = idx


@dataclass
class SimulationTable:
    axis_num: int
    sorted_agents: Sequence[AgentID]

    def __post_init__(self):
        assert self.axis_num == len(self.sorted_agents)
        self._support = {agent: [] for agent in self.sorted_agents}
        self._table = np.zeros([0] * self.axis_num, dtype=np.int32)

    def __getitem__(self, key: Dict[str, Union[Sequence[PolicyID], PolicyID]]):
        idx = self._get_index(key)
        return self._table[idx]

    def __setitem__(
        self, key: Dict[AgentID, Union[Sequence[PolicyID], PolicyID]], value: int
    ):
        idx = self._get_index(key)
        self._table[idx] = value

    def __setstate__(self, state):
        for k, v in state.items():
            if k == "_table":
                self.__dict__[k] = v.copy()
            else:
                self.__dict__[k] = v

    def get_not_dones(
        self, useful_agents: Dict[AgentID, List[PolicyID]]
    ) -> Sequence[Dict[AgentID, PolicyID]]:
        index = np.argwhere(self._table == 0)
        # convert index to mapping
        res = []
        for _idx in index:
            if useful_agents is not None:
                has_not_done = False
                for i, j in enumerate(_idx):
                    if (
                        useful_agents.get(self.sorted_agents[i])
                        and self._support[self.sorted_agents[i]][j]
                        in useful_agents[self.sorted_agents[i]]
                    ):
                        has_not_done = True
                if not has_not_done:
                    continue

            res.append(
                {
                    self.sorted_agents[i]: self._support[self.sorted_agents[i]][j]
                    for i, j in enumerate(_idx)
                }
            )
        return res

    def expand(self, supports: Dict[AgentID, List[PolicyID]]):
        pad_info = []
        for i, agent in enumerate(self.sorted_agents):
            _new_supports = supports[agent]
            _existed_supports = self._support[agent]
            # filter supports which has been already existed
            _new_supports = [s for s in _new_supports if s not in _existed_supports]
            pad_info.append((0, len(_new_supports)))
            # fill into support
            self._support[agent].extend(_new_supports)
            # update index
        self._table = np.pad(self._table, pad_info)

    def _get_index(self, keys: Dict[str, Union[Sequence[PolicyID], PolicyID]]):
        res = []
        for agent in self.sorted_agents:
            idx = []
            supports = keys[agent]
            if not isinstance(supports, List):
                supports = [supports]
            idx = list(map(self._support[agent].index, supports))
            res.append(idx)
        return np.ix_(*res)

    def get_index(self, keys: Dict[str, Union[Sequence[PolicyID], PolicyID]]):
        return self._get_index(keys)

    def get_agent_support_idx(self, agent, pid):
        return self._support[agent].index(pid)

    def replace(
        self, policy_mapping: Dict[AgentID, PolicyID], replaced: Dict[AgentID, PolicyID]
    ):
        for aid, rpid in replaced.items():
            idx = self._support[aid].index(rpid)
            self._support[aid][idx] = policy_mapping[aid]


class RectifierType(enum.IntEnum):
    DEFAULT = 0
    ReLU = 1


def get_rectifier(_type: RectifierType):
    if _type == RectifierType.DEFAULT:
        return lambda x: x
    elif _type == RectifierType.ReLU:
        return lambda x: max(x, 0.0)


@dataclass
class PayoffMatrix:
    """Maintain a PayoffMatrix instance and ..."""

    agents: Sequence[AgentID]
    """ sequence of agent ids """

    payoff_matrix: Any = None
    """ payoff matrix """

    rectifier_type: int = 0
    """ determin rectifier to rectify item """

    # equilibrium: Any = None

    def __post_init__(self):
        # init payoff matrix with registered agent ids
        self.agents = sorted(self.agents)
        self._agent_axes_mapping = dict(zip(self.agents, range(len(self.agents))))

        self.payoff_matrix: Dict[AgentID, PayoffTable] = {
            agent: PayoffTable(agent, self.agents) for agent in self.agents
        }
        # self.equilibrium = {}
        self.simultation_table = SimulationTable(
            axis_num=len(self.agents), sorted_agents=self.agents
        )

        self.rectifier_type = RectifierType(self.rectifier_type)
        self.rectifier = get_rectifier(self.rectifier_type)

        Log.debug("created payoff matrix for agents=%s", self.agents)

    @property
    def agent_axes_mapping(self):
        return self._agent_axes_mapping

    def dict_to_matrix(
        self, results: Dict[AgentID, Sequence[Tuple[Dict, Dict[AgentID, float]]]]
    ) -> Dict[AgentID, np.ndarray]:
        # parse
        # dim = int(len(results) ** (1 / (len(self.agents) - 1)))
        mat = {}
        base = len(self.agents) - 1
        # agent: np.zeros([dim] * (len(self.agents) - 1)) for agent in self.agents}
        for agent, _results in results.items():
            dim = int(len(_results) ** (1 / base))
            mat[agent] = np.zeros([dim] * base)
            for pmapping, reward in _results:
                # convert pmapping to idx
                idx_comb = [None] * len(self.agents)
                for aid, pid in pmapping.items():
                    # filter self axis
                    if aid == agent:
                        continue
                    idx_comb[self._agent_axes_mapping[aid]] = self.payoff_matrix[
                        aid
                    ].get_agent_support_idx(pid)
                # then idx_comb will be: (x, None, y), it is legal
                mat[agent][tuple(idx_comb)] = reward[agent]
        return mat

    def get_sub_matrix(
        self, keys: Dict[AgentID, Sequence[PolicyID]]
    ) -> Dict[AgentID, np.ndarray]:
        """Return a sub matrix with given agent policy supports.

        Args:
            keys (Dict[AgentID, Sequence[PolicyID]]): A dict of agent policy supports.

        Returns:
            Dict[AgentID, Dict]: A dict of payoff matrix.
        """

        return {aid: self.payoff_matrix[aid][keys] for aid in self.agents}

    def update_payoff_and_simulation_status(
        self,
        results: Sequence[Tuple[Dict, Dict[AgentID, float]]],
        set_dones: Sequence[bool] = None,
    ):
        """Update payoff"""

        for i, (policy_mapping, res) in enumerate(results):
            # set done to simulation table
            if set_dones is None or set_dones[i]:
                self.simultation_table[policy_mapping] = 1
            for agent in self.agents:
                self.payoff_matrix[agent][policy_mapping] = self.rectifier(
                    res["reward"][agent]
                )

    def retrieve_combinations(
        self, support_dict: Dict[AgentID, Tuple[PolicyID]], group_by: Any = None
    ) -> Union[Dict, Tuple]:
        if group_by is None:
            # no group, return a tuple
            pass
        else:
            # check whether it is a agent sequence.
            pass
        raise NotImplementedError

    def gen_simulations(
        self, filter=None, split: bool = False
    ) -> Sequence[Union[Dict, Sequence[Dict]]]:
        """Generate simulations from current simulation table.

        Args:
            filters (Sequence, optional): Filter to get simulations. Defaults to None.

        Returns:
            Sequence[Union[Dict, Sequence[Dict]]]: A dict of sequence of dict of policies.
        """

        ori_res = self.simultation_table.get_not_dones(filter)
        new_res = []
        if split:
            _length = len(ori_res) // self.simultation_table.axis_num
            _tail_length = len(ori_res) % self.simultation_table.axis_num
            segments = (
                [_length] * self.simultation_table.axis_num
                if _length
                else [0] * self.simultation_table.axis_num
            )
            segments[-1] += _tail_length

            for e in segments:
                new_res.append(ori_res[:e])
                ori_res = ori_res[e:]

        else:
            new_res = ori_res
        return new_res

    def expand(self, policy_mapping: Dict[AgentID, Sequence[PolicyID]]):
        # expand simulation tables
        self.simultation_table.expand(policy_mapping)
        # expand payoff tables
        for agent in self.agents:
            self.payoff_matrix[agent].expand(policy_mapping)

    def replace(
        self,
        new_support: Dict[AgentID, PolicyID],
        replaced: Dict[AgentID, PolicyID] = None,
    ):
        if replaced is None:
            self.expand(new_support)
        else:
            new_support = {aid: p[0] for aid, p in new_support.items()}
            for agent in self.agents:
                self.payoff_matrix[agent].replace(new_support, replaced)
                self.simultation_table.replace(new_support, replaced)

    def aggregate(
        self,
        equilibrium: Dict[AgentID, Dict[PolicyID, float]],
        brs: Dict[AgentID, PolicyID] = None,
    ) -> Dict[AgentID, float]:
        res = {agent_id: 0.0 for agent_id in equilibrium}
        # extract population combination
        population_combination = {
            agent: list(e.keys()) for agent, e in equilibrium.items()
        }
        # retrieve partially payoff matrix
        if brs is None:
            # m*m*...*m
            res = {
                agent: self.payoff_matrix[agent][population_combination]
                for agent in self.agents
            }
        else:
            # m*m*...*1*...*m
            for agent in self.agents:
                tmp_comb = copy.copy(population_combination)
                # replace the policy combination of ego agent with its best response
                tmp_comb[agent] = [brs[agent]]
                res[agent] = self.payoff_matrix[agent][tmp_comb]

        # then aggregate the payoff matrix along axis
        weight_vectors = [
            np.asarray([list(equilibrium[agent].values())]) for agent in self.agents
        ]

        if brs is None:
            # in case of computing nash values
            weight_mat = np.asarray([[1.0]])
            for vector in weight_vectors:
                weight_mat = np.einsum("ij,j...->i...", vector.T, weight_mat)
                weight_mat = np.expand_dims(weight_mat, axis=0)
            weight_mat = np.squeeze(weight_mat, axis=0)
            weight_mat = np.squeeze(weight_mat, axis=-1)
            for agent in self.agents:
                assert weight_mat.shape == res[agent].shape, (
                    weight_mat.shape,
                    res[agent].shape,
                    equilibrium[agent],
                )
                res[agent] = (res[agent] * weight_mat).sum()
        else:
            # in case of computing
            weight_mat = np.asarray([[1.0]])
            for i, agent in enumerate(self.agents):
                # ignore this one
                tmp = weight_mat
                for vector in weight_vectors[i + 1 :]:
                    tmp = np.einsum("ij,j...->i...", vector.T, tmp)
                    tmp = np.expand_dims(tmp, axis=0)
                tmp = np.squeeze(tmp, axis=-1)
                tmp = np.squeeze(tmp, axis=0)
                tmp = np.expand_dims(tmp, axis=i)
                assert tmp.shape == res[agent].shape, (
                    tmp.shape,
                    res[agent].shape,
                    equilibrium[agent],
                    i,
                )
                res[agent] = (res[agent] * tmp).sum()
                weight_mat = np.einsum("ij,j...->i...", weight_vectors[i].T, weight_mat)
                weight_mat = np.expand_dims(weight_mat, axis=0)

        return res


__all__ = ["PayoffMatrix"]
