from __future__ import annotations
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
import functools, numpy as np
from algorithms.abstract.evaluator import Evaluator
from algorithms.alpha_zero import utils
from .dense import build_dense
from .utils import np_masked_softmax
from typing import TYPE_CHECKING
from algorithms.utils.types import LossValues
if TYPE_CHECKING:
    from algorithms.utils.types import LossValues, SpielState
    from ..utils.params import Params


class AlphaZeroEvaluator(Evaluator):
    """An AlphaZero MCTS Evaluator."""

    def __init__(self, params: 'Params'):
        """An AlphaZero MCTS Evaluator.

    Args:
      keras_model: a Keras Model object.
      l2_regularization: the amount of l2 regularization to use during training.
      optimizer: a TensorFlow optimizer object.
      device: The device used to run the keras_model during evaluation and
        training. Possible values are 'cpu', 'gpu', or a tf.device(...) object.
      feature_extractor: a function which takes as argument the game state and
        returns a numpy tensor which the keras_model can accept as input. If
        None, then the default features will be used, which is the
        observation_tensor() state method, reshaped to match the keras_model
        input shape (if possible). The keras_model is always evaluated on the
        output of this function.
    Raises:
      ValueError: if incorrect inputs are supplied.
    """

        Evaluator.__init__(self, params)

        self.model = build_dense(params)

    def set_weights(self, weights) -> None:
        self.model.set_weights(weights)

    def get_weights(self):
        return self.model.get_weights()

    @functools.lru_cache(maxsize=2 ** 12)
    def value_and_prior(self, state: SpielState):
        state_feature = self._extractor.state_feature_extractor(state)
        with self._device:
            policy, value = self.model(state_feature)

        # renormalize policy over legal actions
        policy = np.array(policy)[0]
        mask = np.array(state.legal_actions_mask())
        policy = np_masked_softmax(policy, mask)
        policy = [(action, policy[action]) for action in state.legal_actions()]

        # value is required to be array over players
        value = value[0, 0].numpy()
        if state.current_player() == 0:
            values = np.array([value, -value])
        elif state.current_player() == 1:
            values = np.array([-value, value])
        else:
            values = [0, 0]

        return values, policy

    def evaluate(self, state):
        return self.value_and_prior(state)[0]

    def prior(self, state):
        return self.value_and_prior(state)[1]

    def update_with_grads(self, optimizer, grads):
        optimizer.apply_gradients(zip(grads, self.model.trainable_variables))

    def update(self, training_examples, optimizer=None, k=None, use_scale=False):
        state_features = np.vstack([r.state_feature for r in training_examples])
        value_targets = np.vstack([r.target_value for r in training_examples])
        policy_targets = np.vstack([r.target_policy for r in training_examples])
        mse = tf.keras.losses.MeanSquaredError()

        with self._device:
            with tf.GradientTape() as tape:
                policy_logits, values = self.model(state_features, training=True)
                loss_policy = tf.nn.softmax_cross_entropy_with_logits(
                    logits=policy_logits, labels=tf.stop_gradient(policy_targets))
                loss_policy = tf.reduce_mean(loss_policy)
                loss_value = mse(values, tf.stop_gradient(value_targets))
                loss_l2 = 0
                for weights in self.model.trainable_variables:
                    loss_l2 += self._l2_regularization * tf.nn.l2_loss(weights)
                loss = loss_policy + loss_value + loss_l2

            grads = tape.gradient(loss, self.model.trainable_variables)

            if optimizer:
                optimizer.apply_gradients(zip(grads, self.model.trainable_variables))
        self.value_and_prior.cache_clear()

        return grads, LossValues(total=float(loss),
                          tau=float(0),
                          chance=float(0),
                          choice=float(loss_policy),
                          value=float(loss_value),
                          l2=float(loss_l2))
