import dgl
import dgl.function as fn

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from itertools import accumulate

from .sage import SAGEConv
from graph_learning.module import ModuleConfig, get_module

from .fusion import *

@ModuleConfig.register('gir',
                       help='[Encoder] GIR')
class GIRLModuleConfig(ModuleConfig):
    def __init__(self, args, context):
        super().__init__(args, context)
        self.logger = context.global_.logger

    @property
    def builder(self):
        return GIRL

    @classmethod
    def define_parser(cls, parser):
        super().define_parser(parser)
        parser.add_argument('--in-size', type=int)
        parser.add_argument('--hidden-size', type=int)
        parser.add_argument('--num-channels', type=int, default=1,
                            help='GIR-MIX if >1.')
        parser.add_argument('--out-size', type=int)

        parser.add_argument('--depth', type=int)
        parser.add_argument('--dropout', type=float)
        parser.add_argument('--bn', action='store_true',
                            help='use batchnorm')
        parser.add_argument('--aggr', choices=['mean', 'pool'], default='mean')
        parser.add_argument('--mix-share', action='store_true',
                            help='share parameters for all anchor sets in GIR-MIX.')
        parser.add_argument('--layer-share', action='store_true')

        parser.add_argument('--weighted', action='store_true')
        parser.add_argument('--return-hiddens', action='store_true')


def get_prop_graphs(g, seeds, depth, device=None):
    if device is None:
        device = g.device
    g = g.remove_self_loop()
    prop_nodes = seeds

    subgs = []

    def get_filter_edges(src, src_ids, tar):
        srct = torch.stack(src).t().contiguous()
        tart = torch.stack((tar[1], tar[0])).t().contiguous()
        filtered_ever = torch.zeros(src[0].size(0), dtype=torch.bool, device=srct.device)
        for tartb in torch.split(tart, 2000):
            filtered = (srct.unsqueeze(1).expand(-1, tartb.size(0), -1)==tartb).all(2).any(1)
            filtered_ever |= filtered
        left = ~filtered_ever
        return (src[0][left].detach(), src[1][left].detach()), src_ids[left]

    for i in range(depth):
        prop_edges = g.out_edges(prop_nodes)
        prop_edges_id = g.out_edges(prop_nodes, form='eid')
        # print(f'depth {i}:')
        # print(f'{prop_edges[0].size(0)} edges.')

        subgs.append(
            g.to('cpu')
            .edge_subgraph(prop_edges_id.cpu(),preserve_nodes=True)
            .to(device))
        prop_nodes = torch.unique(prop_edges[1])

    subgs = [subg.to(device) for subg in subgs]
    return subgs

class GIRL(nn.Module):
    def __init__(self, in_size, hidden_size, num_channels, out_size,
                 depth, dropout, bn, mix_share, layer_share,
                 name, aggr, weighted, return_hiddens):
        super().__init__()
        self.name = name
        self.depth = depth
        self.hidden_size = hidden_size
        self.num_channels = num_channels
        self.out_size = out_size
        self.mix_share = mix_share
        self.layer_share = layer_share
        self.return_hiddens = return_hiddens

        self.activation = nn.ReLU()
        self.dropout_layer = nn.Dropout(p=dropout)
        self.use_bn = bn

        self.gir_layers = nn.ModuleList()

        #self.gir_forwards = nn.ModuleList()
        for i in range(depth):
            lin_in_size = hidden_size * num_channels
            lin_out_size = hidden_size * num_channels
            if (self.out_size is not None) and i==depth-1:
                lin_out_size = self.out_size
            #self.gir_forwards.append(nn.Linear(lin_in_size, lin_out_size))

        if self.use_bn:
            self.bn_layers = nn.ModuleList()

        for gi in range(num_channels):
            gir_layers_g = nn.ModuleList()

            if self.use_bn:
                bn_layers_g = nn.ModuleList()

            layer = None
            for i in range(depth):
                conv_in_size = in_size if i==0 else hidden_size
                conv_out_size = hidden_size
                activation = self.activation
                # if (self.out_size is not None) and i==depth-1:
                #     conv_out_size = out_size
                #     activation = nn.Identity()
                if not self.layer_share or i <=1:
                    layer = SAGEConv(conv_in_size, conv_out_size, activation=activation,
                                     aggregator_type=aggr, weighted=weighted)
                gir_layers_g.append(layer)

                if self.use_bn:
                    bn_layers_g.append(nn.BatchNorm1d(conv_out_size, track_running_stats=False))

            self.gir_layers.append(gir_layers_g)
            if self.use_bn:
                self.bn_layers.append(bn_layers_g)

            if self.mix_share and gi==0:
                break

    def forward(self, data, x):
        g = data.graph()
        g = g.adapt(g.local_var())

        prop_graphs_l = [prop_graphs[:self.depth]
                         for prop_graphs in g.gdata['_gir_prop_graphs']]

        prop_graphs_layer = list(map(list, zip(*prop_graphs_l)))

        h = x
        hs_ret = []
        hs = [h] * len(prop_graphs_layer[0])
        for li, prop_graphs in enumerate(prop_graphs_layer):
            for gi, prop_graph in enumerate(prop_graphs):
                gi_ = 0 if self.mix_share else gi

                hg = self.gir_layers[gi_][li](prop_graph, hs[gi])

                if not (self.out_size is not None and li==self.depth-1):
                    if self.use_bn:
                        hg = self.bn_layers[gi_][li](hg)
                    hg = self.dropout_layer(hg)

                hs[gi] = hg

            h = torch.cat(hs, -1)
            #h = self.gir_forwards[li](h)
            hs_ret.append(h)

        ret = (hs_ret, torch.stack(
            [torch.ones(x.size(0), dtype=torch.bool, device=x.device)
             for x in hs_ret], 1)) if self.return_hiddens else h
        #mask = torch.ones(g.number_of_nodes(), dtype=torch.bool, device=x.device)


        return {'hidden': ret,
                'outputs': {
                }
        }

