import numpy as np
import torch
import dgl
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

from . import DeepSet, GIN, RRWPEncoder, GAT, SignNet2



def init_weights(m):
    if isinstance(m, nn.Linear):
        init.kaiming_normal_(m.weight, nonlinearity='leaky_relu')
        # init.xavier_uniform_(m.weight)
        if m.bias is not None:
            init.constant_(m.bias, 0)


class AutoVisualNet(nn.Module):
    def __init__(self, input_dim=16, gnn_hidden=128, gnn_out=32, out_dim=2, n_graph_view=1, device='cuda:0'):
        super().__init__()
        self.sub_batch_size = 1

        self.input_dim, self.gnn_hidden, self.gnn_out, self.out_dim = input_dim, gnn_hidden, gnn_out, out_dim
        # self.hidden_pe_dim = hidden_pe_dim

        self.device = device

        # self.rrwp_encoder = RRWPEncoder.RRWPLinearNodeEncoder(self.input_dim, self.hidden_pe_dim)

        # self.deepset = DeepSet.DeepSet(1, self.deepset_dim).to(device)
        self.n_graph_view = n_graph_view

        self.use_sep_gnn = True
        # self.sign_nets_= SignNet2.GINDeepSigns(in_channels=input_dim, hidden_channels=gnn_hidden, out_channels=self.input_dim,
        #                                                   num_layers=4,
        #                                        # k=self.input_dim, dim_pe=self.input_dim,
        #                                        #            rho_hidden_channels=self.gnn_hidden
        #                                        ).to(device)
        if self.use_sep_gnn:
            self.gnns = nn.ModuleDict()
            self.sign_nets = nn.ModuleDict()
            for i in range(self.n_graph_view):
                self.gnns[str(i)] = GIN.GIN(self.input_dim, self.gnn_hidden, self.gnn_out).to(device)
                self.sign_nets[str(i)] = SignNet2.GINDeepSigns(in_channels=input_dim, hidden_channels=gnn_hidden,
                                                               out_channels=self.input_dim,
                                                          num_layers=4,).to(device)

                # self.gnns[str(i)] = GIN.NLayerGIN(self.input_dim, self.gnn_hidden, self.gnn_out, n_layers=3).to(device)
                # self.gnns[str(i)] = GAT.GAT(self.input_dim, self.gnn_hidden, self.gnn_out).to(device)
        else:
            self.gin = GIN.GIN(self.input_dim, self.gnn_hidden, self.gnn_out).to(device)

        self.mlp = nn.Sequential(
            #
            nn.Linear(self.gnn_out*self.n_graph_view, self.gnn_out),
            nn.SELU(),
            nn.LayerNorm(self.gnn_out),
            nn.Linear(self.gnn_out, self.out_dim)
        )
        self.n_transforms = 5
        self.gt_blocks = self.build_transformer_blocks()

        # self.head = GIN.NLayerGIN(self.gnn_out*self.n_graph_view, self.gnn_out, self.out_dim, n_layers=5)


    def build_transformer_blocks(self):
        gt_blocks = torch.nn.ModuleList()
        for i in range(self.n_transforms):
            gt_blocks.append(nn.TransformerEncoderLayer(d_model=self.gnn_out*self.n_graph_view, nhead=8))
        return gt_blocks


        # self.reset_parameters()

    def gin_forward(self, graphs, node_features):
        # graphs: [[dgl graph_i]*n_graph_view, ...]
        #
        # _batch_graphs = dgl.batch(graphs).to(self.device)
        # batch_pe = torch.cat(node_features).to(self.device)
        #
        # # _batch_graphs = graphs
        # # batch_pe = node_features
        # output = self.gin(_batch_graphs, batch_pe)

        n_ds = len(graphs)
        outputs = []
        for sub_batch_idx in range(n_ds // self.sub_batch_size):
            _graphs = graphs[sub_batch_idx] # todo: crash when sub_batch_size > 1
            _node_features = node_features[sub_batch_idx]
            output = []
            for i in range(self.n_graph_view):
                _g = _graphs[i].to(self.device)
                _f = _node_features[i].to(self.device)
                _output = self.sign_nets[str(i)](_g, _f)
                # _output = self.sign_nets_(_g, _f)
                # _output = torch.cat([_output, _f], dim=1)
                _output = self.gnns[str(i)](_g, _output)
                # output.append(_output)
                output.append(_output)
            outputs.append(torch.cat(output, dim=1))
        outputs = torch.concatenate(outputs)
        return outputs

    def gin_forward_uni_gnn(self, graphs, node_features):
        # graphs: [[dgl graph_i]*n_graph_view, ...]
        n_ds = len(graphs)
        outputs = []
        for sub_batch_idx in range(n_ds // self.sub_batch_size):
            _batch_graphs = dgl.batch(graphs[sub_batch_idx]).to(self.device)  # todo: crash when sub_batch_size > 1
            # _batch_graphs = graphs
            batch_pe = torch.cat(node_features[sub_batch_idx]).to(self.device)
            # batch_pe = self.rrwp_encoder(batch_pe)
            output = self.gin(_batch_graphs, batch_pe)
            outputs.append(torch.cat(output.chunk(self.n_graph_view), dim=1))
        outputs = torch.concatenate(outputs)
        return outputs

    def forward(self, x, graphs):
        # set_enc = self.deepset_forward(cdists)
        if self.use_sep_gnn:
            n_view_out = self.gin_forward(graphs, x)  # (n_node_i*n_graph, d*n_view)
        else:
            n_view_out = self.gin_forward_uni_gnn(graphs, x)  # (n_node_i*n_graph, d*n_view)
        for blk in self.gt_blocks:
            n_view_out = blk(n_view_out)

        # out = self.head(graphs[0][2].to(self.device), n_view_out)
        out = self.mlp(n_view_out)
        return out

    def reset_parameters(self):
        """ Initialize the weights and bias.
        :return: None
        """
        self.apply(init_weights)

    def loss_fn_kl(self, x, graph, zdist):
        # inv + sftmax + sym/
        def get_zdist_prob(zdist):
            # zdist = torch.cdist(z, z)
            inv_zdist = 1 / (zdist + 1e-9)
            # inv_zdist = -1 * zdist  # this one!
            mask = torch.eye(inv_zdist.shape[0], dtype=torch.bool, device=zdist.device)
            inv_zdist = inv_zdist.masked_fill(mask, float('-inf'))
            p = F.softmax(inv_zdist, dim=1) + 1e-9
            # p = (p + p.t()) / (2.0 * zdist.shape[0]) + 1e-9
            p = p.masked_fill(mask, float(1.0))
            return p
        p = get_zdist_prob(zdist)
        z_hat = self.forward(x, graph)
        zhat_dist = torch.cdist(z_hat, z_hat)
        q = get_zdist_prob(zhat_dist)

        kl = p * (torch.log(p) - torch.log(q))
        kl_loss = kl.sum(dim=1).mean()
        return kl_loss

    def loss_fn_kl2(self, x, graph, zdist):
        # guassian kernel + sftmax + non-sym
        def get_zdist_prob(zdist):
            # zdist = torch.cdist(z, z)
            # inv_zdist = 1 / (zdist + 1e-9)
            inv_zdist = -1 * zdist  # this one!
            mask = torch.eye(inv_zdist.shape[0], dtype=torch.bool, device=zdist.device)
            inv_zdist = inv_zdist.masked_fill(mask, float('-inf'))
            p = F.softmax(inv_zdist, dim=1) + 1e-9
            # p = (p + p.t()) / (2.0 * zdist.shape[0]) + 1e-9
            p = p.masked_fill(mask, float(1.0))
            return p
        p = get_zdist_prob(zdist)
        z_hat = self.forward(x, graph)
        zhat_dist = torch.cdist(z_hat, z_hat)
        q = get_zdist_prob(zhat_dist)

        kl = p * (torch.log(p) - torch.log(q))
        kl_loss = kl.sum(dim=1).mean()
        return kl_loss

    def loss_fn_kl_inverse_only(self, x, graph, zdist):
        # inv + sftmax + sym/
        def get_zdist_prob(zdist):
            # zdist = torch.cdist(z, z)
            inv_zdist = 1 / (zdist + 1e-9)
            # inv_zdist = -1 * zdist  # this one!
            mask = torch.eye(inv_zdist.shape[0], dtype=torch.bool, device=zdist.device)
            # inv_zdist = inv_zdist.masked_fill(mask, float('-inf'))
            # p = F.softmax(inv_zdist, dim=1) + 1e-9
            # p = (p + p.t()) / (2.0 * zdist.shape[0]) + 1e-9
            p = inv_zdist
            p = p.masked_fill(mask, float(1.0))
            return p
        p = get_zdist_prob(zdist)
        z_hat = self.forward(x, graph)
        zhat_dist = torch.cdist(z_hat, z_hat)
        q = get_zdist_prob(zhat_dist)

        kl = p * (torch.log(p) - torch.log(q))
        kl_loss = kl.sum(dim=1).mean()
        return kl_loss

    def loss_fn_kl_t(self, x, graph, zdist, df=1):
        def get_zdist_prob_student_t(zdist, degrees_of_freedom=1):
            dist = zdist ** 2
            dist /= degrees_of_freedom
            dist += 1.
            dist **= (degrees_of_freedom + 1.0) / -2.0
            mask = torch.eye(dist.shape[0], dtype=torch.bool, device=zdist.device)
            dist = dist.masked_fill(mask, float(0.0))
            # Q = torch.maximum(dist / (torch.sum(dist)), torch.Tensor([1e-9]))
            Q = dist / (torch.sum(dist)) + 1e-9 #torch.maximum(), torch.Tensor([1e-9]))
            # print(torch.sum(dist))
            Q = Q.masked_fill(mask, float(1.0))
            return Q

        p = get_zdist_prob_student_t(zdist, degrees_of_freedom=df)
        z_hat = self.forward(x, graph)
        zhat_dist = torch.cdist(z_hat, z_hat)
        q = get_zdist_prob_student_t(zhat_dist, degrees_of_freedom=df)#.to(zdist.device)

        kl = p * (torch.log(p) - torch.log(q))
        # kl_loss = kl.sum(dim=1).mean()
        kl_loss = 2 * (df + 1) / df * kl.sum()
        return kl_loss

    def loss_fn_kl_t2(self, x, graph, zdist, df=1):
        def get_zdist_prob_student_t(zdist, degrees_of_freedom=1):
            dist = zdist ** 2
            dist /= degrees_of_freedom
            dist += 1.
            dist **= (degrees_of_freedom + 1.0) / -2.0
            mask = torch.eye(dist.shape[0], dtype=torch.bool, device=zdist.device)
            dist = dist.masked_fill(mask, float(0.0))
            Q = torch.maximum(dist / (torch.sum(dist, dim=1)), torch.Tensor([1e-9]).to(zdist.device))
            # print(torch.sum(dist))
            Q = Q.masked_fill(mask, float(1.0))
            return Q

        p = get_zdist_prob_student_t(zdist)
        z_hat = self.forward(x, graph)
        zhat_dist = torch.cdist(z_hat, z_hat)
        q = get_zdist_prob_student_t(zhat_dist)

        kl = p * (torch.log(p) - torch.log(q))
        kl_loss = (df + 1) / df * kl.sum(dim=1).mean()
        # kl_loss = 2 * (df + 1) / df * kl.sum()
        return kl_loss


def list_collate_fn(items):
    transposed = zip(*items)
    return list(transposed)


if __name__ == '__main__':
    # import GraphDatasets
    # from dgl.dataloading import GraphDataLoader
    #
    # ds = GraphDatasets.DatasetGraphDataset(data_names=['mnist_group1'], cdist_path='../prepare_data/clip/features',
    #                          visual_path='../prepare_data/bo/res-2')
    #
    # net = AutoVisualNet().to('cuda:0')
    #
    # train_loader = torch.utils.data.DataLoader(ds,
    #                                batch_size=2,
    #                                shuffle=True,
    #                                num_workers=0,
    #                                collate_fn=lambda x: list(zip(*x))
    #                                )
    # it = iter(train_loader)
    # a = next(it)
    #
    # # out = net([ds[0][0], ds[1][0]], [ds[0][1], ds[1][1]])
    # out = net(a[0], a[1])
    #
    # print(out.shape)
    pass



