import numpy as np
from typing import List
from algorithms.abstract.extractor import Extractor
from algorithms.utils.types import SpielAction, ActionImage, SpielState, StateFeature, TauPolicy, Value, ChancePolicy, \
    ChoicePolicy, TrainingTarget, SpielGame


class NannonExtractor(Extractor):
    """
    Extracts information from Spiel objects and values and transforms them into unique
    numpy arrays specific for the game Nannon.
    Other games can be implemented in similar fashion by defining a subclass of Extractor.
    Args:
        game: SpielGame, openspiel game object
    """
    def __init__(self,
                 game: SpielGame,
                 num_points: int,
                 num_chex: int,
                 num_die: int,
                 num_actions: int,
                 pass_action: int,
                 num_players: int = 4):
        Extractor.__init__(self, game, num_actions, pass_action, num_players)
        self._num_points = num_points
        self._num_chex = num_chex
        self._num_die = num_die
        self._image_width = num_points + 2

    def action_to_image(self, spiel_action: SpielAction) -> ActionImage:
        action_image = np.zeros((1, 1, 1, self._num_points + 2), dtype=np.float32)
        if spiel_action >= 0:
            action_image[0][0][0][spiel_action] = 1.0
        return action_image

    def final_action_image(self) -> ActionImage:
        action_image = np.zeros((1, 1, 1, self._num_points + 2), dtype=np.float32)
        action_image[0][0][0][self._pass_action] = 1.0
        return action_image

    def state_feature_extractor(self, state: SpielState) -> StateFeature:
        obs = state.observation_tensor(0)
        white_vec = np.array([obs.pop(0) for _ in range(self._num_points + 2)], dtype=np.float32)
        red_vec = np.array([obs.pop(0) for _ in range(self._num_points + 2)], dtype=np.float32)
        white_vec[0] /= float(self._num_chex)
        white_vec[self._num_points + 1] /= float(self._num_chex)
        red_vec[0] /= float(self._num_chex)
        red_vec[self._num_points + 1] /= float(self._num_chex)
        die = obs.pop(0)
        chance_vec = np.zeros(self._num_points + 2, dtype=np.float32)
        if die == -1.0:
            chance_vec[-1] = 1.0
        else:
            chance_vec[int(die) - 1] = 1.0
        to_play_vec = np.zeros(self._num_points + 2, dtype=np.float32)
        prev_player = obs.pop(0)
        to_play_vec[int(prev_player) + 1] = 1.0
        curr_player = obs.pop(0)
        to_play_vec[int(curr_player) + 6] = 1.0

        state_feature = np.vstack((white_vec, red_vec, to_play_vec, chance_vec))
        state_feature = np.expand_dims(state_feature, axis=1)
        state_feature = np.expand_dims(state_feature, axis=0)
        return state_feature

    def final_target(self) -> TrainingTarget:
        terminal_player = 3
        tau_target = np.zeros((1, 4), dtype=np.float32)
        tau_target[0][terminal_player] = 1.0
        chance_policy_target = np.zeros((1, self._image_width), dtype=np.float32)
        chance_policy_target[0][self._pass_action] = 1.0
        choice_policy_target = np.zeros((1, self._image_width), dtype=np.float32)
        choice_policy_target[0][self._pass_action] = 1.0
        value_target = np.array([0.0], dtype=np.float32)
        target = TrainingTarget(tau_policy=tau_target,
                                value=value_target,
                                chance_policy=chance_policy_target,
                                choice_policy=choice_policy_target)
        return target

    def tau_policy_deterministic(self, player_id: int) -> TauPolicy:
        tau_policy = np.zeros((1, self._num_players), dtype=np.float32)
        tau_policy[0][player_id] = 1.0
        return tau_policy

    def chance_policy(self, probs: List[float]) -> ChancePolicy:
        chance_policy = np.zeros((1, self._image_width), dtype=np.float32)
        for i, prob in enumerate(probs):
            chance_policy[0][i] = prob
        return chance_policy

    def chance_policy_deterministic(self, spiel_action: SpielAction) -> ChoicePolicy:
        chance_policy = np.zeros((1, self._image_width), dtype=np.float32)
        chance_policy[0][spiel_action] = 1.0
        return chance_policy

    def choice_policy(self, probs: List[float]) -> ChoicePolicy:
        choice_policy = np.zeros((1, self._image_width), dtype=np.float32)
        for i, prob in enumerate(probs):
            choice_policy[0][i] = prob
        return choice_policy

    def choice_policy_deterministic(self, spiel_action: SpielAction) -> ChoicePolicy:
        choice_policy = np.zeros((1, self._image_width), dtype=np.float32)
        choice_policy[0][spiel_action] = 1.0
        return choice_policy

