import dgl
from dgl import DGLGraph
import torch

class GLGraph:
    def __init__(self, graph=None):
        self.graph = graph
        if self.graph is None:
            self.graph = DGLGraph()
        self.gdata = {}
        self.gdata['__schema'] = {}

    def __getattr__(self, name):
        gdata = object.__getattribute__(self, 'gdata')
        if name in gdata:
            return gdata
        graph = object.__getattribute__(self, 'graph')
        return getattr(graph, name)

    def adapt(self, dgl_graph):
        graph = GLGraph(dgl_graph)
        graph.gdata = self.gdata
        return graph

    def add_batch_schema(self, name, batch_schema, unbatch_schema=None):
        self.gdata['__schema'][name] = (batch_schema, unbatch_schema)

    def __str__(self):
        return str(self.graph)

    def __repr__(self):
        return self.graph.__repr__()

def edge_batch(graphs, edges):
    if not (isinstance(edges[0], tuple) or isinstance(edges[0], torch.Tensor)):
        return [edge_batch(graphs, e) for e in map(list, zip(*edges))]
    graph_es = []
    for graph, edge in zip(graphs, edges):
        graph_e = dgl.DGLGraph().to(edge[0].device)
        graph_e.add_nodes(graph.number_of_nodes())
        graph_e.add_edges(*edge)
        graph_es.append(graph_e)
    graph_eb = dgl.batch(graph_es)
    if isinstance(edges[0], tuple):
        return graph_eb.edges()
    else:
        return torch.stack(graph_eb.edges(), 0)

def graph_batch(graphs, gdata_gs):
    return dgl.batch(gdata_gs)

def blockdiag_batch(graphs, gdata_ps):
    return torch.block_diag(*gdata_ps)

def gl_batch(batch):
    if len(batch) == 1:
        return batch[0]
    keys = batch[0].gdata.keys()
    gdata_b = {'__schema': {}}
    for k in keys:
        if not k.startswith('__'):
            gdata_k = [graph.gdata[k] for graph in batch]
            if k in batch[0].gdata['__schema']:
                batch_schema, _ = batch[0].gdata['__schema'][k]
                import ipdb; ipdb.set_trace()
                gdata_k = batch_schema(batch, gdata_k)
            gdata_b[k] = gdata_k
        else:
            gdata_b[k] = batch[0].gdata[k]

    batch_g = GLGraph(dgl.batch(batch))
    batch_g.gdata = gdata_b
    return batch_g

def gl_unbatch(batch_g):
    if batch_g.batch_size == 1:
        return [batch_g]

    gdata_b = batch_g.gdata
    gs = [GLGraph(g) for g in dgl.unbatch(batch_g)]

    gdatas_k = {}

    for k in batch_g.gdata.keys():
        if not k.startswith('__'):
            gdata_k = batch_g.gdata[k]
            if k in batch_g.gdata['__schema']:
                batch_schema, unbatch_schema = batch_g.gdata['__schema'][k]
                if unbatch_schema is not None:
                    gdata_k = unbatch_schema(batch_g, gdata_k)
                    gdatas_k[k] = gdata_k
        else:
            gdatas_k[k] = [batch_g.gdata[k]] * len(gs)

    for i, g in enumerate(gs):
        for k in batch_g.gdata.keys():
            if k in gdatas_k.keys():
                gs[i].gdata[k] = gdatas_k[k][i]

    return gs

