import torch
from torch import nn
import torch_sparse
import torch.nn.functional as F
from .base_classes import ODEFunc
# from utils import MaxNFEException
from torch_geometric.nn.conv import GCNConv
from torch_geometric.utils.loop import add_remaining_self_loops,remove_self_loops
# from utils import get_rw_adj
import numpy as np
from torch_geometric.utils import softmax, degree
from torch.nn.utils import spectral_norm
# def batch_jacobian(func, x, create_graph=False):
#   # x in shape (Batch, Length)
#   def _func_sum(x):
#     return func(x).sum(dim=0)        ###readout, pooling, mean, square, norm2, asb,l1 norm
#
#   return torch.autograd.functional.jacobian(_func_sum, x, create_graph=create_graph).permute(1, 2, 0)

def batch_jacobian(func, x, create_graph=False):
  # x in shape (Batch, Length)
  # def _func_sum(x):
  #   return func(x).sum(dim=0)        ###readout, pooling, mean, square, norm2, asb,l1 norm

  return torch.autograd.functional.jacobian(func, x, create_graph=create_graph).permute(1, 2, 0)


class attention_H(nn.Module):
  """"replace this module by a aggregation function """

  def __init__(self, size_in, edge_index):
    super().__init__()
    self.dim = size_in

    self.layer1 =GCNConv(size_in*2, size_in, normalize=True)
    self.edge_index = edge_index
    self.layer2 =GCNConv(size_in,1, normalize=True)
    self.dropout = nn.Dropout(p=0.4)
  def forward(self, x):
    #
    out = self.layer1(x,self.edge_index)
    out = torch.tanh(out)
    # out = torch.relu(out)
    # out = self.dropout(out)
    out = self.layer2(out,self.edge_index)
    # out = torch.tanh(out)
    out = torch.norm(out, dim=0)
    # print("out.shape: ", out.shape)
    return out

class H_x(nn.Module):
  """"replace this module by a aggregation function """

  def __init__(self, size_in):
    super().__init__()
    self.dim = size_in

    self.layer1 =GCNConv(size_in, size_in, normalize=True)

    self.layer2 =GCNConv(size_in,1, normalize=True)

  def forward(self, x,edge_index):
    #
    # print("input x: ", x.shape)
    self.edge_index = edge_index
    out = self.layer1(x,self.edge_index)
    # print("out x: ", out.shape)
    out = torch.tanh(out)
    # out = torch.relu(out)
    out = self.layer2(out,self.edge_index)
    out = torch.norm(out, dim=0)
    return out

class H_x_linear(nn.Module):
  """"replace this module by a aggregation function """

  def __init__(self, size_in, edge_index):
    super().__init__()
    self.dim = size_in

    self.layer1 =nn.Linear(size_in, size_in,)
    self.edge_index = edge_index
    self.layer2 =nn.Linear(size_in,1,)

  def forward(self, x):
    #
    out = self.layer1(x,)
    out = torch.tanh(out)
    # out = torch.relu(out)
    out = self.layer2(out,)
    out = torch.norm(out, dim=0)
    return out

class H_derivatie(nn.Module):
  """"replace this module by a aggregation function """

  def __init__(self, size_in, edge_index):
    super().__init__()
    self.dim = size_in

    self.layer1 = GCNConv(size_in * 2, size_in*2, normalize=True)
    self.edge_index = edge_index
    self.layer2 = GCNConv(size_in*2, size_in*2, normalize=True)

  def forward(self, x):
    #
    # print('x in H: ',x.type())
    # print('edge_index in H: ', self.edge_index.type())
    out = self.layer1(x,self.edge_index)
    out = torch.tanh(out)
    out = self.layer2(out,self.edge_index)
    return out

class H_derivatie_x(nn.Module):
  """"replace this module by a aggregation function """

  def __init__(self, size_in):
    super().__init__()
    self.dim = size_in

    self.layer1 = GCNConv(size_in, size_in, normalize=True)
    # self.edge_index = edge_index
    self.layer2 = GCNConv(size_in, size_in, normalize=True)

  def forward(self, x,edge_index):
    #
    # print('x in H: ',x.type())
    # print('edge_index in H: ', self.edge_index.type())
    self.edge_index = edge_index
    out = self.layer1(x,self.edge_index)
    out = torch.tanh(out)
    # out = torch.sin(out)
    out = self.layer2(out,self.edge_index)
    # out = torch.sin(out)
    return out


