import sys
import tensorflow as tf
from config.config_mahjong import Config
from algorithms.base.algorithm import Algorithm as Algorithm_Base


class Algorithm(Algorithm_Base):
    def __init__(self, model):
        super(Algorithm, self).__init__(model)
        self.init_clip_param = Config.INIT_CLIP_PARAM
        self.epsilon = Config.EPSILON
        self.max_step = Config.MAX_STEP
        self.alpha = Config.ALPHA
        self.beta = Config.BETA
        self.data_split_shape = Config.DATA_SPLIT_SHAPE
        self.weight_decay = Config.WEIGHT_DECAY

    def build_graph(self, datas, update):
        self.update = tf.cast(update, tf.float32)
        self.learning_rate = tf.train.cosine_decay(self.init_learning_rate, self.update, self.max_step)
        self.clip_range = tf.train.cosine_decay(self.init_clip_param, self.update, self.max_step)
        self._split_data(datas)
        input_shape = self.feature.get_shape().as_list()
        print("SHOW: input_shape ", end='')
        print(input_shape)
        self.model.inference(self.feature,self.sampled_legal_action)
        self._calculate_loss()
        return self.loss,\
               [
                   self.policy_loss,
                   self.value_loss,
                   self.entropy_loss,
                   tf.reduce_mean(self.sampled_reward_sum),
                   tf.reduce_mean(self.sampled_normalized_advantage),
                   self.learning_rate,
                   self.decay_loss,
               ]

    def _split_data(self, datas):
        self.feature, self.sampled_advantage, self.sampled_action, self.sampled_neg_log_pi, self.sampled_value,self.sampled_legal_action = tf.split(datas, self.data_split_shape, axis = 1)
        self.feature = tf.reshape(self.feature, [-1, Config.INPUT_DIM[0]])
        self.sampled_normalized_advantage = tf.reshape(self.sampled_advantage, [-1])
        self.sampled_action = tf.reshape(self.sampled_action, [-1])
        self.sampled_neg_log_pi = tf.reshape(self.sampled_neg_log_pi, [-1])
        self.sampled_value = tf.reshape(self.sampled_value, [-1])
        self.sampled_legal_action = tf.reshape(self.sampled_legal_action,[-1, Config.ACTION_DIM])
        self.sampled_reward_sum = self.sampled_value + self.sampled_normalized_advantage

        self.action_rank = self.sampled_action.shape.ndims
        self.feature.shape.assert_has_rank(self.action_rank + 1)
        self.sampled_legal_action.shape.assert_has_rank(self.action_rank+1)
        self.sampled_normalized_advantage.shape.assert_has_rank(self.action_rank)
        self.sampled_action.shape.assert_has_rank(self.action_rank)
        self.sampled_neg_log_pi.shape.assert_has_rank(self.action_rank)
        self.sampled_value.shape.assert_has_rank(self.action_rank)
        self.sampled_reward_sum.shape.assert_has_rank(self.action_rank)

    def _calculate_loss(self):
        self._calc_policy_loss()
        self._calc_value_loss()
        self._calc_entropy_loss()
        self._calc_decay_loss()
        self.loss = -(self.policy_loss - self.alpha * self.value_loss + self.beta * self.entropy_loss - self.decay_loss)

    def _calc_policy_loss(self):
        is_train = tf.cast(tf.not_equal(tf.reduce_sum(self.sampled_legal_action, axis=-1), 1.0), tf.float32)
        neg_log_pi = self.model.neg_log_prob(self.sampled_action, "neg_log_pi")
        neg_log_pi.shape.assert_has_rank(self.action_rank)
        ratio = tf.exp(self.sampled_neg_log_pi - neg_log_pi, name="ratio")
        clipped_ratio = tf.clip_by_value(ratio, 0.0, 3.0)
        norm_adv = is_train * self.sampled_normalized_advantage
        surr1 = clipped_ratio * norm_adv
        surr2 = tf.clip_by_value(ratio, 1.0 - self.clip_range, 1.0 + self.clip_range) * norm_adv
        self.policy_loss = tf.reduce_mean(tf.minimum(surr1, surr2), name="policy_loss")

    def _calc_value_loss(self):
        value = self.model.value
        value.shape.assert_has_rank(self.action_rank)
        clipped_value = tf.add(
            self.sampled_value,
            tf.clip_by_value(value - self.sampled_value, - self.clip_range, self.clip_range), name="clipped_value")
        self.value_loss = \
            0.5 * tf.reduce_mean(
                tf.maximum(
                    tf.square(value - self.sampled_reward_sum),
                    tf.square(clipped_value - self.sampled_reward_sum)))

    def _calc_entropy_loss(self):
        pi_logits = self.model.pi_logits
        logits = pi_logits - tf.reduce_max(pi_logits, axis=-1, keepdims=True)
        exp_logits = tf.exp(logits)
        exp_logits_sum = tf.reduce_sum(exp_logits, axis = -1, keepdims = True)
        p = exp_logits / exp_logits_sum
        temp_entropy_loss = tf.reduce_sum(p * (tf.log(exp_logits_sum) - logits), axis = -1)
        self.entropy_loss = tf.reduce_mean(temp_entropy_loss, name="entropy_loss")

    def _calc_decay_loss(self):
        decay = self.weight_decay['decay']
        if decay > 1e-8:
            list_var = []
            exclude_keys = self.weight_decay['exclude']
            list_name = []
            vars = tf.trainable_variables()
            for var in vars:
                name_lowcase = var.op.name.lower()
                exclude = False
                for e in exclude_keys:
                    if e.lower() in name_lowcase:
                        exclude = True
                if not exclude:
                    list_var.append(tf.nn.l2_loss(var))
                    list_name.append(name_lowcase)
            print('Totally %d trainable parameters' % len(vars))
            print('Add L2 weight decay', decay, 'to following %d variables:' % len(list_name))
            print(list_name)
            self.decay_loss = decay * tf.add_n(list_var)
        else:
            print('No L2 weight decay')
            self.decay_loss = tf.constant(0.0, dtype=tf.float32)

    def get_optimizer(self):
        return tf.train.AdamOptimizer(learning_rate=self.learning_rate, epsilon=self.epsilon)
