import dgl
import torch
from graph_learning.data_setting import DataSettingConfig, DataTransform
from graph_learning.dataset.graph import GLGraph, gl_batch

@DataSettingConfig.register('to',
                            help='Send data to specific device.')
class GraphToConfig(DataSettingConfig):
    def __init__(self, args, context):
        super().__init__(args, context)

    @property
    def builder(self):
        return GraphTo

    @classmethod
    def define_parser(cls, parser):
        super().define_parser(parser)
        parser.add_argument('--device')

class GraphTo(DataTransform):
    def __init__(self, device):
        self.device = device

    def _transform(self, graph):
        for n, v in graph.gdata.items():
            if isinstance(v, (torch.Tensor, dgl.DGLGraph)):
                graph.gdata[n] = v.to(self.device)
            if isinstance(v, list):
                graph.gdata[n] = [vi.to(self.device) for vi in v]
        return graph.adapt(graph.to(self.device))


@DataSettingConfig.register('prepare-gir',
                            help='GIR propogation paths for each layer.')
class GirPropGraphsConfig(DataSettingConfig):
    def __init__(self, args, context):
        super().__init__(args, context)

    @property
    def builder(self):
        return GirPropGraphs

    @classmethod
    def define_parser(cls, parser):
        super().define_parser(parser)
        parser.add_argument('--depth', type=int,
                            help='number of propogation layers.')
        parser.add_argument('--device')

class GirPropGraphs(DataTransform):
    def __init__(self, depth, device):
        self.depth = depth
        self.device = device

    def _transform(self, graph):
        from graph_learning.module.modules.encoders.gir import get_prop_graphs
        from graph_learning.dataset.graph import edge_batch

        prop_graphs_l = []
        for seed_labels in graph.gdata.pop('seed_labels'):
            seed_nodes = graph.nodes()[seed_labels == 1]
            prop_graphs = get_prop_graphs(graph, seed_nodes, self.depth, self.device)
            prop_graphs_l.append(prop_graphs)
        graph.gdata['_gir_prop_graphs'] = prop_graphs_l
        graph.add_batch_schema('_gir_prop_graphs', edge_batch)
        return graph

@DataSettingConfig.register('prepare-agnn',
                            help='AGNN propagation paths.')
class PrepareAGNNConfig(DataSettingConfig):
    def __init__(self, args, context):
        super().__init__(args, context)

    @property
    def builder(self):
        return PrepareAGNN

    @classmethod
    def define_parser(cls, parser):
        super().define_parser(parser)

class PrepareAGNN(DataTransform):
    def _transform(self, graph):
        from graph_learning.module.modules.encoders.agnn import get_anchor_graph
        from graph_learning.dataset.graph import graph_batch

        seed_nodes = graph.nodes()[graph.ndata['seed_labels'] == 1]
        anchor_graph = get_anchor_graph(graph, seed_nodes)
        graph.gdata['_anchor_graph'] = anchor_graph
        graph.add_batch_schema('_anchor_graph', graph_batch)
        return graph


@DataSettingConfig.register('pgnn-precompute-dist',
                            help='Compute distance information for PGNN and AGNN.')
class PGNNPrecomputeDistConfig(DataSettingConfig):
    def __init__(self, args, context):
        super().__init__(args, context)

    @property
    def builder(self):
        return PGNNPrecomputeDist

    @classmethod
    def define_parser(cls, parser):
        super().define_parser(parser)
        parser.add_argument('--approximate', type=int, default=0)

class PGNNPrecomputeDist(DataTransform):
    def __init__(self, approximate):
        self.approximate = approximate

    def _transform(self, graph):
        from graph_learning.module.modules.encoders.pgnn import precompute_dist_data
        from graph_learning.dataset.graph import blockdiag_batch

        dist = precompute_dist_data(graph.cpu().to_networkx(), graph.number_of_nodes(), approximate=self.approximate)

        graph.gdata['_pgnn_dist'] = torch.from_numpy(dist).to(graph.device)
        graph.add_batch_schema('_pgnn_dist', blockdiag_batch)

        return graph


@DataSettingConfig.register('pgnn-select-anchor',
                            help='PGNN anchor seleting.')
class PGNNSelectAnchorConfig(DataSettingConfig):
    def __init__(self, args, context):
        super().__init__(args, context)

    @property
    def builder(self):
        return PGNNSelectAnchor

    @classmethod
    def define_parser(cls, parser):
        super().define_parser(parser)
        parser.add_argument('--layer-num', type=int,
                            help='number of layers.')
        parser.add_argument('--anchor-num', type=int,
                            help='number of anchor set.')
        parser.add_argument('--anchor-size-num', type=int,
                            help='size of per anchor set.')

class PGNNSelectAnchor(DataTransform):
    def __init__(self, layer_num, anchor_num, anchor_size_num):
        self.layer_num = layer_num
        self.anchor_num = anchor_num
        self.anchor_size_num = anchor_size_num

    def _transform(self, graph):
        from graph_learning.module.modules.encoders.pgnn import pgnn_preselect_anchor

        dists_max, dists_argmax = pgnn_preselect_anchor(
            graph, self.layer_num, self.anchor_num, self.anchor_size_num, graph.device)

        graph.ndata['dists_max'] = dists_max
        graph.ndata['dists_argmax'] = dists_argmax

        return graph


@DataSettingConfig.register('add-self-loop',
                            help='Add self loop.')
class AddSelfLoopConfig(DataSettingConfig):
    def __init__(self, args, context):
        super().__init__(args, context)

    @property
    def builder(self):
        return AddSelfLoop

    @classmethod
    def define_parser(cls, parser):
        super().define_parser(parser)

class AddSelfLoop(DataTransform):
    def _transform(self, graph):
        g = graph.adapt(dgl.add_self_loop(graph))
        return g
