import tensorflow.keras.layers
from spektral.data import BatchLoader
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense
from spektral.layers import GINConvBatch, GlobalSumPool, GlobalMaxPool, GlobalAvgPool, DiffPool
import tensorflow as tf
from invertible_neural_networks.flow import NVP
from invertible_neural_networks.flow_BNN import NVP_BNN
from models.TransformerAE import Decoder
from models.base import PositionalEmbedding


class Graph_Model(Model):

    def __init__(self, n_hidden, mlp_hidden, activation: str, epochs, dropout=0., is_only_validation_data=False):
        super(Graph_Model, self).__init__()
        self.graph_conv = GINConvBatch(n_hidden, mlp_hidden=mlp_hidden, mlp_activation=activation, mlp_batchnorm=True,
                                       activation=activation)
        self.bn = tensorflow.keras.layers.BatchNormalization()
        self.pool = GlobalMaxPool()
        self.dropout = tensorflow.keras.layers.Dropout(dropout)
        if is_only_validation_data:
            self.dense = Dense(epochs)
        elif epochs == 12:
            self.dense = Dense(3 * epochs)  # (train_acc, valid_acc, test_acc) * 12 epochs
        elif epochs == 200:
            self.dense = Dense(2 * epochs)  # (train_acc, valid_acc) * 200 epochs
        else:
            raise NotImplementedError('epochs')

    def call(self, inputs):
        out = self.graph_conv(inputs)
        out = self.bn(out)
        out = self.pool(out)
        out = self.dropout(out)
        out = self.dense(out)
        return out


class GINEncoder(Model):
    def __init__(self, n_hidden, mlp_hidden, activation: str, dropout=0.):
        super(GINEncoder, self).__init__()
        self.graph_conv = GINConvBatch(n_hidden, mlp_hidden=mlp_hidden, mlp_activation=activation, mlp_batchnorm=True,
                                       activation=activation)
        self.bn = tensorflow.keras.layers.BatchNormalization()
        self.dropout = tensorflow.keras.layers.Dropout(dropout)
        self.mean = Dense(n_hidden)
        self.var = Dense(n_hidden)

    def call(self, inputs):
        out = self.graph_conv(inputs)
        out = self.bn(out)
        out = self.dropout(out)
        mean = self.mean(out)
        var = self.var(out)
        return mean, var


class TransformerDecoder(Decoder):
    def __init__(self, num_layers, d_model, num_heads, dff, input_length, num_ops, num_nodes, num_adjs,
                 dropout_rate=0.0):
        super(TransformerDecoder, self).__init__(num_layers, d_model, num_heads, dff, num_ops, num_nodes, num_adjs,
                                                 dropout_rate)
        self.pos_embedding = PositionalEmbedding(d_model=d_model, input_length=input_length)

        self.adj_cls = [
            Dense(2, activation='softmax')
            for _ in range(num_adjs)
        ]

        self.ops_cls = tf.keras.layers.Dense(num_ops, activation='softmax')
        self.adj_weight = tf.keras.layers.Dense(num_nodes, activation='relu')
        self.adj_cls = tf.keras.layers.Dense(2, activation='softmax')

    def call(self, x):
        x = self.pos_embedding(x)
        x = self.dropout(x)
        for i in range(self.num_layers):
            x = self.dec_layers[i](x)

        ops_cls = self.ops_cls(x)

        x = self.adj_weight(x)  # (8, 16) -> (8, 8)
        x = tf.reshape(x, (tf.shape(x)[0], -1, 1))  # (8, 8, 1)
        adj_cls = self.adj_cls(x)
        # ops_cls = tf.stack([self.ops_cls[i](flatten_x) for i in range(self.num_nodes)], axis=-1)
        # ops_cls = tf.transpose(ops_cls, (0, 2, 1))

        # adj_cls = tf.stack([self.adj_cls[i](flatten_x) for i in range(self.num_adjs)], axis=-1)
        # adj_cls = tf.transpose(adj_cls, (0, 2, 1))

        return ops_cls, adj_cls

