import dgl
import dgl.function as fn
import torch
import torch.nn.functional as F
from graph_learning.data_setting import DataSettingConfig, DataTransform
from collections import defaultdict
from graph_learning.dataset.graph import GLGraph
import networkx as nx
import networkit as nk
import numpy as np

def get_seed_label(seeds, n_nodes, device):
    is_anchor = torch.zeros(n_nodes, dtype=torch.float32, device=device)
    is_anchor[list(seeds)] = 1
    return is_anchor

@DataSettingConfig.register('ga-mpca',
                            help='GA-MPCA anchor selecting algorithm.')
class GAMPCAConfig(DataSettingConfig):
    @property
    def builder(self):
        return GAMPCA

    @classmethod
    def define_parser(cls, parser):
        super().define_parser(parser)
        parser.add_argument('--max-size', type=int,
                            help='number of anchor nodes.')
        parser.add_argument('--num-channels', type=int, default=1,
                            help='select multiple anchor sets if >1 .')

class GAMPCA(DataTransform):
    def __init__(self, max_size, num_channels):
        self.max_size = max_size
        self.num_channels = num_channels

    def _transform(self, graph):
        n = graph.number_of_nodes()

        rg = graph.reverse(copy_ndata=False)
        rg.ndata['unmasked'] = torch.ones(n, device=graph.device, dtype=torch.float)

        rg.update_all(fn.copy_src('unmasked', 'c'), fn.sum('c', 'deg'))
        out_degrees = rg.ndata['deg'].long()

        max_size = self.max_size

        #is_anchor = torch.zeros_like(out_degrees, dtype=torch.float32)
        anchors = []
        i = 0
        while out_degrees.sum() > 0:
            anchor_this = torch.argmax(out_degrees)
            if out_degrees[anchor_this] == 0:
                break
            #is_anchor[anchor_this] = 1
            anchors.append(anchor_this.item())
            coverage_this = graph.successors(anchor_this)

            rg.ndata['unmasked'][anchor_this] = 0
            rg.ndata['unmasked'][coverage_this] = 0

            rg.update_all(fn.copy_src('unmasked', 'c'), fn.sum('c', 'deg'))
            out_degrees = rg.ndata['deg'].long()

            i += 1
            if i == max_size:
                break

        def get_seed_label(seeds):
            is_anchor = torch.zeros_like(out_degrees, dtype=torch.float32)
            is_anchor[list(seeds)] = 1
            return is_anchor

        anchors_l = list(zip(*[anchors[idx:idx + self.num_channels]
                               for idx in range(0, len(anchors), self.num_channels)]))
        #graph.ndata['seed_labels'] = is_anchor
        graph.gdata['seed_labels'] = [get_seed_label(a)
                                      for a in anchors_l]
        graph.ndata['seed_labels'] = get_seed_label(anchors)
        return graph


@DataSettingConfig.register('all-anchors',
                            help='Use all nodes as anchors (For -O experiment settings).')
class GAMPCACConfig(DataSettingConfig):
    @property
    def builder(self):
        return AllAnchors

class AllAnchors(DataTransform):
    def _transform(self, graph):
        graph.ndata['seed_labels'] = torch.ones(graph.number_of_nodes(), device=graph.device, dtype=torch.long)
        return graph

@DataSettingConfig.register('anchor-id-as-feature',
                            help='Anchor labeling trick.')
class AnchorIdAsFeatureConfig(DataSettingConfig):
    def __init__(self, args, context):
        super().__init__(args, context)

    @property
    def builder(self):
        return AnchorIdAsFeature

    @classmethod
    def define_parser(cls, parser):
        super().define_parser(parser)

class AnchorIdAsFeature(DataTransform):
    def __init__(self):
        pass

    def _transform(self, graph):
        anchors = graph.ndata['seed_labels'].nonzero()[:, 0]
        anchor_label = torch.eye(anchors.size(0), device=anchors.device)
        anchor_label_all = torch.zeros(
            graph.number_of_nodes(), anchors.size(0), device=anchors.device)
        anchor_label_all[anchors] = anchor_label
        graph.ndata['x'] = torch.cat([graph.ndata['x'], anchor_label_all], -1)
        return graph

@DataSettingConfig.register('anchor-as-feature',
                            help='Anchor labeling trick.')
class AnchorAsFeatureConfig(DataSettingConfig):
    def __init__(self, args, context):
        super().__init__(args, context)

    @property
    def builder(self):
        return AnchorAsFeature

    @classmethod
    def define_parser(cls, parser):
        super().define_parser(parser)

class AnchorAsFeature(DataTransform):
    def __init__(self):
        pass

    def _transform(self, graph):
        anchors = graph.ndata['seed_labels'].nonzero()[:, 0]
        anchor_label = torch.ones(anchors.size(0), 1, device=anchors.device)
        anchor_label_all = torch.zeros(
            graph.number_of_nodes(), 1, device=anchors.device)
        anchor_label_all[anchors] = anchor_label
        graph.ndata['x'] = torch.cat([graph.ndata['x'], anchor_label_all], -1)
        return graph
