import tensorflow as tf
from config.config_mahjong import Config
import warnings
from functools import reduce
from algorithms.base.model import Model as Model_Base
import os

USE_BN = True
MAX_LENGTH = 24
REPEAT = 3
DECAY = 0.99
CHANNEL = [32, 64, 128]
KERNEL = [3, 3, 3]


class Model(Model_Base):
    def __init__(self, with_random=True):
        super(Model, self).__init__()

        file_path = os.path.dirname(os.path.dirname(os.path.split(os.path.realpath(__file__))[0]))
        if os.path.exists(os.path.join(file_path, 'debug_actor.py')):
            self.type = 'actor'
            self.is_training = False
            self.data_format = 'NHWC'
        elif os.path.exists(os.path.join(file_path, 'debug_trainer.py')):
            self.type = 'learner'
            self.is_training = True
            self.data_format = 'NHWC'
        else:
            assert -1

        self.initializer = tf.orthogonal_initializer(gain=1.0)
        self.initializer = tf.variance_scaling_initializer(scale=1.0, mode='fan_in', distribution='truncated_normal')
        self.MACs = []
        self.MEMs = []

    def get_shape(self, x):
        return x.get_shape().as_list()

    def mul_all(self, mul_list):
        return reduce(lambda x, y: x * y, mul_list)

    def activation(self, x):
        return tf.nn.relu(x)

    def get_variable(self, shape, name, initializer=None, trainable=True):
        if initializer is None:
            initializer = self.initializer
        return tf.get_variable(name=name, shape=shape, initializer=initializer, trainable=trainable)

    def window(self, stride_or_ksize):
        if isinstance(stride_or_ksize, int):
            stride_or_ksize = [stride_or_ksize] * 2
        return [1, 1] + stride_or_ksize if self.data_format is 'NCHW' \
            else [1] + stride_or_ksize + [1]

    def reshape(self, x, shape=None):
        if shape is None:
            shape = [self.mul_all(self.get_shape(x)[1:])]
        shape = [-1] + shape
        x = tf.reshape(x, shape)
        return x

    def fc(self, x, c_out, bias=False, name='fc'):
        c_in = self.get_shape(x)[-1]
        W = self.get_variable([c_in, c_out], name)
        x = tf.matmul(x, W)
        if bias:
            b = self.get_variable([c_out], name + '_bias', initializer=tf.constant_initializer(0.0))
            x = x + b

        MACs = c_in * c_out
        MEMs = c_out
        self.MACs.append([name, MACs])
        self.MEMs.append([name, MEMs])

        return x

    def conv(self, x, ksize, c_out=None, stride=1, padding='SAME', bias=False, name='conv'):
        data_format = self.data_format
        shape_in = self.get_shape(x)
        c_in = shape_in[1] if data_format is 'NCHW' else shape_in[-1]
        if c_out is None: c_out = c_in
        if isinstance(ksize, int):
            ksize = [ksize, ksize]
        W = self.get_variable(ksize + [c_in, c_out], name)
        x = tf.nn.conv2d(x, W, self.window(stride), padding=padding, data_format=data_format, name=name)
        if bias:
            b = self.get_variable([c_out], name + '_b', initializer=tf.initializers.zeros)
            x = tf.nn.bias_add(x, b, data_format=data_format)

        shape_out = self.get_shape(x)
        MEMs = self.mul_all(shape_out[1:])
        MACs = c_in * ksize[0] * ksize[1] * MEMs
        self.MACs.append([name, MACs])
        self.MEMs.append([name, MEMs])

        return x

    def batch_norm(self, x, center=True, scale=True, decay=DECAY, epsilon=1e-3):
        if not USE_BN:
            return x
        x = tf.layers.batch_normalization(
            x,
            axis=-1 if self.data_format is 'NHWC' else 1,
            momentum=decay,
            epsilon=epsilon,
            center=center,
            scale=scale,
            training=self.is_training
        )

        shape_out = self.get_shape(x)
        MEMs = self.mul_all(shape_out[1:])
        self.MEMs.append(['batch_norm', MEMs])
        return x

    def _decay(self):
        return self._global_params.weight_decay * tf.add_n(
            [tf.nn.l2_loss(v) for v in tf.trainable_variables()
             if "batch_normalization" not in v.name and "bias" not in v.name])

    def count_parameters(self):
        dict_parameters = {}

        def dict_add(key, num):
            if key not in dict_parameters.keys():
                dict_parameters[key] = 0
            dict_parameters[key] += num

        key_list = ['batch_norm', 'conv', 'fc', 'emb']

        for var in tf.trainable_variables():
            print(var.device, var.op.name, var.shape.as_list())
            name_lowcase = var.op.name.lower()
            num = reduce(lambda x, y: x * y, var.get_shape().as_list())

            has_key = False
            for key in key_list:
                if key in name_lowcase:
                    dict_add(key, num)
                    has_key = True
                    break
            if not has_key:
                warnings.warn('Unknown parameter named ' + name_lowcase)
                dict_add(key, num)

        total = 0
        for _, value in dict_parameters.items():
            total += value
        print('Parameters:', total, dict_parameters)

        return dict_parameters

    def count_MACs(self):
        total = 0
        for MAC in self.MACs:
            total += MAC[1]
        print('MACs:', total)
        return total

    def count_MEMs(self):
        total = 0
        for MEM in self.MEMs:
            total += MEM[1]
        total = total * self.batch_size * 4 // (1024 * 1024)
        print('MEMs:', total)
        return total

    def split_data(self, feature_list):
        cards, flowers, actions, callings, readys, user_id = tf.split(feature_list, Config.SPLIT_SHAPE, axis=1)
        cards = tf.identity(cards, "cards")
        actions = tf.identity(actions, "actions")
        user_id = tf.identity(user_id, "user_id")

        cards = tf.reshape(cards[..., :-1], [-1, 7, 34])
        cards = tf.cast(cards[..., :MAX_LENGTH], tf.int32)
        flowers = tf.reshape(flowers, [-1, 2, 8])

        return cards, flowers, actions, callings, readys, user_id

    def inference(self, feature, legal_action):
        self.feature = tf.identity(feature, name="feature")
        self.feature_float = tf.to_float(feature, name="feature_float")
        self.cards, self.flowers, self.actions, self.callings, self.readys, self.user_id = self.split_data(
            self.feature_float)
        ####################################################
        self.legal_action = tf.identity(legal_action, name="legal_action")
        self.legal_action_float = tf.to_float(self.legal_action, name="legal_action_float")
        with tf.variable_scope("model"):
            self.h = self._model()
            self.pi_logits = self._create_policy_network()
            self.q_value = self._create_value_network()
        self.count_parameters()
        self.count_MACs()
        self.prob, self._action = self._prob(self.pi_logits)
        self.action = self._sample(self.pi_logits)
        self.neg_log_pi = self.neg_log_prob(self.action, "neg_log_pi_old")
        self.value = tf.reduce_sum(tf.multiply(self.q_value, self.prob), axis=-1)

    def neg_log_prob(self, action, name):
        action = tf.cast(action, tf.int32)
        one_hot_actions = tf.one_hot(action, self.action_dim)
        return tf.nn.softmax_cross_entropy_with_logits(
            logits=self.pi_logits,
            labels=one_hot_actions,
            dim=-1,
            name=name)

    def _prob(self, pi_logits):
        with tf.variable_scope('prob'):
            softmax_pi_logits = tf.nn.softmax(pi_logits, axis=-1)
            softmax_pi_logits = tf.log(softmax_pi_logits) / self.T
            exp_pi_logits = tf.exp(softmax_pi_logits)
            new_pi_logits = exp_pi_logits / tf.reduce_sum(exp_pi_logits, axis=-1, keep_dims=True)
            return new_pi_logits, tf.argmax(pi_logits, axis=-1)

    def _sample(self, pi_logits):
        if self.with_random:
            softmax_pi_logits = tf.nn.softmax(pi_logits, axis=-1)
            softmax_pi_logits = tf.log(softmax_pi_logits) / self.T
            exp_pi_logits = tf.exp(softmax_pi_logits)
            new_pi_logits = exp_pi_logits / tf.reduce_sum(exp_pi_logits, axis=-1, keep_dims=True)
            return tf.reshape(tf.multinomial(tf.log(new_pi_logits), 1), [-1])
        else:
            return tf.argmax(pi_logits, axis=-1, name="action")

    def _create_policy_network(self):
        with tf.variable_scope("policy_network"):
            x = self.fc(self.h, 512, name="policy_fc1")
            x = self.batch_norm(x)
            x = self.activation(x)
            x = self.fc(x, 512, name="policy_fc2")
            x = self.batch_norm(x)
            x = self.activation(x)
            x = self.fc(x, Config.ACTION_DIM, bias=True, name="policy_fc3")
            legal_action_flag_list_max_mask = (1 - self.legal_action_float) * tf.pow(10.0, 20.0)
            p_logits_after_mask = x - legal_action_flag_list_max_mask
            p_logits_after_mask = tf.identity(p_logits_after_mask, "policy_result_after_mask")
        return p_logits_after_mask

    def _create_value_network(self):
        with tf.variable_scope("value_network"):
            x = self.fc(self.h, 512, name="value_fc1")
            x = self.batch_norm(x)
            x = self.activation(x)
            x = self.fc(x, 512, name="value_fc2")
            x = self.batch_norm(x)
            x = self.activation(x)
            x = self.fc(x, Config.ACTION_DIM, bias=True, name="value_fc3")
        return x

    def onehot_emb(self, x, depth):
        x = tf.cast(x, tf.int32)
        x = tf.one_hot(x, depth=depth)
        x = tf.layers.flatten(x)
        return x

    def flower_emb(self, x, depth):
        x = tf.cast(x, tf.int32)
        x = tf.one_hot(x, depth=depth)
        x = tf.reduce_sum(x, axis=-2)
        x = tf.layers.flatten(x)
        return x

    def card_emb(self, x, depth):
        x = tf.one_hot(x, depth=depth)
        drops = tf.concat([x[:, 3, :], x[:, 6, :]], axis=1)
        oppo_conceal = x[:, 5:6, :]
        others = tf.concat([x[:, :3, :], x[:, 4:5, :]], axis=1)

        oppo_conceal = tf.reduce_sum(oppo_conceal, axis=-2)
        oppo_conceal = oppo_conceal // 4
        others = tf.reduce_sum(others, axis=-2)

        x = tf.concat([others, oppo_conceal, drops], axis=1)
        x = tf.concat([x[..., 9:18], x[..., 27:]], axis=-1)
        x = tf.stack([x - 3, x, x - 1, x - 2], axis=2)
        x = tf.clip_by_value(x, 0, 1)
        if self.data_format == 'NHWC':
            x = tf.transpose(x, [0, 2, 3, 1])
        return x

    def drop_emb(self, x, depth):
        x = tf.concat([x[:, 3, :], x[:, 6, :]], axis=1)
        x = tf.one_hot(x, depth=depth)
        return [x]

    # CNN base network
    def resnet(self, x, stage, repeat, channel, kernel):
        def residual(x, c_out, kernel, name='residual'):
            shortcut = x
            x = self.conv(x, kernel, c_out, name=name + "_conv1")
            x = self.batch_norm(x)
            x = self.activation(x)
            x = self.conv(x, kernel, c_out, name=name + "_conv2")
            x = self.batch_norm(x)
            x = self.activation(x)
            return x + shortcut

        def transition(x, c_out, name='transition'):
            x = self.conv(x, [1, 1], c_out, name=name + '_conv')
            x = self.batch_norm(x)
            x = self.activation(x)
            return x

        x = self.conv(x, kernel[0], channel[0], name='conv_init')
        x = self.batch_norm(x)
        x = self.activation(x)

        for i in range(stage):
            for j in range(repeat):
                x = residual(x, channel[i], kernel[i], name='S%dR%d' % (i, j))
            if i is not stage - 1:
                x = transition(x, channel[i + 1], name='T%d' % i)

        x = self.conv(x, 1, channel[0], name='conv_final')
        x = self.batch_norm(x)
        x = self.activation(x)
        x = tf.layers.flatten(x)
        return x

    def _model(self):
        x_concat = []

        x_concat.append(self.onehot_emb(self.user_id, Config.PLAYER_SIZE))
        x_concat.append(self.onehot_emb(self.readys, Config.PLAYER_SIZE))
        x_concat.append(self.flower_emb(self.flowers, 8))
        x_concat.append(self.onehot_emb(self.actions, Config.ACTION_DIM))

        self.onehot_concat = tf.concat(x_concat, axis=1, name='onehot_concat')

        # totally 53 channel: hand 1 chow-pong-kong 2 concealed-kong 2 drop 24*2
        self.card_concat = self.card_emb(self.cards, Config.CARD_DIM + 1)

        with tf.variable_scope('card'):
            x_card = self.resnet(self.card_concat,
                                 stage=3,
                                 repeat=REPEAT,
                                 channel=CHANNEL,
                                 kernel=KERNEL)
        x_concat.append(x_card)

        x_concat = tf.concat(x_concat, axis=1, name='x_concat')
        x_concat = self.fc(x_concat, 1024, name="concate_fc")
        x_concat = self.batch_norm(x_concat)
        x_concat = self.activation(x_concat)

        return x_concat