class GraphAutoencoder(tf.keras.Model):
    def __init__(self, latent_dim, num_layers, d_model, num_heads, dff, num_ops, num_nodes, num_adjs, eps_scale=0.01,
                 dropout_rate=0.0):
        super(GraphAutoencoder, self).__init__()
        self.ckpt_weights = None
        self.d_model = d_model
        self.num_ops = num_ops
        self.num_adjs = num_adjs
        self.num_nodes = num_nodes
        self.latent_dim = latent_dim
        self.eps_scale = eps_scale
        self.encoder = GINEncoder(self.latent_dim, [128, 128, 128, 128], 'relu', dropout_rate)

        self.decoder = TransformerDecoder(num_layers=num_layers, d_model=d_model, num_heads=num_heads,
                                          dff=dff, input_length=num_nodes, num_ops=num_ops, num_nodes=num_nodes,
                                          num_adjs=num_adjs,
                                          dropout_rate=dropout_rate)

    def sample(self, mean, log_var, eps_scale=0.01):
        eps = tf.random.normal(shape=tf.shape(mean))
        return mean + tf.exp(log_var * 0.5) * eps * eps_scale

    def call(self, inputs, kl_reduction='mean'):
        latent_mean, latent_var = self.encoder(inputs)  # (batch_size, context_len, d_model)
        c = self.sample(latent_mean, latent_var, self.eps_scale)
        kl_loss = -0.5 * tf.reduce_sum(1 + latent_var - tf.square(latent_mean) - tf.exp(latent_var), axis=-1)

        if kl_reduction == 'mean':
            # (1)
            kl_loss = tf.reduce_mean(kl_loss)
        elif kl_reduction == 'none':
            # (batch_size)
            kl_loss = tf.reduce_mean(kl_loss, axis=-1)

        ops_cls, adj_cls = self.decoder(c)  # (batch_size, target_len, d_model)

        # Return the final output
        return ops_cls, adj_cls, kl_loss, latent_mean

    def encode(self, inputs):
        return self.encoder(inputs)

    def decode(self, inputs):
        ops_cls, adj_cls = self.decoder(inputs)
        ops = []
        for i in range(len(ops_cls)):
            ops.append(tf.argmax(ops_cls[i], axis=-1))
        adj = tf.cast(tf.argmax(adj_cls, axis=-1), tf.float32)
        return ops, adj, ops_cls, adj_cls

    def get_weights_to_self_ckpt(self):
        self.ckpt_weights = self.get_weights()

    def set_weights_from_self_ckpt(self):
        if self.ckpt_weights is None:
            raise ValueError('No weights to set')
        self.set_weights(self.ckpt_weights)


class GraphAutoencoderNVP(GraphAutoencoder):
    def __init__(self, nvp_config, latent_dim, num_layers, d_model, num_heads, dff, num_ops, num_nodes, num_adjs,
                 eps_scale=0.01, dropout_rate=0.0):
        super(GraphAutoencoderNVP, self).__init__(latent_dim, num_layers, d_model, num_heads, dff, num_ops,
                                                  num_nodes, num_adjs, eps_scale, dropout_rate)
        if nvp_config['inp_dim'] is None:
            nvp_config['inp_dim'] = latent_dim

        self.pad_dim = nvp_config['inp_dim'] - latent_dim * num_nodes
        self.nvp = NVP(**nvp_config)

    def call(self, inputs, kl_reduction='mean'):
        ops_cls, adj_cls, kl_loss, latent_mean = super().call(inputs, kl_reduction)
        latent_mean = tf.reshape(latent_mean, (tf.shape(latent_mean)[0], -1))
        latent_mean = tf.concat([latent_mean, tf.zeros((tf.shape(latent_mean)[0], self.pad_dim))], axis=-1)
        reg = self.nvp(latent_mean)
        return ops_cls, adj_cls, kl_loss, reg, latent_mean

    def inverse(self, z):
        return self.nvp.inverse(z)

