import numpy as np
import torch
import dgl
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import ot

from . import DeepSet, GIN, RRWPEncoder, GAT
from . import sinkhorn


def init_weights(m):
    if isinstance(m, nn.Linear):
        init.kaiming_normal_(m.weight, nonlinearity='leaky_relu')
        # init.xavier_uniform_(m.weight)
        if m.bias is not None:
            init.constant_(m.bias, 0)


class AutoVisualNet(nn.Module):
    def __init__(self, input_dim=16, gnn_hidden=128, gnn_out=32, out_dim=2, n_graph_view=1, n_transformer=8,
                 n_out_subnet=1, is_gt_parallel=False,
                 device='cuda:0'):
        super().__init__()
        self.sub_batch_size = 1

        self.input_dim, self.gnn_hidden, self.gnn_out, self.out_dim = input_dim, gnn_hidden, gnn_out, out_dim
        self.n_out_subnet = n_out_subnet
        self.is_gt_parallel = is_gt_parallel
        # self.hidden_pe_dim = hidden_pe_dim

        self.device = device

        # self.rrwp_encoder = RRWPEncoder.RRWPLinearNodeEncoder(self.input_dim, self.hidden_pe_dim)

        # self.deepset = DeepSet.DeepSet(1, self.deepset_dim).to(device)
        self.n_graph_view = n_graph_view

        self.use_sep_gnn = True
        self.n_transforms = n_transformer

        # self.sinkhorn_fn = sinkhorn.SinkhornDistance(device=device)

        if self.use_sep_gnn:
            self.gnns = nn.ModuleDict()
            self.gts = nn.ModuleDict()
            for i in range(self.n_graph_view):
                self.gnns[str(i)] = GIN.GIN(self.input_dim, self.gnn_hidden, self.gnn_out).to(device)
                self.gts[str(i)] = self.build_transformer_blocks().to(device)
                # self.gnns[str(i)] = GAT.GAT(self.input_dim, self.gnn_hidden, self.gnn_out).to(device)
        else:
            self.gin = GIN.GIN(self.input_dim, self.gnn_hidden, self.gnn_out).to(device)

        if self.is_gt_parallel:
            self.mlp = nn.Sequential(
                nn.Linear((self.gnn_out + self.input_dim) * self.n_graph_view, self.gnn_out * self.n_graph_view),
                nn.SELU(),
                nn.Linear(self.gnn_out * self.n_graph_view, self.gnn_out),
                nn.SELU(),
                nn.Linear(self.gnn_out, self.gnn_out),
                nn.SELU(),
                # nn.LayerNorm(self.gnn_out*self.n_graph_view),
                nn.Linear(self.gnn_out, self.out_dim)
            )
        else:
            if self.n_out_subnet == 1:
                self.mlp = nn.Sequential(
                    nn.Linear(self.gnn_out*self.n_graph_view, self.gnn_out*self.n_graph_view//2),
                    nn.SELU(),
                    nn.Linear(self.gnn_out * self.n_graph_view //2, self.gnn_out),
                    nn.SELU(),
                    nn.Linear(self.gnn_out, self.gnn_out),
                    nn.SELU(),
                    # nn.LayerNorm(self.gnn_out*self.n_graph_view),
                    nn.Linear(self.gnn_out, self.out_dim)
                )
            else:
                assert self.n_out_subnet > 1
                raise NotImplementedError
                # self.mlp = nn.Sequential(
                #     nn.Linear(self.gnn_out*self.n_graph_view, self.gnn_out),
                #     nn.SELU(),
                #     # nn.LayerNorm(self.gnn_out),
                #     nn.Linear(self.gnn_out, self.gnn_out)
                # )
                # self.out_subnets = nn.ModuleDict()
                # for _ind in range(self.n_out_subnet):
                #     self.out_subnets[str(_ind)] = nn.Sequential(
                #         nn.Linear(self.gnn_out, self.gnn_out),
                #         nn.SELU(),
                #         nn.Linear(self.gnn_out, self.out_dim)
                # )

        # self.gt_blocks = self.build_transformer_blocks()

    def build_transformer_blocks(self):
        gt_blocks = torch.nn.ModuleList()
        if self.is_gt_parallel:
            inp_d = self.input_dim
        else:
            inp_d = self.gnn_out
        for i in range(self.n_transforms):
            gt_blocks.append(nn.TransformerEncoderLayer(d_model=inp_d, nhead=4))
        return gt_blocks
        # torch.nn.Linear
        # att = torch.nn.MultiheadAttention(4, 4)
        # self.reset_parameters()

    def _gt_forward(self, x, gt_blocks, bs, batch_graph_size_cumsum):
        batch_gt_out = []
        for j in range(bs):
            gt_out_ = x[batch_graph_size_cumsum[j]: batch_graph_size_cumsum[j + 1]]
            for blk in gt_blocks:
                gt_out_ = blk(gt_out_)
            batch_gt_out.append(gt_out_)
        batch_gt_out = torch.cat(batch_gt_out)
        return batch_gt_out

    def gt_mv_forward(self, n_view_graphs):
        batch_graph_size = n_view_graphs.batch_num_nodes()
        batch_graph_size_cumsum = torch.cumsum(torch.cat([torch.zeros(1).to(batch_graph_size.device), batch_graph_size]), dim=0).long()
        outputs = []
        for i in range(self.n_graph_view):
            # print('finish form')
            batch_gt_out = self._gt_forward(n_view_graphs.ndata[f'pe{i}'], #.to(self.device),
                                            self.gts[str(i)],
                                            n_view_graphs.batch_size, batch_graph_size_cumsum)
            outputs.append(batch_gt_out)
        outputs = torch.concatenate(outputs, dim=1)
        return outputs

    def gin_forward(self, edge_only_graph, n_view_graph):
        # graphs: [[dgl graph_i]*n_graph_view, ...]
        #
        # batch_graph_size = n_view_graph.batch_num_nodes()
        # batch_graph_size_cumsum = torch.cumsum(torch.cat([torch.zeros(1), batch_graph_size]), dim=0).long()
        outputs = []
        for i in range(self.n_graph_view):
            # print(n_view_graph.edata[f'weight{i}'].device)
            _output = self.gnns[str(i)](edge_only_graph, n_view_graph.ndata[f'pe{i}'],#.to(self.device),
                                        edge_weight=n_view_graph.edata[f'weight{i}'])#.to(self.device))
            outputs.append(_output)
        outputs = torch.concatenate(outputs, dim=1)

        return outputs

    def gin_gt_forward(self, edge_only_graph, n_view_graph):
        # graphs: [[dgl graph_i]*n_graph_view, ...]
        #
        batch_graph_size = n_view_graph.batch_num_nodes()
        batch_graph_size_cumsum = torch.cumsum(torch.cat([torch.zeros(1), batch_graph_size]), dim=0).long()
        outputs = []
        for i in range(self.n_graph_view):
            _output = self.gnns[str(i)](edge_only_graph, n_view_graph.ndata[f'pe{i}'],#.to(self.device),
                                        edge_weight=n_view_graph.edata[f'weight{i}'])#.to(self.device))
            _batch_gt_out = self._gt_forward(_output, self.gts[str(i)],
                                            n_view_graph.batch_size, batch_graph_size_cumsum)
            outputs.append(_batch_gt_out)
        outputs = torch.concatenate(outputs, dim=1)
        return outputs

    def gin_forward_uni_gnn(self, graphs, node_features):
        # graphs: [[dgl graph_i]*n_graph_view, ...]
        n_ds = len(graphs)
        outputs = []
        for sub_batch_idx in range(n_ds // self.sub_batch_size):
            _batch_graphs = dgl.batch(graphs[sub_batch_idx]).to(self.device)  # todo: crash when sub_batch_size > 1
            # _batch_graphs = graphs
            batch_pe = torch.cat(node_features[sub_batch_idx]).to(self.device)
            # batch_pe = self.rrwp_encoder(batch_pe)
            output = self.gin(_batch_graphs, batch_pe)
            outputs.append(torch.cat(output.chunk(self.n_graph_view), dim=1))
        outputs = torch.concatenate(outputs)
        return outputs

    def forward(self, edge_only_graph, n_view_graphs):
        if self.is_gt_parallel:
            return self.forward_parallel(edge_only_graph, n_view_graphs)

        # set_enc = self.deepset_forward(cdists)
        # print('gin forward...')

        n_view_out = self.gin_gt_forward(edge_only_graph, n_view_graphs)  # (n_node_i*n_graph, d*n_view)
        if self.n_out_subnet == 1:
            out = self.mlp(n_view_out)
        else:
            _out = self.mlp(n_view_out)
            out = []
            for _ind in range(self.n_out_subnet):
                out.append(self.out_subnets[str(_ind)](_out))
            out = torch.concatenate(out, dim=1)
        # print('finish forward')
        return out

    def forward_parallel(self, edge_only_graph, n_view_graphs):
        # set_enc = self.deepset_forward(cdists)
        # print('gin forward...')
        n_view_graphs = n_view_graphs.to(self.device)
        gnn_n_view_out = self.gin_forward(edge_only_graph, n_view_graphs)  # (n_node_i*n_graph, d*n_view)
        # print('GT forward...')
        gt_n_view_out = self.gt_mv_forward(n_view_graphs)

        n_view_out = torch.cat([gnn_n_view_out, gt_n_view_out], dim=1)
        if self.n_out_subnet == 1:
            out = self.mlp(n_view_out)
        else:
            _out = self.mlp(n_view_out)
            out = []
            for _ind in range(self.n_out_subnet):
                out.append(self.out_subnets[str(_ind)](_out))
            out = torch.concatenate(out, dim=1)
        # print('finish forward')
        return out

    def reset_parameters(self):
        """ Initialize the weights and bias.
        :return: None
        """
        self.apply(init_weights)

    def loss_fn_kl(self, x, graph, zdist):
        # inv + sftmax + sym/
        def get_zdist_prob(zdist):
            # zdist = torch.cdist(z, z)
            inv_zdist = 1 / (zdist + 1e-9)
            # inv_zdist = -1 * zdist  # this one!
            mask = torch.eye(inv_zdist.shape[0], dtype=torch.bool, device=zdist.device)
            inv_zdist = inv_zdist.masked_fill(mask, float('-inf'))
            p = F.softmax(inv_zdist, dim=1) + 1e-9
            # p = (p + p.t()) / (2.0 * zdist.shape[0]) + 1e-9
            p = p.masked_fill(mask, float(1.0))
            return p
        print('find p..')
        p = get_zdist_prob(zdist)
        print('begin forward')
        z_hat = self.forward(x, graph)
        zhat_dist = torch.cdist(z_hat, z_hat)
        print('find q..')
        q = get_zdist_prob(zhat_dist)
        print('find loss')
        kl = p * (torch.log(p) - torch.log(q))
        kl_loss = kl.sum(dim=1).mean()
        return kl_loss

    def loss_fn_kl2(self, x, graph, zdist):
        # guassian kernel + sftmax + non-sym
        def get_zdist_prob(zdist):
            # zdist = torch.cdist(z, z)
            # inv_zdist = 1 / (zdist + 1e-9)
            inv_zdist = -1 * zdist  # this one!
            mask = torch.eye(inv_zdist.shape[0], dtype=torch.bool, device=zdist.device)
            inv_zdist = inv_zdist.masked_fill(mask, float('-inf'))
            p = F.softmax(inv_zdist, dim=1) + 1e-9
            # p = (p + p.t()) / (2.0 * zdist.shape[0]) + 1e-9
            p = p.masked_fill(mask, float(1.0))
            return p
        p = get_zdist_prob(zdist)
        z_hat = self.forward(x, graph)
        zhat_dist = torch.cdist(z_hat, z_hat)
        q = get_zdist_prob(zhat_dist)

        kl = p * (torch.log(p) - torch.log(q))
        kl_loss = kl.sum(dim=1).mean()
        return kl_loss

    def loss_fn_kl_inverse_only(self, x, graph, zdist):
        # inv + sftmax + sym/
        def get_zdist_prob(zdist):
            # zdist = torch.cdist(z, z)
            inv_zdist = 1 / (zdist + 1e-9)
            # inv_zdist = -1 * zdist  # this one!
            mask = torch.eye(inv_zdist.shape[0], dtype=torch.bool, device=zdist.device)
            # inv_zdist = inv_zdist.masked_fill(mask, float('-inf'))
            # p = F.softmax(inv_zdist, dim=1) + 1e-9
            # p = (p + p.t()) / (2.0 * zdist.shape[0]) + 1e-9
            p = inv_zdist
            p = p.masked_fill(mask, float(1.0))
            return p
        p = get_zdist_prob(zdist)
        z_hat = self.forward(x, graph)
        zhat_dist = torch.cdist(z_hat, z_hat)
        q = get_zdist_prob(zhat_dist)

        kl = p * (torch.log(p) - torch.log(q))
        kl_loss = kl.sum(dim=1).mean()
        return kl_loss

    def loss_fn_kl_t(self, graph, zdist, df=1):
        def get_zdist_prob_student_t(zdist, degrees_of_freedom=1):
            n_sample = zdist.shape[1]
            dist = zdist ** 2
            dist /= degrees_of_freedom
            dist += 1.
            dist **= (degrees_of_freedom + 1.0) / -2.0
            mask = torch.eye(n_sample, dtype=torch.bool, device=zdist.device)
            dist = dist.masked_fill(mask, float(0.0))
            # Q = torch.maximum(dist / (torch.sum(dist)), torch.Tensor([1e-9]))
            Q = dist / (torch.sum(dist)) + 1e-9 #torch.maximum(), torch.Tensor([1e-9]))
            # print(torch.sum(dist))
            Q = Q.masked_fill(mask, float(1.0))
            return Q

        # print('find p..')
        p = get_zdist_prob_student_t(zdist, degrees_of_freedom=df)
        # print('begin forward')
        z_hat = self.forward(graph)
        zhat_dist = torch.cdist(z_hat, z_hat)
        # print('find q..')
        q = get_zdist_prob_student_t(zhat_dist, degrees_of_freedom=df)#.to(zdist.device)
        # print('find loss')
        #
        kl = p * (torch.log(p) - torch.log(q))
        # kl_loss = kl.sum(dim=1).mean()
        kl_loss = 2 * (df + 1) / df * kl.sum()
        # print('finish find loss')
        return kl_loss

    def loss_fn_kl_t_pdist(self, n_view_graphs, edge_only_graph, zdist, batch_zdist_size, df=1):
        '''zdist: in pdist form'''

        def find_block_sum(inp, block_size):
            block_mean = torch.zeros((len(block_size),), dtype=inp.dtype, device=inp.device)
            block_indices = torch.repeat_interleave(torch.arange(len(block_size)).to(inp.device), block_size)
            block_mean.index_reduce_(0, block_indices, inp, 'mean', include_self=False)
            block_sum = block_mean * block_size
            return block_sum

        def find_block_sum_seq(inp, block_size):
            block_size_cumsum = torch.cumsum(torch.cat([torch.zeros(1).to(inp.device), block_size]), dim=0).long()
            block_sum = torch.zeros((len(block_size),), dtype=inp.dtype, device=inp.device)
            for i in range(len(block_size)):
                block_sum[i] = torch.sum(inp[block_size_cumsum[i]: block_size_cumsum[i + 1]])
            return block_sum
        def get_zdist_prob_student_t(zdist, size, degrees_of_freedom=1):
            dist = zdist ** 2
            dist /= degrees_of_freedom
            dist += 1.
            dist **= (degrees_of_freedom + 1.0) / -2.0

            block_sum = find_block_sum(dist, size)
            block_indices = torch.repeat_interleave(torch.arange(len(size)).to(dist.device), size)
            block_sum = block_sum[block_indices]
            q_ = dist / block_sum
            Q = torch.maximum(q_, torch.Tensor([1e-9]).cuda())
            return Q

        # print('find p..')
        p = get_zdist_prob_student_t(zdist, batch_zdist_size, degrees_of_freedom=df)
        # print('begin forward')
        z_hat = self.forward(edge_only_graph, n_view_graphs)

        batch_graph_size = n_view_graphs.batch_num_nodes()
        batch_graph_size_cumsum = torch.cumsum(torch.cat([torch.zeros(1).to(batch_graph_size.device), batch_graph_size]), dim=0).long()
        zhat_dist = torch.cat([torch.pdist(z_hat[batch_graph_size_cumsum[i]: batch_graph_size_cumsum[i + 1]])
                           for i in range(n_view_graphs.batch_size)])

        # zhat_dist = torch.pdist(z_hat)
        # print('find q..')
        q = get_zdist_prob_student_t(zhat_dist, batch_zdist_size, degrees_of_freedom=df)#.to(zdist.device)
        # print('find loss')
        #
        kl = p * (torch.log(p) - torch.log(q))
        # kl_loss = kl.sum(dim=1).mean()
        kl_sum = find_block_sum_seq(kl, batch_zdist_size)
        kl_loss = 2 * (df + 1) / df * kl_sum.mean()
        # print('finish find loss')
        return kl_loss, z_hat

    # def get_z_hat(self, n_view_graphs, edge_only_graph):
    #     z_hat = self.forward(edge_only_graph, n_view_graphs)

    def loss_fn_kl_t_pdist_w_umap(self, n_view_graphs, edge_only_graph, tsne_zdist, umap_zdist, batch_zdist_size, df=1):
        '''zdist: in pdist form'''

        def find_block_sum(inp, block_size):
            block_mean = torch.zeros((len(block_size),), dtype=inp.dtype, device=inp.device)
            block_indices = torch.repeat_interleave(torch.arange(len(block_size)).to(inp.device), block_size)
            block_mean.index_reduce_(0, block_indices, inp, 'mean', include_self=False)
            block_sum = block_mean * block_size
            return block_sum

        def find_block_sum_seq(inp, block_size):
            block_size_cumsum = torch.cumsum(torch.cat([torch.zeros(1).to(inp.device), block_size]), dim=0).long()
            block_sum = torch.zeros((len(block_size),), dtype=inp.dtype, device=inp.device)
            for i in range(len(block_size)):
                block_sum[i] = torch.sum(inp[block_size_cumsum[i]: block_size_cumsum[i + 1]])
            return block_sum

        def get_zdist_weight_tsne(zdist, degrees_of_freedom=1):
            dist = zdist ** 2
            dist /= degrees_of_freedom
            dist += 1.
            dist **= (degrees_of_freedom + 1.0) / -2.0
            return dist

        def get_zdist_prob_student_t(zdist, size, degrees_of_freedom=1):
            dist = get_zdist_weight_tsne(zdist, degrees_of_freedom)

            block_sum = find_block_sum(dist, size)
            block_indices = torch.repeat_interleave(torch.arange(len(size)).to(dist.device), size)
            block_sum = block_sum[block_indices]
            q_ = dist / block_sum
            Q = torch.maximum(q_, torch.Tensor([1e-9]).cuda())
            return Q

        def get_zdist_weight_umap(zdist, size, _a=1.93, _b=0.79):
            return get_zdist_prob_student_t(zdist, size)
            # dist = 1 / (1 + _a * zdist ** (2 * _b))
            # Q = dist

            # block_sum = find_block_sum(dist, size)
            # block_indices = torch.repeat_interleave(torch.arange(len(size)).to(dist.device), size)
            # block_sum = block_sum[block_indices]
            # q_ = dist / block_sum
            # Q = torch.maximum(q_, torch.Tensor([1e-9]).cuda())

            # return Q

        def find_block_ce_seq(q, p, block_size):
            block_size_cumsum = torch.cumsum(torch.cat([torch.zeros(1).to(q.device), block_size]), dim=0).long()
            out = torch.zeros((len(block_size),), dtype=q.dtype, device=q.device)
            for i in range(len(block_size)):
                out[i] = torch.nn.functional.binary_cross_entropy(p[block_size_cumsum[i]: block_size_cumsum[i + 1]],
                                                           q[block_size_cumsum[i]: block_size_cumsum[i + 1]])
                # out[i] = torch.nn.functional.cross_entropy(q[block_size_cumsum[i]: block_size_cumsum[i + 1]],
                #                                                   p[block_size_cumsum[i]: block_size_cumsum[i + 1]])
            return out

        # print('find p..')
        tsne_p = get_zdist_prob_student_t(tsne_zdist, batch_zdist_size, degrees_of_freedom=df)
        umap_p = get_zdist_weight_umap(umap_zdist, batch_zdist_size)
        # print('begin forward')
        z_hat = self.forward(edge_only_graph, n_view_graphs)

        batch_graph_size = n_view_graphs.batch_num_nodes()
        batch_graph_size_cumsum = torch.cumsum(torch.cat([torch.zeros(1).to(batch_graph_size.device), batch_graph_size]), dim=0).long()
        tsne_zhat_dist = torch.cat([torch.pdist(z_hat[batch_graph_size_cumsum[i]: batch_graph_size_cumsum[i + 1], :2])
                           for i in range(n_view_graphs.batch_size)])

        umap_zhat_dist = torch.cat([torch.pdist(z_hat[batch_graph_size_cumsum[i]: batch_graph_size_cumsum[i + 1], 2:])
                               for i in range(n_view_graphs.batch_size)])

        # zhat_dist = torch.pdist(z_hat)
        # print('find q..')
        tsne_q = get_zdist_prob_student_t(tsne_zhat_dist, batch_zdist_size, degrees_of_freedom=df)#.to(zdist.device)
        umap_q = get_zdist_weight_umap(umap_zhat_dist, batch_zdist_size)#.to(zdist.device)
        # print('find loss')
        #

        kl = tsne_p * (torch.log(tsne_p) - torch.log(tsne_q))
        # kl_loss = kl.sum(dim=1).mean()
        kl_sum = find_block_sum_seq(kl, batch_zdist_size)
        kl_tsne_loss = 2 * (df + 1) / df * kl_sum.mean()

        kl_umap = umap_p * (torch.log(umap_p) - torch.log(umap_q))
        kl_umap_sum = find_block_sum_seq(kl_umap, batch_zdist_size)

        # kl_umap_sum = find_block_ce_seq(umap_q, umap_p, batch_zdist_size)

        # kl_umap = (umap_p - umap_q)**2  # l2
        # kl_umap_sum = find_block_sum_seq(kl_umap, batch_zdist_size)

        kl_umap_loss = kl_umap_sum.mean() * 2 * (df + 1) / df
        # print('finish find loss')
        kl_loss = kl_tsne_loss + kl_umap_loss
        # print(f'kl_tsne_loss: {kl_tsne_loss}, kl_umap_loss: {kl_umap_loss}')
        return kl_loss, z_hat

    def loss_fn_l2_pdist_w_umap(self, n_view_graphs, edge_only_graph, tsne_zdist, umap_zdist, batch_zdist_size, df=1):
        '''zdist: in pdist form'''

        def find_block_sum(inp, block_size):
            block_mean = torch.zeros((len(block_size),), dtype=inp.dtype, device=inp.device)
            block_indices = torch.repeat_interleave(torch.arange(len(block_size)).to(inp.device), block_size)
            block_mean.index_reduce_(0, block_indices, inp, 'mean', include_self=False)
            block_sum = block_mean * block_size
            return block_sum

        def find_block_sum_seq(inp, block_size, find_mean=False):
            block_size_cumsum = torch.cumsum(torch.cat([torch.zeros(1).to(inp.device), block_size]), dim=0).long()
            block_sum = torch.zeros((len(block_size),), dtype=inp.dtype, device=inp.device)
            for i in range(len(block_size)):
                if find_mean:
                    block_sum[i] = torch.mean(inp[block_size_cumsum[i]: block_size_cumsum[i + 1]])
                else:
                    block_sum[i] = torch.sum(inp[block_size_cumsum[i]: block_size_cumsum[i + 1]])
            return block_sum

        def get_zdist_weight_tsne(zdist, degrees_of_freedom=1):
            dist = zdist ** 2
            dist /= degrees_of_freedom
            dist += 1.
            dist **= (degrees_of_freedom + 1.0) / -2.0
            return dist

        def get_zdist_prob_student_t(zdist, size, degrees_of_freedom=1):
            dist = get_zdist_weight_tsne(zdist, degrees_of_freedom)

            block_sum = find_block_sum(dist, size)
            block_indices = torch.repeat_interleave(torch.arange(len(size)).to(dist.device), size)
            block_sum = block_sum[block_indices]
            q_ = dist / block_sum
            Q = torch.maximum(q_, torch.Tensor([1e-9]).cuda())
            return Q

        def get_zdist_weight_umap(zdist, _a=1.93, _b=0.79):
            dist = 1 / (1 + _a * zdist ** (2 * _b))
            Q = dist

            # block_sum = find_block_sum(dist, size)
            # block_indices = torch.repeat_interleave(torch.arange(len(size)).to(dist.device), size)
            # block_sum = block_sum[block_indices]
            # q_ = dist / block_sum
            # Q = torch.maximum(q_, torch.Tensor([1e-9]).cuda())

            return Q

        # print('find p..')
        tsne_p = get_zdist_weight_tsne(tsne_zdist, degrees_of_freedom=df)
        umap_p = get_zdist_weight_umap(umap_zdist)
        # print('begin forward')
        z_hat = self.forward(edge_only_graph, n_view_graphs)

        batch_graph_size = n_view_graphs.batch_num_nodes()
        batch_graph_size_cumsum = torch.cumsum(torch.cat([torch.zeros(1), batch_graph_size]), dim=0).long()
        tsne_zhat_dist = torch.cat([torch.pdist(z_hat[batch_graph_size_cumsum[i]: batch_graph_size_cumsum[i + 1], :2])
                           for i in range(n_view_graphs.batch_size)])

        umap_zhat_dist = torch.cat([torch.pdist(z_hat[batch_graph_size_cumsum[i]: batch_graph_size_cumsum[i + 1], 2:])
                               for i in range(n_view_graphs.batch_size)])

        # zhat_dist = torch.pdist(z_hat)
        # print('find q..')
        tsne_q = get_zdist_weight_tsne(tsne_zhat_dist, degrees_of_freedom=df)#.to(zdist.device)
        umap_q = get_zdist_weight_umap(umap_zhat_dist)#.to(zdist.device)
        # print('find loss')
        #

        kl = (tsne_p - tsne_q)**2  # l2
        # kl_loss = kl.sum(dim=1).mean()
        kl_sum = find_block_sum_seq(kl, batch_zdist_size, find_mean=True)
        kl_tsne_loss = kl_sum.mean()

        kl_umap = (umap_p - umap_q)**2  # l2
        kl_umap_sum = find_block_sum_seq(kl_umap, batch_zdist_size, find_mean=True)
        kl_umap_loss = kl_umap_sum.mean() # * 2 * (df + 1) / df
        # print('finish find loss')
        kl_loss = kl_tsne_loss + kl_umap_loss
        # print(f'kl_tsne_loss: {kl_tsne_loss}, kl_umap_loss: {kl_umap_loss}')
        return kl_loss, z_hat

    def loss_fn_sinkhorn_w_umap(self, n_view_graphs, edge_only_graph, tsne_z, umap_z, batch_zdist_size, df=1):
        '''zdist: in pdist form'''

        def find_block_sum(inp, block_size):
            block_mean = torch.zeros((len(block_size),), dtype=inp.dtype, device=inp.device)
            block_indices = torch.repeat_interleave(torch.arange(len(block_size)).to(inp.device), block_size)
            block_mean.index_reduce_(0, block_indices, inp, 'mean', include_self=False)
            block_sum = block_mean * block_size
            return block_sum

        def find_block_sum_seq(inp, block_size, find_mean=False):
            block_size_cumsum = torch.cumsum(torch.cat([torch.zeros(1).to(inp.device), block_size]), dim=0).long()
            block_sum = torch.zeros((len(block_size),), dtype=inp.dtype, device=inp.device)
            for i in range(len(block_size)):
                if find_mean:
                    block_sum[i] = torch.mean(inp[block_size_cumsum[i]: block_size_cumsum[i + 1]])
                else:
                    block_sum[i] = torch.sum(inp[block_size_cumsum[i]: block_size_cumsum[i + 1]])
            return block_sum

        def get_zdist_weight_tsne(zdist, degrees_of_freedom=1):
            dist = zdist ** 2
            dist /= degrees_of_freedom
            dist += 1.
            dist **= (degrees_of_freedom + 1.0) / -2.0
            return dist

        def get_zdist_prob_student_t(zdist, size, degrees_of_freedom=1):
            dist = get_zdist_weight_tsne(zdist, degrees_of_freedom)

            block_sum = find_block_sum(dist, size)
            block_indices = torch.repeat_interleave(torch.arange(len(size)).to(dist.device), size)
            block_sum = block_sum[block_indices]
            q_ = dist / block_sum
            Q = torch.maximum(q_, torch.Tensor([1e-9]).cuda())
            return Q

        def get_zdist_weight_umap(zdist, _a=1.93, _b=0.79):
            dist = 1 / (1 + _a * zdist ** (2 * _b))
            Q = dist

            # block_sum = find_block_sum(dist, size)
            # block_indices = torch.repeat_interleave(torch.arange(len(size)).to(dist.device), size)
            # block_sum = block_sum[block_indices]
            # q_ = dist / block_sum
            # Q = torch.maximum(q_, torch.Tensor([1e-9]).cuda())

            return Q

        def find_block_sinkhorn_seq(q, p, block_size):
            block_size_cumsum = torch.cumsum(torch.cat([torch.zeros(1).to(q.device), block_size]), dim=0).long()
            out = torch.zeros((len(block_size),), dtype=q.dtype, device=q.device)
            q = q.view(-1, 2, 1)
            p = p.view(-1, 2, 1)
            for i in range(len(block_size)):
                # print(p[block_size_cumsum[i]: block_size_cumsum[i + 1]].shape)
                # print(q[block_size_cumsum[i]: block_size_cumsum[i + 1]].shape)
                # out[i], log_ = ot.bregman.empirical_sinkhorn2(p[block_size_cumsum[i]: block_size_cumsum[i + 1]],
                #                                            q[block_size_cumsum[i]: block_size_cumsum[i + 1]],
                #                                         reg=1.0, numIterMax=100, stopThr=1e-3, log=True)
                out[i] = self.sinkhorn_fn(p[block_size_cumsum[i]: block_size_cumsum[i + 1]],
                                                q[block_size_cumsum[i]: block_size_cumsum[i + 1]],
                                                )
                # print(out[i])
            return out

        # print('find p..')
        # tsne_p = get_zdist_weight_tsne(tsne_zdist, degrees_of_freedom=df)
        # umap_p = get_zdist_weight_umap(umap_zdist)
        # print('begin forward')
        z_hat = self.forward(edge_only_graph, n_view_graphs)

        # batch_graph_size = n_view_graphs.batch_num_nodes()
        # batch_graph_size_cumsum = torch.cumsum(torch.cat([torch.zeros(1), batch_graph_size]), dim=0).long()
        # tsne_zhat_dist = torch.cat([torch.pdist(z_hat[batch_graph_size_cumsum[i]: batch_graph_size_cumsum[i + 1], :2])
        #                    for i in range(n_view_graphs.batch_size)])
        #
        # umap_zhat_dist = torch.cat([torch.pdist(z_hat[batch_graph_size_cumsum[i]: batch_graph_size_cumsum[i + 1], 2:])
        #                        for i in range(n_view_graphs.batch_size)])

        # zhat_dist = torch.pdist(z_hat)
        # print('find q..')
        # tsne_q = get_zdist_weight_tsne(tsne_zhat_dist, degrees_of_freedom=df)#.to(zdist.device)
        # umap_q = get_zdist_weight_umap(umap_zhat_dist)#.to(zdist.device)
        # print('find loss')
        #

        kl = find_block_sinkhorn_seq(tsne_z, z_hat[:, :2], batch_zdist_size)  # l2
        # kl_loss = kl.sum(dim=1).mean()
        # kl_sum = find_block_sum_seq(kl, batch_zdist_size, find_mean=True)
        kl_tsne_loss = kl.mean()

        kl_umap = find_block_sinkhorn_seq(umap_z, z_hat[:, 2:], batch_zdist_size)  # l2
        # kl_umap_sum = find_block_sum_seq(kl_umap, batch_zdist_size, find_mean=True)
        kl_umap_loss = kl_umap.mean() # * 2 * (df + 1) / df
        # print('finish find loss')
        kl_loss = 0.1 * kl_tsne_loss + kl_umap_loss
        # print(f'kl_tsne_loss: {kl_tsne_loss}, kl_umap_loss: {kl_umap_loss}')
        return kl_loss, z_hat

    def loss_fn_kl_t2(self, x, graph, zdist, df=1):
        def get_zdist_prob_student_t(zdist, degrees_of_freedom=1):
            dist = zdist ** 2
            dist /= degrees_of_freedom
            dist += 1.
            dist **= (degrees_of_freedom + 1.0) / -2.0
            mask = torch.eye(dist.shape[0], dtype=torch.bool, device=zdist.device)
            dist = dist.masked_fill(mask, float(0.0))
            Q = torch.maximum(dist / (torch.sum(dist, dim=1)), torch.Tensor([1e-9]).to(zdist.device))
            # print(torch.sum(dist))
            Q = Q.masked_fill(mask, float(1.0))
            return Q

        p = get_zdist_prob_student_t(zdist)
        z_hat = self.forward(x, graph)
        zhat_dist = torch.cdist(z_hat, z_hat)
        q = get_zdist_prob_student_t(zhat_dist)

        kl = p * (torch.log(p) - torch.log(q))
        kl_loss = (df + 1) / df * kl.sum(dim=1).mean()
        # kl_loss = 2 * (df + 1) / df * kl.sum()
        return kl_loss


def list_collate_fn(items):
    transposed = zip(*items)
    return list(transposed)


if __name__ == '__main__':
    # import GraphDatasets
    # from dgl.dataloading import GraphDataLoader
    #
    # ds = GraphDatasets.DatasetGraphDataset(data_names=['mnist_group1'], cdist_path='../prepare_data/clip/features',
    #                          visual_path='../prepare_data/bo/res-2')
    #
    # net = AutoVisualNet().to('cuda:0')
    #
    # train_loader = torch.utils.data.DataLoader(ds,
    #                                batch_size=2,
    #                                shuffle=True,
    #                                num_workers=0,
    #                                collate_fn=lambda x: list(zip(*x))
    #                                )
    # it = iter(train_loader)
    # a = next(it)
    #
    # # out = net([ds[0][0], ds[1][0]], [ds[0][1], ds[1][1]])
    # out = net(a[0], a[1])
    #
    # print(out.shape)
    pass



