import sys
import tensorflow as tf
from config.config_mahjong import Config
from algorithms.base.algorithm import Algorithm as Algorithm_Base


class Algorithm(Algorithm_Base):
    def __init__(self, model):
        super(Algorithm, self).__init__(model)
        self.init_clip_param = Config.INIT_CLIP_PARAM
        self.epsilon = Config.EPSILON
        self.max_step = Config.MAX_STEP
        self.alpha = Config.ALPHA
        self.beta = Config.BETA
        self.data_split_shape = Config.DATA_SPLIT_SHAPE
        self.weight_decay = Config.WEIGHT_DECAY

    def build_graph(self, datas, update):
        self.update = tf.cast(update, tf.float32)
        self.learning_rate = tf.train.cosine_decay(self.init_learning_rate, self.update, self.max_step)
        self.clip_range = tf.train.cosine_decay(self.init_clip_param, self.update, self.max_step)
        self._split_data(datas)
        input_shape = self.feature.get_shape().as_list()
        print("SHOW: input_shape ", end='')
        print(input_shape)
        self.model.inference(self.feature,self.sampled_legal_action)
        self._calculate_loss()
        return self.loss,\
               [
                   self.policy_loss,
                   self.value_loss,
                   self.entropy_loss,
                   tf.reduce_mean(self.sampled_reward_sum),
                   tf.reduce_mean(self.sampled_normalized_advantage),
                   self.learning_rate,
                   self.decay_loss,
               ]

    def _split_data(self, datas):
        self.feature, self.sampled_advantage, self.sampled_action, self.sampled_neg_log_pi, self.sampled_value,self.sampled_legal_action = tf.split(datas, self.data_split_shape, axis = 1)
        self.feature = tf.reshape(self.feature, [-1, Config.INPUT_DIM[0]])
        self.sampled_normalized_advantage = tf.reshape(self.sampled_advantage, [-1])
        self.sampled_action = tf.reshape(self.sampled_action, [-1])
        self.sampled_neg_log_pi = tf.reshape(self.sampled_neg_log_pi, [-1])
        self.sampled_value = tf.reshape(self.sampled_value, [-1])
        self.sampled_legal_action = tf.reshape(self.sampled_legal_action,[-1, Config.ACTION_DIM])
        self.sampled_reward_sum = self.sampled_value + self.sampled_normalized_advantage

        self.action_rank = self.sampled_action.shape.ndims
        self.feature.shape.assert_has_rank(self.action_rank + 1)
        self.sampled_legal_action.shape.assert_has_rank(self.action_rank+1)
        self.sampled_normalized_advantage.shape.assert_has_rank(self.action_rank)
        self.sampled_action.shape.assert_has_rank(self.action_rank)
        self.sampled_neg_log_pi.shape.assert_has_rank(self.action_rank)
        self.sampled_value.shape.assert_has_rank(self.action_rank)
        self.sampled_reward_sum.shape.assert_has_rank(self.action_rank)

    def _calculate_loss(self):
        self._calc_policy_loss()
        self._calc_value_loss()
        self._calc_entropy_loss()
        self._calc_decay_loss()
        self.loss = -(self.policy_loss - self.alpha * self.value_loss + self.beta * self.entropy_loss - self.decay_loss)

    def thresholded(self, logits, regrets, threshold=6.0):
        """Zeros out `regrets` where `logits` are too negative or too large."""
        can_decrease_1 = tf.cast(tf.greater(logits, -threshold), tf.float32)
        can_increase_1 = tf.cast(tf.less(logits, threshold), tf.float32)

        regrets_negative = tf.minimum(regrets, 0.0)
        regrets_positive = tf.maximum(regrets, 0.0)
        return can_decrease_1 * regrets_negative + can_increase_1 * regrets_positive

    def _calc_policy_loss(self):
        q_value = tf.stop_gradient(self.model.q_value)
        prob = tf.stop_gradient(self.model.prob)
        target = tf.reduce_sum(tf.multiply(prob, q_value), axis=-1, keepdims=True)
        advantage = q_value - target

        legal_pi_logit = tf.multiply(self.model.pi_logits, self.sampled_legal_action)
        legal_action_num = tf.reduce_sum(self.sampled_legal_action, axis=-1, keepdims=True)
        legal_pi_logit_sum = tf.reduce_sum(legal_pi_logit, axis=-1, keepdims=True)
        legal_pi_logit_mean = tf.divide(legal_pi_logit_sum, tf.maximum(legal_action_num, 1.0))

        pi_logit = self.model.pi_logits - legal_pi_logit_mean
        clipped_adv = self.thresholded(pi_logit, advantage)
        temp = tf.multiply(pi_logit, clipped_adv)

        self.policy_loss = tf.reduce_mean(tf.reduce_sum(tf.multiply(temp, self.sampled_legal_action), axis=-1),
                                          name="policy_loss")

    def _calc_value_loss(self):
        action = tf.expand_dims(tf.cast(self.sampled_action, tf.int32), 1)
        value = tf.batch_gather(self.model.q_value, action)
        value = tf.squeeze(value)
        value.shape.assert_has_rank(self.action_rank)
        clipped_value = tf.add(self.sampled_value,
                               tf.clip_by_value(value - self.sampled_value, - self.clip_range, self.clip_range),
                               name="clipped_value")
        self.value_loss = tf.multiply(0.5, tf.reduce_mean(
            tf.maximum(tf.square(value - self.sampled_reward_sum), tf.square(clipped_value - self.sampled_reward_sum))),
                                      name="vf_loss")

    def _calc_entropy_loss(self):
        pi_logits = self.model.pi_logits
        logits = pi_logits - tf.reduce_max(pi_logits, axis=-1, keepdims=True)
        exp_logits = tf.exp(logits)
        exp_logits_sum = tf.reduce_sum(exp_logits, axis=-1, keepdims=True)
        p = exp_logits / exp_logits_sum
        temp_entropy_loss = tf.reduce_sum(p * (tf.log(exp_logits_sum) - logits), axis=-1)
        self.entropy_loss = tf.reduce_mean(temp_entropy_loss, name="entropy_loss")

    def _calc_decay_loss(self):
        decay = self.weight_decay['decay']
        if decay > 1e-8:
            list_var = []
            exclude_keys = self.weight_decay['exclude']
            list_name = []
            vars = tf.trainable_variables()
            for var in vars:
                name_lowcase = var.op.name.lower()
                exclude = False
                for e in exclude_keys:
                    if e.lower() in name_lowcase:
                        exclude = True
                if not exclude:
                    list_var.append(tf.nn.l2_loss(var))
                    list_name.append(name_lowcase)
            print('Totally %d trainable parameters' % len(vars))
            print('Add L2 weight decay', decay, 'to following %d variables:' % len(list_name))
            print(list_name)
            self.decay_loss = decay * tf.add_n(list_var)
        else:
            print('No L2 weight decay')
            self.decay_loss = tf.constant(0.0, dtype=tf.float32)

    def get_optimizer(self):
        return tf.train.AdamOptimizer(learning_rate=self.learning_rate, epsilon=self.epsilon)