class GraphAutoencoder2part(tf.keras.Model):
    def __init__(self, latent_dim, num_layers, d_model, num_heads, dff, num_ops, num_nodes, num_adjs, 
                 eps_scale=0.01, dropout_rate=0.0):
        super(GraphAutoencoder2part, self).__init__()
        self.ckpt_weights = None
        self.d_model = d_model
        self.num_ops = num_ops
        self.num_adjs = num_adjs
        self.num_nodes = num_nodes
        self.latent_dim = latent_dim
        self.eps_scale = eps_scale
        self.encoder = GINEncoder(self.latent_dim, [128, 128, 128, 128], 'relu', dropout_rate)

        self.decoder = TransformerDecoder(num_layers=num_layers, d_model=d_model, num_heads=num_heads,
                                         dff=dff, input_length=num_nodes // 2, num_ops=num_ops, 
                                         num_nodes=num_nodes // 2, num_adjs=num_adjs // 4, 
                                         dropout_rate=dropout_rate)

    def sample(self, mean, log_var, eps_scale=0.01):
        eps = tf.random.normal(shape=tf.shape(mean))
        return mean + tf.exp(log_var * 0.5) * eps * eps_scale

    def split_graph(self, op_list, adj_list):
        """
        拆分输入的操作矩阵和邻接矩阵为两个子图。

        Parameters:
            op_list (tf.Tensor): 操作矩阵列表，形状为 (batch_size, num_nodes, num_ops)。
            adj_list (tf.Tensor): 邻接矩阵列表，形状为 (batch_size, num_nodes, num_nodes)。

        Returns:
            tuple: 拆分后的操作矩阵列表和邻接矩阵列表。
        """
        half_nodes = self.num_nodes // 2
        split_op1 = op_list[:, :half_nodes, :]
        split_op2 = op_list[:, half_nodes:, :]
        split_adj1 = adj_list[:, :half_nodes, :half_nodes]
        split_adj2 = adj_list[:, half_nodes:, half_nodes:]
        return split_op1, split_op2, split_adj1, split_adj2

    def call(self, inputs, kl_reduction='mean'):
        """
        调用模型进行前向传播。

        Parameters:
            inputs (tuple): 包含操作矩阵列表和邻接矩阵列表的输入，形状为 (batch_size, num_nodes, num_ops), (batch_size, num_nodes, num_nodes)。
            kl_reduction (str): KL损失的聚合方式，'mean' 或 'none'。

        Returns:
            tuple: (ops_cls, adj_cls, kl_loss, combined_latent_mean)
        """
        op_list, adj_list = inputs
        split_op1, split_op2, split_adj1, split_adj2 = self.split_graph(op_list, adj_list)

        # 编码每个子图
        mean1, var1 = self.encoder((split_adj1, split_op1))
        mean2, var2 = self.encoder((split_adj2, split_op2))

        # 合并编码结果
        combined_mean = tf.concat([mean1, mean2], axis=-1)
        combined_var = tf.concat([var1, var2], axis=-1)

        # 采样
        c1 = self.sample(mean1, var1, self.eps_scale)
        c2 = self.sample(mean2, var2, self.eps_scale)

        # 计算 KL 损失
        kl_loss1 = -0.5 * tf.reduce_sum(1 + var1 - tf.square(mean1) - tf.exp(var1), axis=-1)
        kl_loss2 = -0.5 * tf.reduce_sum(1 + var2 - tf.square(mean2) - tf.exp(var2), axis=-1)
        kl_loss = kl_loss1 + kl_loss2

        if kl_reduction == 'mean':
            kl_loss = tf.reduce_mean(kl_loss)
        elif kl_reduction == 'none':
            kl_loss = tf.reduce_mean(kl_loss, axis=-1)

        # 解码
        ops_cls1, adj_cls1 = self.decoder(c1)
        ops_cls2, adj_cls2 = self.decoder(c2)
        ops_cls = tf.concat([ops_cls1, ops_cls2], axis=1)

        batch_size = tf.shape(adj_cls1)[0]
        half_num_nodes = self.num_nodes // 2

        # 将 adj_cls1 和 adj_cls2 还原成 11x11 形式
        adj_cls1 = tf.reshape(adj_cls1, (batch_size, half_num_nodes, half_num_nodes, -1))
        adj_cls2 = tf.reshape(adj_cls2, (batch_size, half_num_nodes, half_num_nodes, -1))

        # 创建零矩阵
        zeros_right = tf.zeros((batch_size, half_num_nodes, half_num_nodes, 2), dtype=adj_cls1.dtype)
        zeros_bottom = tf.zeros((batch_size, half_num_nodes, half_num_nodes, 2), dtype=adj_cls1.dtype)

        # 合并上半部分 (左上角是 adj_cls1，右上角是 zeros_right)
        top_half = tf.concat([adj_cls1, zeros_right], axis=2)

        # 合并下半部分 (左下角是 zeros_bottom，右下角是 adj_cls2)
        bottom_half = tf.concat([zeros_bottom, adj_cls2], axis=2)

        # 合并上下两部分
        adj_cls = tf.concat([top_half, bottom_half], axis=1)

        # 将合并后的邻接矩阵展平为 (batch_size, 484, 2)
        adj_cls = tf.reshape(adj_cls, (batch_size, -1, 2))

        return ops_cls, adj_cls, kl_loss, combined_mean

    def encode(self, inputs):
        """
        对输入进行编码，返回潜在均值和方差。

        Parameters:
            inputs (tuple): 包含操作矩阵列表和邻接矩阵列表的输入。

        Returns:
            tuple: (combined_mean, combined_var)
        """
        op_list, adj_list = inputs
        split_op1, split_op2, split_adj1, split_adj2 = self.split_graph(op_list, adj_list)

        # 编码每个子图
        mean1, var1 = self.encoder((split_adj1, split_op1))
        mean2, var2 = self.encoder((split_adj2, split_op2))

        # 合并编码结果
        combined_mean = tf.concat([mean1, mean2], axis=-1)
        combined_var = tf.concat([var1, var2], axis=-1)

        return combined_mean, combined_var

    def decode(self, latent_mean):
        """
        对潜在向量进行解码，返回操作和邻接矩阵。

        Parameters:
            latent_mean (tf.Tensor): 潜在向量，形状为 (batch_size, 22, latent_len)，即 (45, 22, 16)。

        Returns:
            tuple: (ops, adj, ops_cls, adj_cls)
        """
        # 获取批大小和潜在向量维度
        batch_size = tf.shape(latent_mean)[0]
        half_num_nodes = self.num_nodes // 2

        # 将 latent_mean 拆分成两个子向量，每个子向量形状为 (batch_size, 11, latent_len)
        latent_mean1 = latent_mean[:, :half_num_nodes, :]
        latent_mean2 = latent_mean[:, half_num_nodes:, :]

        # 分别对两个子向量进行解码
        ops_cls1, adj_cls1 = self.decoder(latent_mean1)
        ops_cls2, adj_cls2 = self.decoder(latent_mean2)

        # 对解码后的结果进行合并
        ops_cls = tf.concat([ops_cls1, ops_cls2], axis=1)

        # 将 adj_cls1 和 adj_cls2 还原成 11x11 形式
        adj_cls1 = tf.reshape(adj_cls1, (batch_size, half_num_nodes, half_num_nodes, -1))
        adj_cls2 = tf.reshape(adj_cls2, (batch_size, half_num_nodes, half_num_nodes, -1))

        # 创建零矩阵
        zeros_right = tf.zeros((batch_size, half_num_nodes, half_num_nodes, 2), dtype=adj_cls1.dtype)
        zeros_bottom = tf.zeros((batch_size, half_num_nodes, half_num_nodes, 2), dtype=adj_cls1.dtype)

        # 合并上半部分 (左上角是 adj_cls1，右上角是 zeros_right)
        top_half = tf.concat([adj_cls1, zeros_right], axis=2)

        # 合并下半部分 (左下角是 zeros_bottom，右下角是 adj_cls2)
        bottom_half = tf.concat([zeros_bottom, adj_cls2], axis=2)

        # 合并上下两部分
        adj_cls = tf.concat([top_half, bottom_half], axis=1)

        # 将合并后的邻接矩阵展平为 (batch_size, 484, 2)
        adj_cls = tf.reshape(adj_cls, (batch_size, -1, 2))

        # 获取最终的 adj 矩阵
        adj = tf.cast(tf.argmax(adj_cls, axis=-1), tf.float32)

        # 计算最终的 ops
        ops = []
        for i in range(len(ops_cls)):
            ops.append(tf.argmax(ops_cls[i], axis=-1))

        return ops, adj, ops_cls, adj_cls

    def get_weights_to_self_ckpt(self):
        self.ckpt_weights = self.get_weights()

    def set_weights_from_self_ckpt(self):
        if self.ckpt_weights is None:
            raise ValueError('No weights to set')
        self.set_weights(self.ckpt_weights)

