import sys
import tensorflow as tf
from config.config_mahjong import Config
from algorithms.base.algorithm import Algorithm as Algorithm_Base


class Algorithm(Algorithm_Base):
    def __init__(self, model):
        super(Algorithm, self).__init__(model)
        self.init_clip_param = Config.INIT_CLIP_PARAM
        self.epsilon = Config.EPSILON
        self.max_step = Config.MAX_STEP
        self.alpha = Config.ALPHA
        self.beta = Config.BETA
        self.data_split_shape = Config.DATA_SPLIT_SHAPE
        self.weight_decay = Config.WEIGHT_DECAY

    def build_graph(self, datas, update):
        self.update = tf.cast(update, tf.float32)
        self.learning_rate = tf.train.cosine_decay(self.init_learning_rate, self.update, self.max_step)
        self.clip_range = tf.train.cosine_decay(self.init_clip_param, self.update, self.max_step)
        self._split_data(datas)
        input_shape = self.feature.get_shape().as_list()
        print("SHOW: input_shape ", end='')
        print(input_shape)
        self.model.inference(self.feature,self.sampled_legal_action)
        self._calculate_loss()
        return self.loss,\
               [
                   self.policy_loss,
                   self.value_loss,
                   self.entropy_loss,
                   tf.reduce_mean(self.sampled_reward_sum),
                   tf.reduce_mean(self.sampled_normalized_advantage),
                   self.learning_rate,
                   self.decay_loss,
               ]

    def _split_data(self, datas):
        self.feature, self.sampled_advantage, self.sampled_action, self.sampled_neg_log_pi, self.sampled_value,self.sampled_legal_action = tf.split(datas, self.data_split_shape, axis = 1)
        self.feature = tf.reshape(self.feature, [-1, Config.INPUT_DIM[0]])
        self.sampled_normalized_advantage = tf.reshape(self.sampled_advantage, [-1])
        self.sampled_action = tf.reshape(self.sampled_action, [-1])
        self.sampled_neg_log_pi = tf.reshape(self.sampled_neg_log_pi, [-1])
        self.sampled_value = tf.reshape(self.sampled_value, [-1])
        self.sampled_legal_action = tf.reshape(self.sampled_legal_action,[-1, Config.ACTION_DIM])
        self.sampled_reward_sum = self.sampled_value + self.sampled_normalized_advantage

        self.action_rank = self.sampled_action.shape.ndims
        self.feature.shape.assert_has_rank(self.action_rank + 1)
        self.sampled_legal_action.shape.assert_has_rank(self.action_rank+1)
        self.sampled_normalized_advantage.shape.assert_has_rank(self.action_rank)
        self.sampled_action.shape.assert_has_rank(self.action_rank)
        self.sampled_neg_log_pi.shape.assert_has_rank(self.action_rank)
        self.sampled_value.shape.assert_has_rank(self.action_rank)
        self.sampled_reward_sum.shape.assert_has_rank(self.action_rank)

    def _calculate_loss(self):

        self._calc_policy_loss()
        self._calc_value_loss()
        self._calc_entropy_loss()
        self._calc_decay_loss()
        self.loss = -(self.policy_loss - self.alpha * self.value_loss + self.beta * self.entropy_loss - self.decay_loss)

    def thresholded(self, logits, regrets, ratio, threshold=6.0):
        """Zeros out `regrets` where `logits` are too negative or too large."""
        can_decrease_1 = tf.cast(tf.greater(logits, -threshold), tf.float32)
        can_increase_1 = tf.cast(tf.less(logits, threshold), tf.float32)

        can_decrease_2_1 = tf.cast(tf.greater(ratio, 1.0 - self.clip_range), tf.float32)
        can_decrease_2_2 = tf.cast(tf.less(ratio, 3.0), tf.float32)
        can_increase_2 = tf.cast(tf.less(ratio, 1.0 + self.clip_range), tf.float32)

        regrets_negative = tf.minimum(regrets, 0.0)
        regrets_positive = tf.maximum(regrets, 0.0)
        return can_decrease_1 * can_decrease_2_1 * can_decrease_2_2 * regrets_negative + \
               can_increase_1 * can_increase_2 * regrets_positive

    def _calc_policy_loss(self):
        is_train = tf.cast(tf.not_equal(tf.reduce_sum(self.sampled_legal_action, axis=-1), 1.0), tf.float32)
        neg_log_pi = self.model.neg_log_prob(self.sampled_action, "neg_log_pi")
        neg_log_pi.shape.assert_has_rank(self.action_rank)
        ratio = tf.exp(self.sampled_neg_log_pi - neg_log_pi, name="ratio")

        action = tf.expand_dims(tf.cast(self.sampled_action, tf.int32), 1)
        legal_pi_logit = tf.multiply(self.model.pi_logits, self.sampled_legal_action)
        self.legal_pi_logit = tf.reduce_min(tf.reduce_min(legal_pi_logit, axis=-1))
        legal_action_num = tf.reduce_sum(self.sampled_legal_action, axis=-1, keepdims=True)
        self.legal_action_num = tf.reduce_min(tf.reduce_sum(self.sampled_legal_action, axis=-1))
        legal_pi_logit_sum = tf.reduce_sum(legal_pi_logit, axis=-1, keepdims=True)
        legal_pi_logit_mean = tf.divide(legal_pi_logit_sum, tf.maximum(legal_action_num, 1.0))
        self.legal_pi_logit_mean = tf.reduce_max(tf.reshape(legal_pi_logit_mean, [-1]))
        pi_logit = tf.batch_gather(self.model.pi_logits - legal_pi_logit_mean, action)
        pi_logit = tf.reshape(pi_logit, [-1])
        self.pi_logit = tf.reduce_min(pi_logit)
        self.min_index = tf.arg_min(pi_logit, dimension=0)
        clipped_adv = self.thresholded(pi_logit, self.sampled_normalized_advantage, ratio)
        clipped_adv = is_train * clipped_adv
        self.policy_loss = tf.reduce_mean(
            tf.divide(tf.multiply(pi_logit, clipped_adv), tf.maximum(tf.exp(-1 * self.sampled_neg_log_pi), 0.01)),
            name="policy_loss")

    def _calc_value_loss(self):
        value = self.model.value
        value.shape.assert_has_rank(self.action_rank)
        clipped_value = tf.add(
            self.sampled_value,
            tf.clip_by_value(value - self.sampled_value, - self.clip_range, self.clip_range), name="clipped_value")
        self.value_loss = \
            0.5 * tf.reduce_mean(
                tf.maximum(
                    tf.square(value - self.sampled_reward_sum),
                    tf.square(clipped_value - self.sampled_reward_sum)))

    def _calc_entropy_loss(self):
        pi_logits = self.model.pi_logits
        logits = pi_logits - tf.reduce_max(pi_logits, axis=-1, keepdims=True)
        exp_logits = tf.exp(logits)
        exp_logits_sum = tf.reduce_sum(exp_logits, axis = -1, keepdims = True)
        p = exp_logits / exp_logits_sum
        temp_entropy_loss = tf.reduce_sum(p * (tf.log(exp_logits_sum) - logits), axis = -1)
        self.entropy_loss = tf.reduce_mean(temp_entropy_loss, name="entropy_loss")

    def _calc_decay_loss(self):
        decay = self.weight_decay['decay']
        if decay > 1e-8:
            list_var = []
            exclude_keys = self.weight_decay['exclude']
            list_name = []
            vars = tf.trainable_variables()
            for var in vars:
                name_lowcase = var.op.name.lower()
                exclude = False
                for e in exclude_keys:
                    if e.lower() in name_lowcase:
                        exclude = True
                if not exclude:
                    list_var.append(tf.nn.l2_loss(var))
                    list_name.append(name_lowcase)
            print('Totally %d trainable parameters' % len(vars))
            print('Add L2 weight decay', decay, 'to following %d variables:' % len(list_name))
            print(list_name)
            self.decay_loss = decay * tf.add_n(list_var)
        else:
            print('No L2 weight decay')
            self.decay_loss = tf.constant(0.0, dtype=tf.float32)

    def get_optimizer(self):
        return tf.train.AdamOptimizer(learning_rate=self.learning_rate, epsilon=self.epsilon)
