import torch
from torch import nn
from torch_geometric.utils import softmax
import torch_sparse
import torch.nn.functional as F
from torch_geometric.utils.loop import add_remaining_self_loops
from data import get_dataset
from utils import MaxNFEException
from base_classes import ODEFunc


class ODEFuncAtt_graphcon_terms(ODEFunc):

  def __init__(self, in_features, out_features, opt, data, device):
    super(ODEFuncAtt_graphcon_terms, self).__init__(opt, data, device)

    if opt['self_loop_weight'] > 0:
      self.edge_index, self.edge_weight = add_remaining_self_loops(data.edge_index, data.edge_attr,
                                                                   fill_value=opt['self_loop_weight'])
    else:
      self.edge_index, self.edge_weight = data.edge_index, data.edge_attr

    self.multihead_att_layer = SpGraphAttentionLayer(in_features, out_features, opt,
                                                     device).to(device)
    try:
      self.attention_dim = opt['attention_dim']
    except KeyError:
      self.attention_dim = out_features

    assert self.attention_dim % opt['heads'] == 0, "Number of heads must be a factor of the dimension size"
    self.d_k = self.attention_dim // opt['heads']

    self.weight_all = nn.ParameterList(
      [nn.Parameter(torch.tensor(0.01, device=device), requires_grad=True) for _ in range(opt['num_terms'])])

  def multiply_attention(self, x, attention, wx):
    if self.opt['mix_features']:
      wx = torch.mean(torch.stack(
        [torch_sparse.spmm(self.edge_index, attention[:, idx], wx.shape[0], wx.shape[0], wx) for idx in
         range(self.opt['heads'])], dim=0),
        dim=0)
      ax = torch.mm(wx, self.multihead_att_layer.Wout)
    else:
      ax = torch.mean(torch.stack(
        [torch_sparse.spmm(self.edge_index, attention[:, idx], x.shape[0], x.shape[0], x) for idx in
         range(self.opt['heads'])], dim=0),
        dim=0)
    return ax

  def forward(self, t, x_full):  # t is needed when called by the integrator
    # x = x_full[:, :self.opt['hidden_dim']]
    # y = x_full[:, self.opt['hidden_dim']:]
    # if self.nfe > self.opt["max_nfe"]:
    #   raise MaxNFEException
    # self.nfe += 1
    # attention, wy = self.multihead_att_layer(y, self.edge_index)
    # ay = self.multiply_attention(y, attention, wy)
    # # todo would be nice if this was more efficient
    #
    # if not self.opt['no_alpha_sigmoid']:
    #   alpha = torch.sigmoid(self.alpha_train)
    # else:
    #   alpha = self.alpha_train
    # f = (ay - y - x)
    # if self.opt['add_source']:
    #   f = (1. - torch.sigmoid(self.beta_train)) * f + torch.sigmoid(self.beta_train) * self.x0[:, self.opt['hidden_dim']:]
    # f = torch.cat([f, (1. - torch.sigmoid(self.beta_train2)) * alpha * x + torch.sigmoid(self.beta_train2) * self.x0[:,:self.opt['hidden_dim']]],dim=1)

    self.nfe += 1
    # print("x_full", x_full.shape)
    dim_x = self.opt['hidden_dim']
    x_0 = x_full[:, 0:dim_x]
    attention, wy = self.multihead_att_layer(x_0, self.edge_index)
    ax = self.multiply_attention(x_0, attention, wy) - x_0
    for n_order in range(1, self.opt['num_terms'], 2):
      # print("n_order", n_order)

      x = x_full[:, n_order * dim_x:(n_order + 1) * dim_x]
      # print("x", x.shape)
      ax = ax - ((self.weight_all[n_order])) * x
    if self.opt['add_source']:
      ax = (1.-F.sigmoid(self.beta_train))*ax + F.sigmoid(self.beta_train) * self.x0[:,0:self.opt['hidden_dim']]

    f = torch.cat([x_full[:, self.opt['hidden_dim']:], ax], dim=1)

    return f

  def __repr__(self):
    return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'


class SpGraphAttentionLayer(nn.Module):
  """
  Sparse version GAT layer, similar to https://arxiv.org/abs/1710.10903
  """

  def __init__(self, in_features, out_features, opt, device, concat=True):
    super(SpGraphAttentionLayer, self).__init__()
    self.in_features = in_features
    self.out_features = out_features
    self.alpha = opt['leaky_relu_slope']
    self.concat = concat
    self.device = device
    self.opt = opt
    self.h = opt['heads']

    try:
      self.attention_dim = opt['attention_dim']
    except KeyError:
      self.attention_dim = out_features

    assert self.attention_dim % opt['heads'] == 0, "Number of heads must be a factor of the dimension size"
    self.d_k = self.attention_dim // opt['heads']

    self.W = nn.Parameter(torch.zeros(size=(in_features, self.attention_dim))).to(device)
    nn.init.xavier_normal_(self.W.data, gain=1.414)

    self.Wout = nn.Parameter(torch.zeros(size=(self.attention_dim, self.in_features))).to(device)
    nn.init.xavier_normal_(self.Wout.data, gain=1.414)

    self.a = nn.Parameter(torch.zeros(size=(2 * self.d_k, 1, 1))).to(device)
    nn.init.xavier_normal_(self.a.data, gain=1.414)

    self.leakyrelu = nn.LeakyReLU(self.alpha)

  def forward(self, x, edge):
    wx = torch.mm(x, self.W)  # h: N x out
    h = wx.view(-1, self.h, self.d_k)
    h = h.transpose(1, 2)

    # Self-attention on the nodes - Shared attention mechanism
    edge_h = torch.cat((h[edge[0, :], :, :], h[edge[1, :], :, :]), dim=1).transpose(0, 1).to(
      self.device)  # edge: 2*D x E
    edge_e = self.leakyrelu(torch.sum(self.a * edge_h, dim=0)).to(self.device)
    attention = softmax(edge_e, edge[self.opt['attention_norm_idx']])
    return attention, wx

  def __repr__(self):
    return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'


if __name__ == '__main__':
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  opt = {'dataset': 'Cora', 'self_loop_weight': 1, 'leaky_relu_slope': 0.2, 'beta_dim': 'vc', 'heads': 2, 'K': 10, 'attention_norm_idx': 0,
         'add_source':False, 'alpha_dim': 'sc', 'beta_dim': 'vc', 'max_nfe':1000, 'mix_features': False}
  dataset = get_dataset(opt, '../data', False)
  t = 1
  func = ODEFuncAtt(dataset.data.num_features, 6, opt, dataset.data, device)
  out = func(t, dataset.data.x)