class GraphAutoencoderNVP2part(GraphAutoencoder2part):
    def __init__(self, nvp_config, latent_dim, num_layers, d_model, num_heads, dff, num_ops, num_nodes, num_adjs,
                 eps_scale=0.01, dropout_rate=0.0):
        super(GraphAutoencoderNVP2part, self).__init__(latent_dim, num_layers, d_model, num_heads, dff, num_ops,
                                                  num_nodes, num_adjs, eps_scale, dropout_rate)
        if nvp_config['inp_dim'] is None:
            nvp_config['inp_dim'] = latent_dim

        self.pad_dim = nvp_config['inp_dim'] - latent_dim * num_nodes
        self.nvp = NVP(**nvp_config)

    def call(self, inputs, kl_reduction='mean'):
        ops_cls, adj_cls, kl_loss, latent_mean = super().call(inputs, kl_reduction)
        latent_mean = tf.reshape(latent_mean, (tf.shape(latent_mean)[0], -1))
        latent_mean = tf.concat([latent_mean, tf.zeros((tf.shape(latent_mean)[0], self.pad_dim))], axis=-1)
        reg = self.nvp(latent_mean)
        return ops_cls, adj_cls, kl_loss, reg, latent_mean

    def inverse(self, z):
        return self.nvp.inverse(z)

