from typing import Tuple, List

import numpy as np
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf

from algorithms.utils.types import TrainingBatch, RepresentationInput, TrainingTarget, ProcessedRootData, \
    HiddenStatesTensor, RecurrentData, TauPolicy, Value, ChancePolicy, ChoicePolicy, DynamicsInput, \
    ProcessedRecurrentData

np.set_printoptions(suppress=True)
np.seterr(divide='ignore')


def np_softmax(logits) -> np.ndarray:
    max_logit = np.amax(logits, axis=-1, keepdims=True)
    exp_logit = np.exp(logits - max_logit)
    return exp_logit / np.sum(exp_logit, axis=-1, keepdims=True)


def process_root_data(training_examples: TrainingBatch) -> ProcessedRootData:
    root_state_features, root_targets, root_ss, _, _, _ = training_examples
    representation_input: RepresentationInput
    representation_input = np.squeeze(np.array(root_state_features, dtype=np.float32), axis=1)
    return representation_input, root_targets, root_ss


def process_recurrent_data(hidden_states: HiddenStatesTensor, recurrent_data: RecurrentData) -> ProcessedRecurrentData:
    action_image_batch, recurrent_target_batch, recurrent_ss = recurrent_data
    action_image_batch = np.squeeze(np.array(action_image_batch, dtype=np.float32), axis=1)  # type: np.ndarray
    dynamics_input = tf.concat((hidden_states, action_image_batch), axis=1)  # type: DynamicsInput

    tau_targets: Tuple[TauPolicy, ...]
    value_targets: Tuple[Value, ...]
    chance_targets: Tuple[ChancePolicy, ...]
    choice_targets: Tuple[ChoicePolicy, ...]

    tau_targets, value_targets, chance_targets, choice_targets = zip(*recurrent_target_batch)
    value_targets = np.expand_dims(np.array(value_targets, dtype=np.float32), axis=1)  # type: np.ndarray
    return dynamics_input, tau_targets, chance_targets, choice_targets, value_targets, recurrent_ss


def extract_trainable_variables(models):
    trainable_variables = [variables
                           for variables_list in map(lambda n: n.trainable_variables, models)
                           for variables in variables_list]
    return trainable_variables


def get_opponent_id(player):
    return 1 - player


def np_masked_softmax(logits, legal_actions_mask) -> np.ndarray:
    """Returns the softmax over the valid actions defined by `legal_actions_mask`.

    Args:
      logits: A tensor [..., num_actions] (e.g. [num_actions] or [B, num_actions])
        representing the logits to mask.
      legal_actions_mask: The legal action mask, same shape as logits. 1 means
        it's a legal action, 0 means it's illegal.
    """
    masked_logits = logits + np.log(legal_actions_mask)
    max_logit = np.amax(masked_logits, axis=-1, keepdims=True)
    exp_logit = np.exp(masked_logits - max_logit)
    return exp_logit / np.sum(exp_logit, axis=-1, keepdims=True)