class H_derivatie_x_linear(nn.Module):
  """"replace this module by a aggregation function """

  def __init__(self, size_in, edge_index):
    super().__init__()
    self.dim = size_in

    self.layer1 = nn.Linear(size_in, size_in, )
    self.edge_index = edge_index
    self.layer2 = nn.Linear(size_in, size_in, )

  def forward(self, x):
    #
    # print('x in H: ',x.type())
    # print('edge_index in H: ', self.edge_index.type())
    out = self.layer1(x,)
    out = torch.tanh(out)
    # out = torch.sin(out)
    out = self.layer2(out,)
    # out = torch.sin(out)
    return out

class linear_H(nn.Module):
  """"replace this module by a aggregation function """

  def __init__(self, size_in, edge_index, edge_weight):
    super().__init__()
    self.dim = size_in

    self.layer1 = nn.Linear(size_in*2, size_in )
    self.edge_index = edge_index
    self.layer2 = nn.Linear(size_in, 1, )
    self.edge_weight = edge_weight

  def forward(self, x):
    #
    # print('x in H: ',x.type())
    # print('edge_index in H: ', self.edge_index.type())
    out = self.layer1(x, )
    out = torch.sin(out)
    out = torch_sparse.spmm(self.edge_index, self.edge_weight, x.shape[0], x.shape[0], out)
    out = self.layer2(out,)
    out = torch.sin(out)
    # out = x

    out = torch.norm(out, dim=0)
    # out = torch.reshape(out,shape=[out1.shape[0],out1.shape[1]])
    # out = out.sum(dim=0)
    
    # out = self.layer2(out,)
    # print("out before spmm: ", out.shape)

    # print("out after spmm: ", out.shape)
    return out

# Define the ODE function.
# Input:
# --- t: A tensor with shape [], meaning the current time.
# --- x: A tensor with shape [#batches, dims], meaning the value of x at t.
# Output:
# --- dx/dt: A tensor with shape [#batches, dims], meaning the derivative of x at t.

class H_gcn(nn.Module):
  """"replace this module by a aggregation function """

  def __init__(self, size_in):
    super().__init__()
    self.dim = size_in

    self.layer1 =GCNConv(size_in*2, size_in, normalize=True)

    self.layer2 =GCNConv(size_in,1, normalize=True)
    self.dropout = nn.Dropout(p=0.4)
  def forward(self, x,edge_index):
    #
    out = self.layer1(x,edge_index)
    out = torch.tanh(out)
    # out = torch.relu(out)
    # out = self.dropout(out)
    out = self.layer2(out,edge_index)
    # out = torch.tanh(out)
    out = torch.norm(out, dim=0)
    # print("out.shape: ", out.shape)
    return out

class HAMGCNFunc_SKEW(ODEFunc):

  # currently requires in_features = out_features
  def __init__(self, in_features, out_features, opt, device):
    super(HAMGCNFunc_SKEW, self).__init__(opt, device)

    self.in_features = in_features
    self.out_features = out_features
    self.w = nn.Parameter(torch.eye(opt['hidden_dim']))
    self.d = nn.Parameter(torch.zeros(opt['hidden_dim']) + 1)
    self.alpha_sc = nn.Parameter(torch.ones(1))
    self.beta_sc = nn.Parameter(torch.ones(1))




    # self.H_derivatie_x= H_derivatie_x(in_features)
    # self.H_x = H_x(in_features)

    self.H = H_gcn(in_features ).to(device)
    self.multihead_att_layer = SpGraphTransAttentionLayer(in_features, out_features).to(device)

  def sparse_multiply(self, x):
    if self.opt['block'] in ['attention']:  # adj is a multihead attention
      # ax = torch.mean(torch.stack(
      #   [torch_sparse.spmm(self.edge_index, self.attention_weights[:, idx], x.shape[0], x.shape[0], x) for idx in
      #    range(self.opt['heads'])], dim=0), dim=0)
      mean_attention = self.attention_weights.mean(dim=1)
      ax = torch_sparse.spmm(self.edge_index, mean_attention, x.shape[0], x.shape[0], x)
    elif self.opt['block'] in ['mixed', 'hard_attention']:  # adj is a torch sparse matrix
      ax = torch_sparse.spmm(self.edge_index, self.attention_weights, x.shape[0], x.shape[0], x)
    else:  # adj is a torch sparse matrix
      ax = torch_sparse.spmm(self.edge_index, self.edge_weight, x.shape[0], x.shape[0], x)
    return ax

  def multiply_attention(self, x, attention, v=None):
      num_heads = 4
      mix_features = 0
      if mix_features:
          vx = torch.mean(torch.stack(
              [torch_sparse.spmm(self.edge_index, attention[:, idx], v.shape[0], v.shape[0], v[:, :, idx]) for idx in
               range(num_heads)], dim=0),
              dim=0)
          ax = self.multihead_att_layer.Wout(vx)
      else:
          mean_attention = attention.mean(dim=1)
          ax = torch_sparse.spmm(self.edge_index, mean_attention, x.shape[0], x.shape[0], x)
      return ax

  def forward(self, t, x_full):  # the t param is needed by the ODE solver.
    x = x_full[:, :self.opt['hidden_dim']]
    y = x_full[:, self.opt['hidden_dim']:]

    attention, values = self.multihead_att_layer(x, self.edge_index)
    ax = self.multiply_attention(y, attention, values)
    ay = self.multiply_attention(x, attention, values)

    if self.opt['add_source']:
      ax = (1. - torch.sigmoid(self.beta_train)) * ax + torch.sigmoid(self.beta_train) * self.x0[:, self.opt['hidden_dim']:]

    # f = torch.hstack([ax, -ay])

    f = torch.hstack([-ay, ax])



    return f