class GraphAutoencoderNVP_BNN(GraphAutoencoder):
    def __init__(self, nvp_config, latent_dim, num_layers, d_model, num_heads, dff, num_ops, num_nodes, num_adjs,
                 eps_scale=0.01, dropout_rate=0.0):
        super(GraphAutoencoderNVP_BNN, self).__init__(latent_dim, num_layers, d_model, num_heads, dff, num_ops,
                                                      num_nodes, num_adjs, eps_scale, dropout_rate)
        if nvp_config['inp_dim'] is None:
            nvp_config['inp_dim'] = latent_dim

        self.pad_dim = nvp_config['inp_dim'] - latent_dim * num_nodes
        from invertible_neural_networks.flow_BNN import NVP_BNN
        self.nvp = NVP_BNN(**nvp_config)

    def call(self, inputs):
        ops_cls, adj_cls, kl_loss, latent_mean = super().call(inputs)
        latent_mean = tf.reshape(latent_mean, (tf.shape(latent_mean)[0], -1))
        latent_mean = tf.concat([latent_mean, tf.zeros((tf.shape(latent_mean)[0], self.pad_dim))], axis=-1)
        reg = self.nvp(latent_mean)
        return ops_cls, adj_cls, kl_loss, reg, latent_mean

    def inverse(self, z):
        return self.nvp.inverse(z)


class GraphAutoencoderEnsembleNVP(GraphAutoencoder):
    def __init__(self, num_nvp, nvp_config, latent_dim, num_layers, d_model, num_heads, dff, num_ops, num_nodes,
                 num_adjs,
                 eps_scale=0.01, dropout_rate=0.0):
        super(GraphAutoencoderEnsembleNVP, self).__init__(latent_dim, num_layers, d_model, num_heads, dff, num_ops,
                                                          num_nodes, num_adjs, eps_scale, dropout_rate)
        if nvp_config['inp_dim'] is None:
            nvp_config['inp_dim'] = latent_dim

        self.num_nvp = num_nvp
        self.pad_dim = nvp_config['inp_dim'] - latent_dim * num_nodes
        self.nvp_list = [NVP(**nvp_config) for _ in range(num_nvp)]

    def call(self, inputs, kl_reduction='mean'):
        ops_cls, adj_cls, kl_loss, latent_mean = super().call(inputs, kl_reduction)
        latent_mean = tf.reshape(latent_mean, (tf.shape(latent_mean)[0], -1))
        latent_mean = tf.concat([latent_mean, tf.zeros((tf.shape(latent_mean)[0], self.pad_dim))], axis=-1)
        reg = tf.transpose(tf.stack([nvp(latent_mean) for nvp in self.nvp_list]), (1, 0, 2))
        return ops_cls, adj_cls, kl_loss, reg, latent_mean

    def inverse(self, z):
        return tf.transpose(tf.stack([nvp.inverse(z) for nvp in self.nvp_list]), (1, 0, 2))


