import torch
from torch import nn
import torch.nn.functional as F
import torch.nn.init as init
import math
import scipy.sparse as sp
import numpy as np

from ... import epsilon


class SpatialConvOrderK(nn.Module):
    """
    Spatial convolution of order K with possibly different diffusion matrices (useful for directed graphs)

    Efficient implementation inspired from graph-wavenet codebase
    """

    def __init__(self, c_in, c_out, _len=3, support_len=3, order=2, include_self=True):
        super(SpatialConvOrderK, self).__init__()
        self.include_self = include_self
        c_in = (order * support_len + (1 if include_self else 0)) * c_in
        self.mlp = nn.Conv2d(c_in, c_out, kernel_size=1)
        self.order = order
        self.c_out = c_out

    @staticmethod
    def compute_support(adj, device=None):
        if device is not None:
            adj = adj.to(device)
        adj_bwd = adj.T
        adj_fwd = adj / (adj.sum(1, keepdims=True) + epsilon)
        adj_bwd = adj_bwd / (adj_bwd.sum(1, keepdims=True) + epsilon)
        support = [adj_fwd, adj_bwd]
        return support

    @staticmethod
    def compute_support_orderK(adj, k, include_self=False, device=None):
        if isinstance(adj, (list, tuple)):
            support = adj
        else:
            support = SpatialConvOrderK.compute_support(adj, device)
        supp_k = []
        for a in support:
            ak = a
            for i in range(k - 1):
                ak = torch.matmul(ak, a.T)
                if not include_self:
                    ak.fill_diagonal_(0.)
                supp_k.append(ak)
        return support + supp_k

    def forward(self, x, support, support_diag=None, pattern=None):
        # [batch, features, nodes, steps]
        if x.dim() < 4:
            squeeze = True
            x = torch.unsqueeze(x, -1)
        else:
            squeeze = False
        out = [x] if self.include_self else []
        if (type(support) is not list):
            support = [support]
        for a in support:
            x1 = torch.einsum('ncvl,wv->ncwl', (x, a)).contiguous()
            out.append(x1)
            for k in range(2, self.order + 1):
                x2 = torch.einsum('ncvl,wv->ncwl', (x1, a)).contiguous()
                out.append(x2)
                x1 = x2
        out = torch.cat(out, dim=1)

        if support_diag is not None:
            out_diag = [x] if self.include_self else []
            if (type(support_diag) is not list):
                support_diag = [support_diag]
            for a in support_diag:
                x1 = torch.einsum('ncvl,wv->ncwl', (x, a)).contiguous()
                out_diag.append(x1)
                for k in range(2, self.order + 1):
                    x2 = torch.einsum('ncvl,wv->ncwl', (x1, a)).contiguous()
                    out_diag.append(x2)
                    x1 = x2
            out_diag = torch.cat(out_diag, dim=1)

        # out => b t*d n s, w/o self-loop and temporal
        # out_diag => b t*d n s, w/ self-loop and temporal
        # suppose t=0,1 (current),2

        if pattern is not None:
            t = out.size(1) // self.c_out
            mid = t // 2

            if pattern == "wo_self_loop_w_temporal":
                # if w/o self-loop, but w/ temporal
                out[:, :mid, :, :] = out_diag[:, :mid, :, :]
                out[:, mid+1:, :, :] = out_diag[:, mid+1:, :, :]
            elif pattern == "w_self_loop_wo_temporal":
                # if w/ self-loop, but w/o temporal
                out[:, mid, :, :] = out_diag[:, mid, :, :]
            elif pattern == "wo_self_loop_wo_temporal":
                out = out_diag

        out = self.mlp(out)
        if squeeze:
            out = out.squeeze(-1)
        return out


