import torch
from torch import nn

from ... import epsilon


class SpatialConvOrderK(nn.Module):
    """
    Spatial convolution of order K with possibly different diffusion matrices (useful for directed graphs)

    Efficient implementation inspired from graph-wavenet codebase
    """

    def __init__(self, c_in, c_out, support_len=3, order=2, include_self=True):
        super(SpatialConvOrderK, self).__init__()
        self.include_self = include_self
        c_in = (order * support_len + (1 if include_self else 0)) * c_in
        self.mlp = nn.Conv2d(c_in, c_out, kernel_size=1)
        self.order = order
        self.c_out = c_out

    @staticmethod
    def compute_support(adj, device=None):
        if device is not None:
            adj = adj.to(device)
        adj_bwd = adj.T
        adj_fwd = adj / (adj.sum(1, keepdims=True) + epsilon)
        adj_bwd = adj_bwd / (adj_bwd.sum(1, keepdims=True) + epsilon)
        support = [adj_fwd, adj_bwd]
        return support

    @staticmethod
    def compute_support_orderK(adj, k, include_self=False, device=None):
        if isinstance(adj, (list, tuple)):
            support = adj
        else:
            support = SpatialConvOrderK.compute_support(adj, device)
        supp_k = []
        for a in support:
            ak = a
            for i in range(k - 1):
                ak = torch.matmul(ak, a.T)
                if not include_self:
                    ak.fill_diagonal_(0.)
                supp_k.append(ak)
        return support + supp_k

    def forward(self, x, support, support_diag=None, pattern=None):
        # [batch, features, nodes, steps]
        if x.dim() < 4:
            squeeze = True
            x = torch.unsqueeze(x, -1)
        else:
            squeeze = False
        out = [x] if self.include_self else []
        if (type(support) is not list):
            support = [support]
        for a in support:
            x1 = torch.einsum('ncvl,wv->ncwl', (x, a)).contiguous()
            out.append(x1)
            for k in range(2, self.order + 1):
                x2 = torch.einsum('ncvl,wv->ncwl', (x1, a)).contiguous()
                out.append(x2)
                x1 = x2
        out = torch.cat(out, dim=1)

        if support_diag is not None:
            out_diag = [x] if self.include_self else []
            if (type(support_diag) is not list):
                support_diag = [support_diag]
            for a in support_diag:
                x1 = torch.einsum('ncvl,wv->ncwl', (x, a)).contiguous()
                out_diag.append(x1)
                for k in range(2, self.order + 1):
                    x2 = torch.einsum('ncvl,wv->ncwl', (x1, a)).contiguous()
                    out_diag.append(x2)
                    x1 = x2
            out_diag = torch.cat(out_diag, dim=1)

        # out => b t*d n s, w/o self-loop and temporal
        # out_diag => b t*d n s, w/ self-loop and temporal
        # suppose t=0,1 (current),2

        if pattern is not None:
            t = out.size(1) // self.c_out
            mid = t // 2

            if pattern == "wo_self_loop_w_temporal":
                # if w/o self-loop, but w/ temporal
                out[:, :mid, :, :] = out_diag[:, :mid, :, :]
                out[:, mid+1:, :, :] = out_diag[:, mid+1:, :, :]
            elif pattern == "w_self_loop_wo_temporal":
                # if w/ self-loop, but w/o temporal
                out[:, mid, :, :] = out_diag[:, mid, :, :]
            elif pattern == "wo_self_loop_wo_temporal":
                out = out_diag

        out = self.mlp(out)
        if squeeze:
            out = out.squeeze(-1)
        return out