class GraphAutoencoderEnsembleNVP_BNN(GraphAutoencoder):
    def __init__(self, num_nvp, nvp_config, latent_dim, num_layers, d_model, num_heads, dff, num_ops, num_nodes,
                 num_adjs,
                 eps_scale=0.01, dropout_rate=0.0):
        super(GraphAutoencoderEnsembleNVP_BNN, self).__init__(latent_dim, num_layers, d_model, num_heads, dff, num_ops,
                                                              num_nodes, num_adjs, eps_scale, dropout_rate)
        if nvp_config['inp_dim'] is None:
            nvp_config['inp_dim'] = latent_dim

        self.num_nvp = num_nvp
        self.pad_dim = nvp_config['inp_dim'] - latent_dim * num_nodes
        self.nvp_list = [NVP_BNN(**nvp_config) for _ in range(num_nvp)]

    def call(self, inputs):
        ops_cls, adj_cls, kl_loss, latent_mean = super().call(inputs)
        latent_mean = tf.reshape(latent_mean, (tf.shape(latent_mean)[0], -1))
        latent_mean = tf.concat([latent_mean, tf.zeros((tf.shape(latent_mean)[0], self.pad_dim))], axis=-1)
        reg = tf.transpose(tf.stack([nvp(latent_mean) for nvp in self.nvp_list]), (1, 0, 2))
        return ops_cls, adj_cls, kl_loss, reg, latent_mean

    def inverse(self, z):
        return tf.transpose(tf.stack([nvp.inverse(z) for nvp in self.nvp_list]), (1, 0, 2))


def bpr_loss(y_true, y_pred):
    N = tf.shape(y_true)[0]  # y_true.shape[0] = batch size
    lc_length = tf.shape(y_true)[1]

    total_loss = tf.constant([])

    for i in range(lc_length):
        tf.autograph.experimental.set_loop_options(
            shape_invariants=[(total_loss, tf.TensorShape([None]))]
        )
        loss_value = 0.0
        for j in range(N):
            loss_value += tf.reduce_sum(tf.keras.backend.switch(y_true[:, i] > y_true[j, i],
                                                                -tf.math.log(tf.sigmoid(y_pred[:, i] - y_pred[j, i])),
                                                                0))
        total_loss = tf.concat([total_loss, tf.expand_dims(loss_value, 0)], 0)

    return total_loss / tf.cast(N, tf.float32) ** 2


def get_rank_weight(y_true):
    N = tf.shape(y_true)[0]  # y_true.shape[0] = batch size
    rank = tf.subtract(y_true, tf.transpose(y_true))
    rank = tf.where(rank < 0, 1., 0.)
    rank = tf.reduce_sum(rank, axis=1)
    weight = tf.math.reciprocal(rank + tf.cast(N, tf.float32) * 10e-3)
    return weight


def weighted_mse(y_true, y_pred):
    mse = tf.keras.losses.mse(y_true, y_pred)
    weight = get_rank_weight(y_true)
    '''
    mse = tf.keras.losses.mse(y_true, y_pred)
    weight = []
    for i in range(N):
        rank = tf.cast(tf.reduce_sum(tf.where(y_true > y_true[i], 1, 0)), tf.float32)
        weight.append(1. / (tf.cast(N, tf.float32) * 10e-3 + rank))
    '''
    return tf.reduce_sum(tf.multiply(mse, weight))
