import numpy as np
import scipy.sparse as sparse
import torch
import torch.nn as nn

import dgl
import dgl.function as fn
from dgl.base import DGLError


class DiffConv(nn.Module):
    """DiffConv is the implementation of diffusion convolution from paper DCRNN
    It will compute multiple diffusion matrix and perform multiple diffusion conv on it,
    this layer can be used for traffic prediction, pedamic model.
    Parameter
    ==========
    in_feats : int
        number of input feature

    out_feats : int
        number of output feature

    k : int
        number of diffusion steps

    dir : str [both/in/out]
        direction of diffusion convolution
        From paper default both direction
    """

    def __init__(
        self, in_feats, out_feats, k, in_graph_list, out_graph_list, dir="both"
    ):
        super(DiffConv, self).__init__()
        self.in_feats = in_feats
        self.out_feats = out_feats
        self.k = k
        self.dir = dir
        self.num_graphs = self.k - 1 if self.dir == "both" else 2 * self.k - 2
        self.project_fcs = nn.ModuleList()
        for i in range(self.num_graphs):
            self.project_fcs.append(
                nn.Linear(self.in_feats, self.out_feats, bias=False)
            )
        self.merger = nn.Parameter(torch.randn(self.num_graphs + 1))
        self.in_graph_list = in_graph_list
        self.out_graph_list = out_graph_list

    @staticmethod
    def attach_graph(g, k):
        device = g.device
        out_graph_list = []
        in_graph_list = []
        wadj, ind, outd = DiffConv.get_weight_matrix(g)
        adj = sparse.coo_matrix(wadj / outd.cpu().numpy())
        outg = dgl.from_scipy(adj, eweight_name="weight").to(device)
        outg.edata["weight"] = outg.edata["weight"].float().to(device)
        out_graph_list.append(outg)
        for i in range(k - 1):
            out_graph_list.append(
                DiffConv.diffuse(out_graph_list[-1], wadj, outd)
            )
        adj = sparse.coo_matrix(wadj.T / ind.cpu().numpy())
        ing = dgl.from_scipy(adj, eweight_name="weight").to(device)
        ing.edata["weight"] = ing.edata["weight"].float().to(device)
        in_graph_list.append(ing)
        for i in range(k - 1):
            in_graph_list.append(
                DiffConv.diffuse(in_graph_list[-1], wadj.T, ind)
            )
        return out_graph_list, in_graph_list

    @staticmethod
    def get_weight_matrix(g):
        adj = g.adj(scipy_fmt="coo")
        ind = g.in_degrees()
        outd = g.out_degrees()
        weight = g.edata["weight"]
        adj.data = weight.cpu().numpy()
        return adj, ind, outd

    @staticmethod
    def diffuse(progress_g, weighted_adj, degree):
        device = progress_g.device
        progress_adj = progress_g.adj(scipy_fmt="coo")
        progress_adj.data = progress_g.edata["weight"].cpu().numpy()
        ret_adj = sparse.coo_matrix(
            progress_adj @ (weighted_adj / degree.cpu().numpy())
        )
        ret_graph = dgl.from_scipy(ret_adj, eweight_name="weight").to(device)
        ret_graph.edata["weight"] = ret_graph.edata["weight"].float().to(device)
        return ret_graph

    def forward(self, g, x):
        feat_list = []
        if self.dir == "both":
            graph_list = self.in_graph_list + self.out_graph_list
        elif self.dir == "in":
            graph_list = self.in_graph_list
        elif self.dir == "out":
            graph_list = self.out_graph_list

        for i in range(self.num_graphs):
            g = graph_list[i]
            with g.local_scope():
                g.ndata["n"] = self.project_fcs[i](x)
                g.update_all(
                    fn.u_mul_e("n", "weight", "e"), fn.sum("e", "feat")
                )
                feat_list.append(g.ndata["feat"])
                # Each feat has shape [N,q_feats]
        feat_list.append(self.project_fcs[-1](x))
        feat_list = torch.cat(feat_list).view(
            len(feat_list), -1, self.out_feats
        )
        ret = (
            (self.merger * feat_list.permute(1, 2, 0)).permute(2, 0, 1).mean(0)
        )
        return ret
