import collections, pyspiel, tensorflow as tf, numpy as np
from enum import Enum
from typing import NewType, Tuple, List, NamedTuple
from typing_extensions import TypedDict

SpielAction = NewType('SpielAction', int)
Player = NewType('Player', int)
Value = NewType('Value', np.ndarray)
NumpyArray = NewType('NumpyArray', np.ndarray)
Policy = NewType('Policy', NumpyArray)
SpielGame = NewType('SpielGame', pyspiel.Game)
SpielState = NewType('SpielState', pyspiel.State)
QueueName = NewType('QueueName', str)
ChoicePolicy = NewType('ChoicePolicy', np.ndarray)
ChancePolicy = NewType('ChancePolicy', np.ndarray)
ActionImage = NewType('ActionImage', np.ndarray)
TauPolicy = NewType('TauPolicy', np.ndarray)
StateFeature = NewType('StateFeature', np.ndarray)
HiddenState = NewType('HiddenState', np.ndarray)
AdamOptimizer = NewType('AdamOptimizer', tf.keras.optimizers.Adam)

ModelWeights = Tuple[List[NumpyArray], ...]

TensorflowModel = NewType('TensorflowModel', tf.keras.Model)
RepresentationModel = NewType('RepresentationModel', tf.keras.Model)
PredictionModel = NewType('PredictionModel', tf.keras.Model)
DynamicsModel = NewType('DynamicsModel', tf.keras.Model)

AlphaZeroResult = collections.namedtuple("AlphaZeroResult", ['state_feature', 'target_value', 'target_policy', 'node'])

LossValues = collections.namedtuple("LossValues", ['total', 'tau', 'chance', 'choice', 'value', 'l2'])


class TrainingTarget(NamedTuple):
    tau_policy: TauPolicy
    value: Value
    chance_policy: ChancePolicy
    choice_policy: ChoicePolicy


RepresentationInput = NewType('RepresentationInput', np.ndarray)
ProcessedRootData = Tuple[RepresentationInput,
                          Tuple[TrainingTarget, ...],
                          Tuple[str, ...]]

GameSampleData = Tuple[List[StateFeature], List[ActionImage], List[TrainingTarget], List[str]]

TrainingBatch = Tuple[Tuple[StateFeature, ...],  # root state features
                      Tuple[TrainingTarget, ...],  # root training targets
                      Tuple[str, ...],  # root state strings
                      List[Tuple[ActionImage, ...]],  # recurrent action images
                      List[Tuple[TrainingTarget, ...]],  # recurrent training targets
                      List[Tuple[str, ...]]]  # state strings

TauLoss = NewType('TauLoss', tf.Tensor)
ChoiceLoss = NewType('ChoiceLoss', tf.Tensor)
ChanceLoss = NewType('ChanceLoss', tf.Tensor)
ValueLoss = NewType('ValueLoss', tf.Tensor)
LossTensors = Tuple[TauLoss, ChoiceLoss, ChanceLoss, ValueLoss]
HiddenStatesTensor = NewType('HiddenStatesTensor', tf.Tensor)
MSE = NewType('MSE', tf.keras.losses.MeanSquaredError)
RecurrentData = Tuple[Tuple[ActionImage, ...],
                      Tuple[TrainingTarget, ...],
                      Tuple[str, ...]]
DynamicsInput = NewType('DynamicsInput', tf.Tensor)

ProcessedRecurrentData = Tuple[DynamicsInput,
                               Tuple[TauPolicy, ...],
                               Tuple[ChancePolicy, ...],
                               Tuple[ChoicePolicy, ...],
                               np.ndarray,
                               Tuple[str, ...]]

Subtrajectory = List[Tuple[SpielState, HiddenState, List[int], SpielAction, bool]]
DynamicsTestResult = Tuple[List[int], List[int], List[int], List[int], List[int], List[int]]


class NodeType(Enum):
    CHANCE = 1
    CHOICE = 2
    TAU = 3
    TERMINAL = 4
    UNKNOWN = 5
    NONE = 6


class DynamicsTestLog(TypedDict):
    choice_pass_action: List[int]
    chance_pass_action: List[int]
    choice_strict_acc: List[int]
    choice_top_acc: List[int]
    choice_single_acc: List[int]
    tau_acc: List[int]


class MasterEvaluationLog(TypedDict):
    choice_pass_action: List[float]
    chance_pass_action: List[float]
    choice_strict_acc: List[float]
    choice_top_acc: List[float]
    choice_single_acc: List[float]
    tau_acc: List[float]
    step: List[int]
    bell: List[float]
    rand: List[float]


EvaluationResult = Tuple[int, List[int], List[int], List[DynamicsTestLog]]

