from numpy.core.numeric import zeros_like
import dgl
import networkx as nx
import torch
import numpy as np

from graph_learning.data_setting import DataSettingConfig, DataTransform
from graph_learning.dataset.graph import GLGraph, gl_batch

@DataSettingConfig.register('a-position-aware-labels-all')
class APositionAwareLabelsAllConfig(DataSettingConfig):
    def __init__(self, args, context):
        super().__init__(args, context)

    @property
    def builder(self):
        return APositionAwareLabelsAll

    @classmethod
    def define_parser(cls, parser):
        super().define_parser(parser)
        parser.add_argument('--depth', type=int)
        parser.add_argument('--uw', action='store_true')

def get_pair_dists(data, depth, unweighted):
    anchors = data.ndata['seed_labels'].nonzero()[:, 0].tolist()
    edge_attrs = []
    if 'weight' in data.edata:
        edge_attrs.append('weight')
    G = nx.DiGraph(dgl.to_networkx(data.cpu(), edge_attrs=edge_attrs))

    n = G.number_of_nodes()
    pair_dists = torch.zeros(len(anchors), depth, n).fill_(np.infty)


    for i in range(len(anchors)):
        anchor = anchors[i]
        for j in range(n):
            paths = list(nx.all_simple_paths(G, anchor, j, cutoff=depth))
            for l in range(depth):
                if unweighted:
                    dists = [len(p)-1 for p in paths
                             if len(p) <= l+2]
                else:
                    dists = [nx.path_weight(G, p, 'weight') for p in paths
                             if len(p) <= l+2]
                if len(dists) > 0:
                    pair_dists[i][l][j] = min(dists)
                if anchor == j:
                    pair_dists[i][l][j] = 0

    return pair_dists

class APositionAwareLabelsAll(DataTransform):
    def __init__(self, depth, uw):
        self.depth = depth
        self.uw = uw

    def _transform(self, data):
        pair_dists = get_pair_dists(data, self.depth, self.uw)

        labels = pair_dists.min(0).values.T.unsqueeze(2)
        labels = torch.where(torch.isinf(labels), torch.full_like(labels, np.nan), labels)
        data.ndata['labels'] = labels

        return data

@DataSettingConfig.register('a-position-aware-labels-each')
class APositionAwareLabelsEachConfig(DataSettingConfig):
    def __init__(self, args, context):
        super().__init__(args, context)

    @property
    def builder(self):
        return APositionAwareLabelsEach

    @classmethod
    def define_parser(cls, parser):
        super().define_parser(parser)
        parser.add_argument('--depth', type=int)
        parser.add_argument('--uw', action='store_true')

class APositionAwareLabelsEach(DataTransform):
    def __init__(self, depth, uw):
        self.depth = depth
        self.uw = uw

    def _transform(self, data):
        pair_dists = get_pair_dists(data, self.depth, self.uw)

        labels = pair_dists.permute(2, 1, 0)
        labels = torch.where(torch.isinf(labels), torch.full_like(labels, np.nan), labels)
        data.ndata['labels'] = labels
        return data

@DataSettingConfig.register('a-position-aware-each-as-feature')
class APositionAwareFeatEachConfig(DataSettingConfig):
    def __init__(self, args, context):
        super().__init__(args, context)

    @property
    def builder(self):
        return APositionAwareFeatEach

    @classmethod
    def define_parser(cls, parser):
        super().define_parser(parser)
        parser.add_argument('--depth', type=int)
        parser.add_argument('--uw', action='store_true')
        parser.add_argument('--im', action='store_true')

class APositionAwareFeatEach(DataTransform):
    def __init__(self, depth, uw, im):
        self.depth = depth
        self.uw = uw
        self.im = im

    def _transform(self, data):
        pair_dists = get_pair_dists(data, self.depth, self.uw)
        if self.im:
            dist_mark = pair_dists.permute(2, 1, 0).flatten(1)
        else:
            dist_mark = pair_dists[:,-1].permute(1, 0)
        dist_mark = torch.where(dist_mark.isinf(), torch.zeros_like(dist_mark), dist_mark)
        data.ndata['x'] = torch.cat([data.ndata['x'], dist_mark], -1)
        return data
