from typing import Any, Tuple

from algorithms.abstract.bot import ZeroBot
import os, pickle

from algorithms.utils.params import Params
from algorithms.utils.types import SpielState, Player, SpielGame, SpielAction, ChoicePolicy
from typing import List, Dict


def get_opponent_id(player: Player) -> Player:
    return 1 - player


class BellmanBot(ZeroBot):
    def __init__(self, game: SpielGame, params: Params, verbose=False):
        ZeroBot.__init__(self, game, params, verbose)
        pickle_fname = '{}_{}_{}.p'.format(params.num_points, params.num_chex, params.num_die)
        self._num_points = params.num_points
        self._num_chex = params.num_chex
        dir_ = os.path.dirname(__file__)
        path = os.path.join(dir_, 'tables', pickle_fname)
        self.table = pickle.load(open(path, 'rb'))  # type: Dict[tuple, float]

    def evaluate_state(self, state: SpielState, player: Player) -> float:
        state_tuple = self.decode(state, player)
        value = self.table[state_tuple]
        return value

    def value_search(self, state: SpielState) -> List[Tuple[int, float]]:
        action_vals = []
        current_player = state.current_player()
        opponent = get_opponent_id(current_player)
        for action in state.legal_actions():
            working_state = state.clone()
            working_state.apply_action(action)
            val = self.evaluate_state(working_state, opponent)
            action_vals.append((action, val))
        return action_vals

    def step_with_policy(self, state: SpielState) -> Tuple[List[Tuple[int, float]], float]:
        action_vals = self.value_search(state)

        best_value = float('-inf')
        best_action = None
        for action, v in action_vals:
            v = 1 - v
            if v >= best_value:
                best_value = v
                best_action = action

        return action_vals, best_action

    def step(self, state: SpielState):
        return self.step_with_policy(state)[1]

    def extract_white_player_info(self, obs: List[float]) -> tuple:
        player_info = []
        n_player_home = obs.pop(0)
        for x in range(int(n_player_home)):
            player_info.append(0)
        for index in range(1, self._num_points + 1):
            x = obs.pop(0)
            if x == 1:
                player_info.append(index)
        n_player_score = obs.pop(0)
        for x in range(int(n_player_score)):
            player_info.append(self._num_points + 1)
        player_info = tuple(sorted(player_info))
        return player_info

    def extract_red_player_info(self, obs: List[float]) -> tuple:
        player_info = []
        n_player_score = obs.pop(0)
        for x in range(int(n_player_score)):
            player_info.append(self._num_points + 1)
        for index in range(1, self._num_points + 1):
            x = obs.pop(0)
            if x == 1:
                player_info.append(self._num_points + 1 - index)
        n_player_home = obs.pop(0)
        for x in range(int(n_player_home)):
            player_info.append(0)
        player_info = tuple(sorted(player_info))
        return player_info

    def decode(self, state: SpielState, player: Player) -> Tuple[tuple, tuple]:
        obs = state.observation_tensor(player)  # type: List[float]
        white_info = self.extract_white_player_info(obs)
        red_info = self.extract_red_player_info(obs)
        if player == 0:
            return white_info, red_info
        return red_info, white_info

    def action_and_policy(self, state: SpielState) -> Tuple[SpielAction, ChoicePolicy] :
        action = self.step(state)
        choice_policy_target = self._extractor.choice_policy_deterministic(action)
        return action, choice_policy_target

    def search(self, state: SpielState) -> Any:
        pass