class SpGraphTransAttentionLayer(nn.Module):
    """
    Sparse version GAT layer, similar to https://arxiv.org/abs/1710.10903
    """

    def __init__(self, in_features, out_features, edge_weights=None,device=None):
        super(SpGraphTransAttentionLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.alpha = .2
        self.device = device
        self.h = int(4)  # num of attention heads
        self.edge_weights = edge_weights
        self.reweight_attention = 0
        self.sn = False

        try:
            self.attention_dim = 64
        except KeyError:
            self.attention_dim = self.out_features

        assert self.attention_dim % self.h == 0, "Number of heads ({}) must be a factor of the dimension size ({})".format(
            self.h, self.attention_dim)
        self.d_k = self.attention_dim // self.h

        if self.sn == True:
            self.Q = spectral_norm(nn.Linear(self.in_features, self.attention_dim))
            self.init_weights(self.Q)
            self.V = spectral_norm(nn.Linear(self.in_features, self.attention_dim))
            self.init_weights(self.V)
            self.K = spectral_norm(nn.Linear(self.in_features, self.attention_dim))
            self.init_weights(self.K)
        else:
            self.Q = nn.Linear(self.in_features, self.attention_dim)
            self.init_weights(self.Q)
            self.V = nn.Linear(self.in_features, self.attention_dim)
            self.init_weights(self.V)
            self.K = nn.Linear(self.in_features, self.attention_dim)
            self.init_weights(self.K)

        self.activation = nn.Sigmoid()  # nn.LeakyReLU(self.alpha)
        self.Wout = spectral_norm(nn.Linear(self.d_k, in_features))
        self.init_weights(self.Wout)

    def init_weights(self, m):
        if type(m) == nn.Linear:
            # nn.init.xavier_uniform_(m.weight, gain=1.414)
            # m.bias.data.fill_(0.01)
            nn.init.constant_(m.weight, 1e-5)

    def forward(self, x, edge):
        """
        x might be [features, augmentation, positional encoding, labels]
        """
        q = self.Q(x)
        k = self.K(x)
        v = self.V(x)

        # perform linear operation and split into h heads
        k = k.view(-1, self.h, self.d_k)
        q = q.view(-1, self.h, self.d_k)
        v = v.view(-1, self.h, self.d_k)

        # transpose to get dimensions [n_nodes, attention_dim, n_heads]

        k = k.transpose(1, 2)
        q = q.transpose(1, 2)
        v = v.transpose(1, 2)

        src = q[edge[0, :], :, :]
        dst_k = k[edge[1, :], :, :]

        prods = torch.sum(src * dst_k, dim=1) / np.sqrt(self.d_k)
        if self.reweight_attention and self.edge_weights is not None:
            prods = prods * self.edge_weights.unsqueeze(dim=1)
        attention = softmax(prods, edge[0])

        return attention, v  # (v, prods)

    def __repr__(self):
        return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'
