from __future__ import annotations
import os

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf, numpy as np
from algorithms.abstract.evaluator import Evaluator
from algorithms.mu_zero import utils, dense
from algorithms.mu_zero.utils import np_softmax, np_masked_softmax
from typing import TYPE_CHECKING, Tuple, List
from algorithms.utils.types import LossValues, HiddenState, ModelWeights, SpielAction, TrainingBatch, \
    AdamOptimizer, TauPolicy, Value, ChoicePolicy, ChancePolicy, HiddenStatesTensor, LossTensors, MSE, RecurrentData, \
    TauLoss, ChanceLoss, ValueLoss, ChoiceLoss, SpielState

if TYPE_CHECKING:
    from algorithms.utils.params import Params

tf.get_logger().setLevel('ERROR')


class MuZeroEvaluator(Evaluator):
    """"
    Handles network activations for the ND MuZero MCTS as well as model updates.
    """

    def __init__(self, params: Params, worker_id: int = 0) -> None:
        Evaluator.__init__(self, params, worker_id)
        models = dense.build_dense(params)
        self.rep_model, self.pred_model, self.dyn_model = models
        self._scale_gradient = params.scale_gradient

    def root_representation(self, state: SpielState) -> HiddenState:
        """
        Performs the `h` function of NDMZ. Uses the extractor to generate a state feature ndarray. Passes this through the
        representation model to generate a hidden state.
        """
        state_feature = self._extractor.state_feature_extractor(state)
        with self._device:
            hidden_state, _ = self.rep_model(state_feature, training=False)  # type: tf.Tensor, tf.Tensor
        return hidden_state.numpy()

    def get_weights(self) -> ModelWeights:
        """
        Returns the model weights as lists of numpy arrays.
        """
        return self.rep_model.get_weights(), self.pred_model.get_weights(), self.dyn_model.get_weights()

    def set_weights(self, model_weights: ModelWeights) -> None:
        """
        Sets the weights of all models. This is invoked to synchronize the states of all the
        workers, as well as after receiving model weights from the training service via AMQP.
        """
        rep_weights, pred_weights, dyn_weights = model_weights
        self.rep_model.set_weights(rep_weights)
        self.pred_model.set_weights(pred_weights)
        self.dyn_model.set_weights(dyn_weights)

    def root_prediction(self,
                        hidden_state: HiddenState,
                        actions: List[int],
                        mask) -> Tuple[float, List[int], List[float]]:
        """
        Used in MuZero MCTS at the root node. Performs the `f` function of NDMZ.
        Uses a masked softmax
        """
        with self._device:
            choice_logits, _, value = self.pred_model(hidden_state, training=False)
        priors = np_masked_softmax(choice_logits.numpy().flatten(), mask)
        priors = [priors[action] for action in actions]
        value = value[0, 0].numpy()
        return value, actions, priors

    def prediction(self, hidden_state: HiddenState) -> Tuple[np.ndarray, np.ndarray, float]:
        with self._device:
            choice_logits, chance_logits, value = self.pred_model(hidden_state, training=False)
        choice_priors = np_softmax(choice_logits.numpy().flatten())
        chance_priors = np_softmax(chance_logits.numpy().flatten())
        value = value[0, 0].numpy()
        return chance_priors, choice_priors, value

    def dynamics(self, hidden_state: HiddenState, action: SpielAction) -> Tuple[np.ndarray, np.ndarray]:
        action_image = self._extractor.action_to_image(action)
        dynamics_input = tf.concat((hidden_state, action_image), axis=1)
        with self._device:
            next_hidden_state, tau_logits = self.dyn_model(dynamics_input, training=False)
        tau_priors = np_softmax(tau_logits.numpy().flatten())
        return next_hidden_state.numpy(), tau_priors

    def get_root_loss_and_hidden(self,
                                 training_examples: TrainingBatch,
                                 mse: MSE) -> Tuple[LossTensors, HiddenStatesTensor]:
        root_data = utils.process_root_data(training_examples)
        representation_input, root_targets, root_ss = root_data

        root_tau_targets: Tuple[TauPolicy, ...]
        root_value_targets: Tuple[Value, ...]
        root_chance_targets: Tuple[ChancePolicy, ...]
        root_choice_targets: Tuple[ChoicePolicy, ...]
        root_tau_targets, root_value_targets, root_chance_targets, root_choice_targets = zip(*root_targets)

        hidden_states, root_tau_logits = self.rep_model(representation_input, training=True)
        root_tau_loss_all = tf.nn.softmax_cross_entropy_with_logits(
            logits=root_tau_logits, labels=tf.stop_gradient(tf.squeeze(root_tau_targets)))
        root_tau_loss = tf.reduce_mean(root_tau_loss_all)  # type: TauLoss

        root_choice_logits, root_chance_logits, root_values = self.pred_model(hidden_states, training=True)

        root_choice_loss_all = tf.nn.softmax_cross_entropy_with_logits(
            logits=root_choice_logits, labels=tf.stop_gradient(root_choice_targets))
        root_choice_loss = tf.reduce_mean(root_choice_loss_all)  # type: ChoiceLoss

        root_chance_loss_all = tf.nn.softmax_cross_entropy_with_logits(
            logits=root_chance_logits, labels=tf.stop_gradient(root_chance_targets))
        root_chance_loss = tf.reduce_mean(root_chance_loss_all)  # type: ChanceLoss
        root_value_loss = mse(root_values, tf.stop_gradient(root_value_targets))  # type: ValueLoss
        return (root_tau_loss, root_choice_loss, root_chance_loss, root_value_loss), hidden_states

    def get_recurrent_loss_and_hidden(self,
                                      hidden_states: HiddenStatesTensor,
                                      recurrent_data: RecurrentData,
                                      mse: MSE) -> Tuple[LossTensors, HiddenStatesTensor]:
        """
        Applie
        """
        processed_recurrent_data = utils.process_recurrent_data(hidden_states, recurrent_data)
        dynamics_input, tau_targets, _, _, _, ss = processed_recurrent_data
        hidden_states, tau_logits = self.dyn_model(dynamics_input, training=True)
        tau_loss_all = tf.nn.softmax_cross_entropy_with_logits(
            logits=tau_logits, labels=tf.stop_gradient(tf.squeeze(tau_targets)))
        tau_loss = tf.reduce_mean(tau_loss_all)  # type: TauLoss

        _, _, chance_targets, choice_targets, value_targets, _ = processed_recurrent_data
        choice_logits, chance_logits, values = self.pred_model(hidden_states, training=True)

        choice_loss_all = tf.nn.softmax_cross_entropy_with_logits(
            logits=choice_logits, labels=tf.stop_gradient(choice_targets))
        choice_loss = tf.reduce_mean(choice_loss_all)  # type: ChoiceLoss

        chance_loss_all = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
            logits=chance_logits, labels=tf.stop_gradient(chance_targets)))
        chance_loss = tf.reduce_mean(chance_loss_all)  # type: ChanceLoss

        value_loss = mse(values, tf.stop_gradient(value_targets))  # type: ValueLoss

        return (tau_loss, choice_loss, chance_loss, value_loss), hidden_states

    def update_with_grads(self, optimizer, grads):
        models = (self.rep_model, self.pred_model, self.dyn_model)
        trainable_variables = utils.extract_trainable_variables(models)
        optimizer.apply_gradients(zip(grads, trainable_variables))

    def update(self,
               training_examples: TrainingBatch,
               optimizer: AdamOptimizer,
               k=12) -> Tuple[List[np.array], LossValues]:
        def scale_gradient(tensor, scale: float):
            return (1. - scale) * tf.stop_gradient(tensor) + scale * tensor

        gradient_scale = 1.0 / (k + 1)
        models = (self.rep_model, self.pred_model, self.dyn_model)
        trainable_variables = utils.extract_trainable_variables(models)

        loss = 0
        tau_loss_total = 0
        choice_loss_total = 0
        chance_loss_total = 0
        value_loss_total = 0
        l2_loss_total = 0
        mse = tf.keras.losses.MeanSquaredError()

        with self._device:
            with tf.GradientTape() as tape:
                # We first get the losses for the root state, and the hidden states
                # from the representation model which will be used
                losses, hidden_states = self.get_root_loss_and_hidden(training_examples, mse)
                tau_loss, choice_loss, chance_loss, value_loss = losses
                tau_loss_total += tau_loss
                value_loss_total += value_loss
                chance_loss_total += chance_loss
                choice_loss_total += choice_loss

                _, _, _, action_images, recurrent_targets, recurrent_ss = training_examples
                for recurrent_data in zip(action_images, recurrent_targets, recurrent_ss):
                    losses, hidden_states = self.get_recurrent_loss_and_hidden(hidden_states, recurrent_data, mse)
                    tau_loss, choice_loss, chance_loss, value_loss = losses
                    tau_loss_total += tau_loss
                    chance_loss_total += chance_loss
                    choice_loss_total += choice_loss
                    value_loss_total += value_loss

                    unscaled_loss = tau_loss + chance_loss + choice_loss + value_loss
                    if self._scale_gradient:
                        loss += scale_gradient(unscaled_loss, gradient_scale)
                        hidden_states = scale_gradient(hidden_states, 0.5)
                    else:
                        loss += unscaled_loss

                l2_loss = 0
                for weights in trainable_variables:
                    l2_loss += self._l2_regularization * tf.nn.l2_loss(weights)
                l2_loss_total += l2_loss
                loss += l2_loss
            grads = tape.gradient(loss, trainable_variables)

            if optimizer:
                optimizer.apply_gradients(zip(grads, trainable_variables))

        grads = [grad.numpy() for grad in grads]
        return grads, LossValues(total=float(loss),
                                 tau=float(tau_loss_total),
                                 chance=float(chance_loss_total),
                                 choice=float(choice_loss_total),
                                 value=float(value_loss_total),
                                 l2=float(l2_loss_total))
