import numpy as np
import torch
import dgl
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

from . import DeepSet, GAT, RRWPEncoder



def init_weights(m):
    if isinstance(m, nn.Linear):
        init.kaiming_normal_(m.weight, nonlinearity='leaky_relu')
        # init.xavier_uniform_(m.weight)
        if m.bias is not None:
            init.constant_(m.bias, 0)


class AutoVisualNet(nn.Module):
    def __init__(self, input_dim=16, hidden_pe_dim=64, gnn_hidden=128, gnn_out=32, out_dim=2, n_graph_view=5, device='cuda:0'):
        super().__init__()
        self.sub_batch_size = 1

        self.input_dim, self.gnn_hidden, self.gnn_out, self.out_dim = input_dim, gnn_hidden, gnn_out, out_dim
        self.hidden_pe_dim = hidden_pe_dim

        self.device = device

        self.rrwp_encoder = RRWPEncoder.RRWPLinearNodeEncoder(self.input_dim, self.hidden_pe_dim)

        # self.deepset = DeepSet.DeepSet(1, self.deepset_dim).to(device)
        self.gin = GAT.GAT(self.hidden_pe_dim, self.gnn_hidden, self.gnn_out).to(device)

        self.n_graph_view = n_graph_view

        self.mlp = nn.Sequential(
            # nn.LayerNorm(self.gnn_out*self.n_graph_view),
            nn.Linear(self.gnn_out*self.n_graph_view, self.gnn_out),
            nn.SELU(),
            nn.Linear(self.gnn_out, self.out_dim)
        )
        self.n_transforms = 3
        self.gt_blocks = self.build_transformer_blocks()


    def build_transformer_blocks(self):
        gt_blocks = torch.nn.ModuleList()
        for i in range(self.n_transforms):
            gt_blocks.append(nn.TransformerEncoderLayer(d_model=self.gnn_out*self.n_graph_view, nhead=8))
        return gt_blocks


        # self.reset_parameters()

    def deepset_forward(self, cdists):
        n_ds = len(cdists)
        node_features = []
        for sub_batch_idx in range(n_ds // self.sub_batch_size):
            _b_cdists = cdists[sub_batch_idx * self.sub_batch_size: (sub_batch_idx + 1) * self.sub_batch_size]
            _b_n_nodes = [c.shape[0] for c in _b_cdists]
            _b_block_size = []
            for _n in _b_n_nodes:
                _b_block_size += [_n] * _n
            _b_block_size = torch.from_numpy(np.array(_b_block_size)).to(self.device)
            _b_cdists_merged = []
            for c in _b_cdists:
                _b_cdists_merged.append(c.reshape(-1))
            _b_cdists_merged = torch.cat(_b_cdists_merged).to(self.device)
            _, _node_f = self.deepset(_b_cdists_merged.unsqueeze(1), _b_block_size)
            node_features.append(_node_f)
        # node_features = torch.concatenate(node_features)
        return node_features

    def gin_forward(self, graphs, node_features):
        # graphs: [[dgl graph_i]*n_graph_view, ...]
        n_ds = len(graphs)
        outputs = []
        for sub_batch_idx in range(n_ds // self.sub_batch_size):
            _batch_graphs = dgl.batch(graphs[sub_batch_idx]).to(self.device)  # todo: crash when sub_batch_size > 1
            # _batch_graphs = graphs
            batch_pe = torch.cat(node_features[sub_batch_idx]).to(self.device)
            batch_pe = self.rrwp_encoder(batch_pe)
            output = self.gin(_batch_graphs, batch_pe)
            outputs.append(torch.cat(output.chunk(self.n_graph_view), dim=1))
        outputs = torch.concatenate(outputs)
        return outputs

    def forward(self, x, graphs):
        # set_enc = self.deepset_forward(cdists)

        n_view_out = self.gin_forward(graphs, x)  # (n_node_i*n_graph, d*n_view)
        for blk in self.gt_blocks:
            n_view_out = blk(n_view_out)
        out = self.mlp(n_view_out)
        return out

    def reset_parameters(self):
        """ Initialize the weights and bias.
        :return: None
        """
        self.apply(init_weights)

    def loss_fn_kl(self, x, graph, zdist):
        # inv + sftmax + sym/
        def get_zdist_prob(zdist):
            # zdist = torch.cdist(z, z)
            inv_zdist = 1 / (zdist + 1e-9)
            # inv_zdist = -1 * zdist  # this one!
            mask = torch.eye(inv_zdist.shape[0], dtype=torch.bool, device=zdist.device)
            inv_zdist = inv_zdist.masked_fill(mask, float('-inf'))
            p = F.softmax(inv_zdist, dim=1) + 1e-9
            # p = (p + p.t()) / (2.0 * zdist.shape[0]) + 1e-9
            p = p.masked_fill(mask, float(1.0))
            return p
        p = get_zdist_prob(zdist)
        z_hat = self.forward(x, graph)
        zhat_dist = torch.cdist(z_hat, z_hat)
        q = get_zdist_prob(zhat_dist)

        kl = p * (torch.log(p) - torch.log(q))
        kl_loss = kl.sum(dim=1).mean()
        return kl_loss

    def loss_fn_kl2(self, x, graph, zdist):
        # guassian kernel + sftmax + non-sym
        def get_zdist_prob(zdist):
            # zdist = torch.cdist(z, z)
            # inv_zdist = 1 / (zdist + 1e-9)
            inv_zdist = -1 * zdist  # this one!
            mask = torch.eye(inv_zdist.shape[0], dtype=torch.bool, device=zdist.device)
            inv_zdist = inv_zdist.masked_fill(mask, float('-inf'))
            p = F.softmax(inv_zdist, dim=1) + 1e-9
            # p = (p + p.t()) / (2.0 * zdist.shape[0]) + 1e-9
            p = p.masked_fill(mask, float(1.0))
            return p
        p = get_zdist_prob(zdist)
        z_hat = self.forward(x, graph)
        zhat_dist = torch.cdist(z_hat, z_hat)
        q = get_zdist_prob(zhat_dist)

        kl = p * (torch.log(p) - torch.log(q))
        kl_loss = kl.sum(dim=1).mean()
        return kl_loss

    def loss_fn_kl_inverse_only(self, x, graph, zdist):
        # inv + sftmax + sym/
        def get_zdist_prob(zdist):
            # zdist = torch.cdist(z, z)
            inv_zdist = 1 / (zdist + 1e-9)
            # inv_zdist = -1 * zdist  # this one!
            mask = torch.eye(inv_zdist.shape[0], dtype=torch.bool, device=zdist.device)
            # inv_zdist = inv_zdist.masked_fill(mask, float('-inf'))
            # p = F.softmax(inv_zdist, dim=1) + 1e-9
            # p = (p + p.t()) / (2.0 * zdist.shape[0]) + 1e-9
            p = inv_zdist
            p = p.masked_fill(mask, float(1.0))
            return p
        p = get_zdist_prob(zdist)
        z_hat = self.forward(x, graph)
        zhat_dist = torch.cdist(z_hat, z_hat)
        q = get_zdist_prob(zhat_dist)

        kl = p * (torch.log(p) - torch.log(q))
        kl_loss = kl.sum(dim=1).mean()
        return kl_loss

    def loss_fn_kl_t(self, x, graph, zdist, df=1):
        def get_zdist_prob_student_t(zdist, degrees_of_freedom=1):
            dist = zdist ** 2
            dist /= degrees_of_freedom
            dist += 1.
            dist **= (degrees_of_freedom + 1.0) / -2.0
            mask = torch.eye(dist.shape[0], dtype=torch.bool, device=zdist.device)
            dist = dist.masked_fill(mask, float(0.0))
            Q = torch.maximum(dist / (torch.sum(dist)), torch.Tensor([1e-9]).to(zdist.device))
            # print(torch.sum(dist))
            Q = Q.masked_fill(mask, float(1.0))
            return Q

        p = get_zdist_prob_student_t(zdist, degrees_of_freedom=df)
        z_hat = self.forward(x, graph)
        zhat_dist = torch.cdist(z_hat, z_hat)
        q = get_zdist_prob_student_t(zhat_dist, degrees_of_freedom=df)

        kl = p * (torch.log(p) - torch.log(q))
        # kl_loss = kl.sum(dim=1).mean()
        kl_loss = 2 * (df + 1) / df * kl.sum()
        return kl_loss

    def loss_fn_kl_t2(self, x, graph, zdist, df=1):
        def get_zdist_prob_student_t(zdist, degrees_of_freedom=1):
            dist = zdist ** 2
            dist /= degrees_of_freedom
            dist += 1.
            dist **= (degrees_of_freedom + 1.0) / -2.0
            mask = torch.eye(dist.shape[0], dtype=torch.bool, device=zdist.device)
            dist = dist.masked_fill(mask, float(0.0))
            Q = torch.maximum(dist / (torch.sum(dist, dim=1)), torch.Tensor([1e-9]).to(zdist.device))
            # print(torch.sum(dist))
            Q = Q.masked_fill(mask, float(1.0))
            return Q

        p = get_zdist_prob_student_t(zdist)
        z_hat = self.forward(x, graph)
        zhat_dist = torch.cdist(z_hat, z_hat)
        q = get_zdist_prob_student_t(zhat_dist)

        kl = p * (torch.log(p) - torch.log(q))
        kl_loss = (df + 1) / df * kl.sum(dim=1).mean()
        # kl_loss = 2 * (df + 1) / df * kl.sum()
        return kl_loss


def list_collate_fn(items):
    transposed = zip(*items)
    return list(transposed)


if __name__ == '__main__':
    # import GraphDatasets
    # from dgl.dataloading import GraphDataLoader
    #
    # ds = GraphDatasets.DatasetGraphDataset(data_names=['mnist_group1'], cdist_path='../prepare_data/clip/features',
    #                          visual_path='../prepare_data/bo/res-2')
    #
    # net = AutoVisualNet().to('cuda:0')
    #
    # train_loader = torch.utils.data.DataLoader(ds,
    #                                batch_size=2,
    #                                shuffle=True,
    #                                num_workers=0,
    #                                collate_fn=lambda x: list(zip(*x))
    #                                )
    # it = iter(train_loader)
    # a = next(it)
    #
    # # out = net([ds[0][0], ds[1][0]], [ds[0][1], ds[1][1]])
    # out = net(a[0], a[1])
    #
    # print(out.shape)
    pass