class D_GCN(nn.Module):
    """
    Neural network block that applies a diffusion graph convolution to sampled location
    """

    def __init__(self, in_channels, out_channels, orders, activation='relu', use_kaiming_init=True):
        """
        :param in_channels: Number of time step.
        :param out_channels: Desired number of output features at each node in
        each time step.
        :param order: The diffusion steps.
        """
        super(D_GCN, self).__init__()
        self.orders = orders
        self.activation = activation
        self.num_matrices = 2 * self.orders + 1
        self.Theta1 = nn.Parameter(torch.FloatTensor(in_channels * self.num_matrices,
                                                     out_channels))
        self.bias = nn.Parameter(torch.FloatTensor(out_channels))
        self.use_kaiming = use_kaiming_init
        self.reset_parameters()

    def reset_parameters(self):
        if self.use_kaiming:
            init.kaiming_normal_(self.Theta1, mode='fan_in', nonlinearity='relu')
            if self.bias is not None:
                fan_in = self.Theta1.shape[1]
                bound = 1 / math.sqrt(fan_in)
                self.bias.data.uniform_(-bound, bound)
        else:
            stdv = 1. / math.sqrt(self.Theta1.shape[1])
            self.Theta1.data.uniform_(-stdv, stdv)
            stdv1 = 1. / math.sqrt(self.bias.shape[0])
            self.bias.data.uniform_(-stdv1, stdv1)

    def _concat(self, x, x_):
        x_ = x_.unsqueeze(0)
        return torch.cat([x, x_], dim=0)

    def forward(self, X, A_q, A_h):
        """
        :param X: Input data of shape (batch_size, num_nodes, num_timesteps)
        :A_q: The forward random walk matrix (num_nodes, num_nodes)
        :A_h: The backward random walk matrix (num_nodes, num_nodes)
        :return: Output data of shape (batch_size, num_nodes, num_features)
        """
        batch_size = X.shape[0]  # batch_size
        num_node = X.shape[1]
        input_size = X.size(2)  # time_length
        supports = []
        supports.append(A_q)
        supports.append(A_h)

        x0 = X.permute(1, 2, 0)  # (num_nodes, num_times, batch_size)
        x0 = torch.reshape(x0, shape=[num_node, input_size * batch_size])
        x = torch.unsqueeze(x0, 0)
        for support in supports:
            x1 = torch.mm(support, x0)
            x = self._concat(x, x1)
            for k in range(2, self.orders + 1):
                x2 = 2 * torch.mm(support, x1) - x0
                x = self._concat(x, x2)
                x1, x0 = x2, x1

        x = torch.reshape(x, shape=[self.num_matrices, num_node, input_size, batch_size])
        x = x.permute(3, 1, 2, 0)  # (batch_size, num_nodes, input_size, order)
        x = torch.reshape(x, shape=[batch_size, num_node, input_size * self.num_matrices])
        x = torch.matmul(x, self.Theta1)  # (batch_size * self._num_nodes, output_size)
        x += self.bias
        if self.activation == 'relu':
            x = F.relu(x)
        elif self.activation == 'selu':
            x = F.selu(x)

        return x

    @staticmethod
    def compute_support(adj, device=None, original=False):
        if device is not None:
            adj = adj.to(device)
        adj_bwd = adj.T
        adj_fwd = adj / (adj.sum(1, keepdims=True) + epsilon)
        adj_bwd = adj_bwd / (adj_bwd.sum(1, keepdims=True) + epsilon)
        if original:
            support = [adj_fwd, adj_bwd]
        else:
            support = [adj_fwd.T, adj_bwd.T]  # Transpose needed since torch.mm(support, x0) is used later
        return support


