import numpy as np
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
import random
from algorithms.mu_zero import utils, MuZeroEvaluator
from algorithms.utils.types import SpielGame, SpielState, SpielAction, ChoicePolicy, ChancePolicy, DynamicsInput, \
    TauPolicy, HiddenState, Subtrajectory, DynamicsTestResult, DynamicsTestLog
from algorithms.utils.params import Params
from typing import List, Optional


class MuZeroDynamicsLab(object):
    def __init__(self, game: SpielGame, evaluator: MuZeroEvaluator, params: Params):
        self._game = game
        self._pass_action = params.pass_action
        self._evaluator = evaluator
        self._k = params.k
        self._helper = params.extractor
        self._log = None  # type: Optional[DynamicsTestLog]

    def run_dynamics_test_single(self, log: DynamicsTestLog) -> None:
        self._log = log
        state = self._game.new_initial_state()  # type: SpielState
        states = [SpielState]
        hidden_states, legal_actions_all, actions, chance_node_flags = [], [], [], []
        game_has_started = False

        while not state.is_terminal():
            is_chance_node = state.is_chance_node()  # type: bool
            if is_chance_node:
                outcomes, probs = zip(*state.chance_outcomes())  # type: List[int], List[float]
                chance_action = np.random.choice(outcomes, p=probs)  # type: SpielAction
                if game_has_started:
                    states.append(state.clone())
                    state_feature = self._helper.state_feature_extractor(state)
                    hidden_state, _ = self._evaluator.rep_model(state_feature, training=False)
                    hidden_states.append(hidden_state)
                    legal_actions = state.legal_actions()  # type: List[int]
                    legal_actions_all.append(legal_actions)
                    actions.append(chance_action)
                    chance_node_flags.append(is_chance_node)
                state.apply_action(chance_action)
            else:
                game_has_started = True
                states.append(state.clone())
                state_feature = self._helper.state_feature_extractor(state)
                hidden_state, _ = self._evaluator.rep_model(state_feature, training=False)
                hidden_state = hidden_state.numpy()  # type: HiddenState
                hidden_states.append(hidden_state)
                legal_actions = state.legal_actions()  # type: List[int]
                legal_actions_all.append(legal_actions)
                player_action = random.choice(legal_actions)  # type: SpielAction
                actions.append(player_action)
                chance_node_flags.append(is_chance_node)
                state.apply_action(player_action)

        for i in range(len(hidden_states)):
            end = i + self._k + 1
            test_states = states[i: end]
            test_hidden_states = hidden_states[i: end]
            test_legal_actions = legal_actions_all[i: end]
            test_actions = actions[i: end]
            test_chance_node_flags = chance_node_flags[i: end]
            subtrajectory = list(zip(test_states, test_hidden_states, test_legal_actions,
                                     test_actions, test_chance_node_flags))
            self.analyze_subtrajectory_dynamics(subtrajectory)

    def test_pass_action(self, policy: np.ndarray) -> int:
        action_probs = []
        for action, prob in enumerate(policy.flatten()):
            action_probs.append((action, prob))
        action_probs = sorted(action_probs, key=lambda x: x[1], reverse=True)
        first_action = action_probs[0]
        if first_action[0] == self._pass_action and first_action[1] > 0.5:
            return 1
        return 0

    @staticmethod
    def test_single_choice(legal_actions_actual: List[int], choice_policy: ChoicePolicy, state: SpielState) -> int:
        action_probs = []
        for action, prob in enumerate(choice_policy.flatten()):
            action_probs.append((action, prob))
        action_probs = sorted(action_probs, key=lambda x: x[1], reverse=True)
        top = action_probs[0][0]
        if top in set(legal_actions_actual):
            return 1
        return 0

    @staticmethod
    def test_choice_top(legal_actions_actual: List[int], choice_policy: ChoicePolicy) -> int:
        action_probs = []
        for action, prob in enumerate(choice_policy.flatten()):
            action_probs.append((action, prob))
        action_probs = sorted(action_probs, key=lambda x: x[1], reverse=True)
        top = action_probs[:len(legal_actions_actual)]
        top = list(map(lambda x: x[0], top))
        if set(legal_actions_actual) == set(top):
            return 1
        return 0

    @staticmethod
    def test_choice_acc(legal_actions_actual: List[int], choice_policy: ChoicePolicy) -> int:
        illegal_actions = []
        avg_prob = 1 / len(choice_policy.flatten())
        for action, prob in enumerate(choice_policy.flatten()):
            if action not in legal_actions_actual and prob > avg_prob:
                illegal_actions.append(action)
        if len(illegal_actions) == 0:
            return 1
        return 0

    @staticmethod
    def test_tau_acc(tau_policy: TauPolicy, state: SpielState) -> int:
        real_player = state.current_player()
        player_probs = []
        for action, prob in enumerate(tau_policy.flatten()):
            player_probs.append((action, prob))
        player_probs = sorted(player_probs, key=lambda x: x[1], reverse=True)
        pred_player, _ = player_probs[0]
        if real_player == 0 and pred_player == 0:
            return 1
        if real_player == 1 and pred_player == 1:
            return 1
        if real_player == -1 and pred_player == 2:
            return 1
        if real_player == -4 and pred_player == 3:
            return 1
        return 0

    def analyze_subtrajectory_dynamics(self, subtrajectory: Subtrajectory) -> None:
        state, hidden_state, legal_actions, action, is_chance_node = subtrajectory.pop(0)
        while True:
            choice_logits, chance_logits, _ = self._evaluator.pred_model(hidden_state, training=False)
            choice_policy = utils.np_softmax(choice_logits.numpy())  # type: ChoicePolicy
            chance_policy = utils.np_softmax(chance_logits.numpy())  # type: ChancePolicy
            if is_chance_node:
                self._log['choice_pass_action'].append(self.test_pass_action(choice_policy))
            else:
                self._log['chance_pass_action'].append(self.test_pass_action(chance_policy))
                self._log['choice_strict_acc'].append(self.test_choice_acc(legal_actions, choice_policy))
                self._log['choice_top_acc'].append(self.test_choice_top(legal_actions, choice_policy))
                self._log['choice_single_acc'].append(self.test_single_choice(legal_actions, choice_policy, state))
            if len(subtrajectory) == 0:
                break
            action_image = self._helper.action_to_image(action)
            dynamics_input = tf.concat((hidden_state, action_image), axis=1)  # type: DynamicsInput
            hidden_state, tau_logits = self._evaluator.dyn_model(dynamics_input, training=False)
            tau_policy = utils.np_softmax(tau_logits.numpy())  # type: TauPolicy
            state, _, legal_actions, action, is_chance_node = subtrajectory.pop(0)
            self._log['tau_acc'].append(self.test_tau_acc(tau_policy, state))