class DGIN(nn.Module):
    """
    Direction-sensitive Graph Isomorphism Network (DGIN).
    Separately aggregates incoming and outgoing messages.
    """

    def __init__(self, in_channels, out_channels, activation='relu', eps=0.0, train_eps=True):
        super(DGIN, self).__init__()
        self.activation = activation

        self.eps_in = nn.Parameter(torch.Tensor([eps])) if train_eps else torch.tensor([eps])
        self.eps_out = nn.Parameter(torch.Tensor([eps])) if train_eps else torch.tensor([eps])

        self.mlp_in = nn.Sequential(
            nn.Linear(in_channels, out_channels),
            nn.BatchNorm1d(out_channels),
            nn.ReLU() if activation == 'relu' else nn.SELU(),
            nn.Linear(out_channels, out_channels)
        )
        self.mlp_out = nn.Sequential(
            nn.Linear(in_channels, out_channels),
            nn.BatchNorm1d(out_channels),
            nn.ReLU() if activation == 'relu' else nn.SELU(),
            nn.Linear(out_channels, out_channels)
        )
        self.reset_parameters()

    def reset_parameters(self):
        def init_mlp(mlp):
            for layer in mlp:
                if isinstance(layer, nn.Linear):
                    nn.init.kaiming_normal_(layer.weight, nonlinearity='relu')
                    if layer.bias is not None:
                        nn.init.zeros_(layer.bias)
                elif isinstance(layer, nn.BatchNorm1d):
                    nn.init.ones_(layer.weight)
                    nn.init.zeros_(layer.bias)

        init_mlp(self.mlp_in)
        init_mlp(self.mlp_out)

    def forward(self, X, A_out, A_in):
        """
        :param X: Node features (B, N, F)
        :param A_out: Outgoing normalized adjacency matrix (N, N)
        :param A_in: Incoming normalized adjacency matrix (N, N)
        """
        B, N, F = X.shape
        # I = torch.eye(N, device=X.device)
        A_out_sl = A_out
        A_in_sl = A_in

        agg_out = torch.einsum('ij,bjf->bif', A_out_sl, X)
        agg_in = torch.einsum('ij,bjf->bif', A_in_sl, X)

        h_out = (1 + self.eps_out) * X + agg_out
        h_in = (1 + self.eps_in) * X + agg_in

        h_out = self.mlp_out(h_out.reshape(B * N, F)).reshape(B, N, -1)
        h_in = self.mlp_in(h_in.reshape(B * N, F)).reshape(B, N, -1)

        out = h_in + h_out  # Can be replaced with concat + Linear to enhance expressiveness
        return out

    @staticmethod
    def compute_support(adj, device=None, epsilon=1e-6):
        """
        Computes normalized forward and backward adjacency matrices.
        :param adj: (N, N) raw adjacency matrix
        :return: A_out, A_in (normalized forward and backward)
        """
        if device is not None:
            adj = adj.to(device)
        adj_bwd = adj.T
        adj_fwd = adj / (adj.sum(1, keepdims=True) + epsilon)
        adj_bwd = adj_bwd / (adj_bwd.sum(1, keepdims=True) + epsilon)
        return adj_fwd, adj_bwd


class GAT(nn.Module):
    def __init__(self, in_features, out_features, num_heads=1, alpha=0.2, concat=True, dropout=0):
        super(GAT, self).__init__()
        assert out_features % num_heads == 0, "out_features must be divisible by num_heads"

        self.in_features = in_features
        self.out_features = out_features
        self.num_heads = num_heads
        self.concat = concat
        self.head_dim = out_features // num_heads
        self.alpha = alpha
        self.dropout = dropout

        self.W = nn.Parameter(torch.empty(size=(num_heads, in_features, self.head_dim)))
        self.a = nn.Parameter(torch.empty(size=(num_heads, 2 * self.head_dim, 1)))

        nn.init.xavier_uniform_(self.W.data, gain=1.414)
        nn.init.xavier_uniform_(self.a.data, gain=1.414)

        self.leakyrelu = nn.LeakyReLU(self.alpha)

        if not concat:
            self.out_proj = nn.Linear(self.head_dim, out_features)

    def forward(self, h, adj):
        B, N, _ = h.size()
        h_prime_list = []

        for head in range(self.num_heads):
            Wh = torch.matmul(h, self.W[head])  # (B, N, out_features)
            e = self._prepare_attentional_mechanism_input(Wh, self.a[head])  # (B, N, N)

            zero_vec = -9e15 * torch.ones_like(e)
            attention = torch.where(adj > 0, e, zero_vec)
            attention = F.softmax(attention, dim=2)
            attention = F.dropout(attention, self.dropout, training=self.training)

            h_prime = torch.matmul(attention, Wh)  # (B, N, out_features)
            h_prime_list.append(h_prime)

        if self.concat:
            # Concatenate along feature dimension
            return F.elu(torch.cat(h_prime_list, dim=2))  # (B, N, out_features * num_heads)
        else:
            # Average heads
            avg_output = torch.mean(torch.stack(h_prime_list, dim=0), dim=0)  # (B, N, head_dim)
            return self.out_proj(avg_output)  # (B, N, out_features)

    def _prepare_attentional_mechanism_input(self, Wh, a):
        """
        Wh: (B, N, head_dim)
        a:  (2 * head_dim, 1)
        returns: e (B, N, N)
        """
        head_dim = Wh.size(-1)
        a1 = a[:head_dim]  # (head_dim, 1)
        a2 = a[head_dim:]  # (head_dim, 1)

        Wh1 = torch.matmul(Wh, a1)  # (B, N, 1)
        Wh2 = torch.matmul(Wh, a2)  # (B, N, 1)

        e = Wh1 + Wh2.transpose(1, 2)  # (B, N, N)
        return self.leakyrelu(e)
