import torch
from torch import nn
import torch_sparse
import torch.nn.functional as F
from base_classes import ODEFunc
from utils import MaxNFEException
from torch_geometric.utils.loop import add_remaining_self_loops,remove_self_loops
from torch_geometric.utils import get_laplacian

# Define the ODE function.
# Input:
# --- t: A tensor with shape [], meaning the current time.
# --- x: A tensor with shape [#batches, dims], meaning the value of x at t.
# Output:
# --- dx/dt: A tensor with shape [#batches, dims], meaning the derivative of x at t.
class LaplacianODEFunc_multiterm(ODEFunc):

  # currently requires in_features = out_features
  def __init__(self, in_features, out_features, opt, data, device):
    super(LaplacianODEFunc_multiterm, self).__init__(opt, data, device)

    self.in_features = in_features
    self.out_features = out_features
    self.w = nn.Parameter(torch.eye(opt['hidden_dim']))
    self.d = nn.Parameter(torch.zeros(opt['hidden_dim']) + 1)
    self.alpha_sc = nn.Parameter(torch.ones(1))
    self.beta_sc = nn.Parameter(torch.ones(1))
    self.opt = opt
    # self.weight_all=[]
    # for _ in range(opt['num_terms']):
    #   # create a learnable for each term
    #   self.weight_all.append(nn.Parameter(torch.tensor(1.0),requires_grad=True).to(device))

    self.weight_all = nn.ParameterList([nn.Parameter(torch.tensor(0.01,device=device),requires_grad=True) for _ in range(opt['num_terms'])])

    # self.linear_fc = nn.Linear(in_features, 1)
    # if opt['self_loop_weight'] > 0:
    #   self.edge_index, self.edge_weight = add_remaining_self_loops(data.edge_index, data.edge_attr,
    #                                                                fill_value=opt['self_loop_weight'])
    # else:
    #   self.edge_index, self.edge_weight = data.edge_index, data.edge_attr
    #
    # self.edge_index, self.edge_weight = remove_self_loops(self.edge_index, self.edge_weight)
    # self.lin2 = nn.Linear(in_features * 2, out_features)
    # nn.init.xavier_normal_(self.lin2.weight, gain=1.414)

  def sparse_multiply(self, x):
    if self.opt['block'] in ['attention','att_frac']:  # adj is a multihead attention
      mean_attention = self.attention_weights.mean(dim=1)
      ax = torch_sparse.spmm(self.edge_index, mean_attention, x.shape[0], x.shape[0], x)
    elif self.opt['block'] in ['mixed', 'hard_attention']:  # adj is a torch sparse matrix
      ax = torch_sparse.spmm(self.edge_index, self.attention_weights, x.shape[0], x.shape[0], x)
    else:  # adj is a torch sparse matrix
      ax = torch_sparse.spmm(self.edge_index, self.edge_weight, x.shape[0], x.shape[0], x)
    return ax

  def forward(self, t, x_full):  # the t param is needed by the ODE solver.
    # if self.nfe > self.opt["max_nfe"]:
    #   raise MaxNFEException
    self.nfe += 1
    # print("x_full shape: ", x_full.shape)
    dim_x = self.opt['hidden_dim']
    x_all_orders = []
    x_0 = x_full[:, 0:dim_x]
    ax =  (self.sparse_multiply(x_0) - x_0)

    # print("x_full: ", x_full)
    for n_order in range(self.opt['num_terms'] - self.opt['end_terms'],self.opt['num_terms']):
      # assign x_full to each x term
      x = x_full[:, n_order * dim_x:(n_order + 1) * dim_x]
      # print("n_order: ", n_order)
      # print("weight: ", self.weight_all[n_order])
      # print("x : ", x)
      # print("x shape in function: ", x.shape)
      # store each x for later use
      # x_all_orders.append(x)
      # print("weight: ", self.weight_all[n_order])
      ax =ax - ((self.weight_all[n_order])) * x

      # ax = F.sigmoid(self.weight_all[0]) * ax - (1- F.sigmoid(self.weight_all[n_order])) * x
      # ax = ax - x
      # print("n_order: ", torch.tanh(self.weight_all[n_order]))
      # ax = ax - (torch.tanh(self.weight_all[n_order])) * x
      # ax = ax - 0 * x
      # ax = ax - self.linear_fc(x) * x
      #
      # print("self.linear_fc(x): ", self.linear_fc(x))


      # ax = ax - x

    # print("x_all_orders shape: ", len(x_all_orders))
    # print("weight_all shape: ", self.weight_all)
    # for weight in self.weight_all:
    #   print("weight: ", weight)
    # print("ax:", ax)

    # for n_order in range(0, self.opt['num_terms']):
    #   ax = ax - x_all_orders[n_order]


    # for n_order in range(1, self.opt['num_terms']):
    #   print("n_order: ", n_order)
    #   print("weight: ", self.weight_all[n_order])
    #   #ax = ax -(torch.sigmoid(self.weight_all[n_order])) * x_all_orders[n_order]
    #   ax = ax - ((self.weight_all[n_order])) * x_all_orders[n_order]

    # f =ax
    # for n_order in range(1,self.opt['num_terms']):
    #   # print("n_order: ", n_order)
    #   x_n =x_all_orders[n_order]
    #   f = torch.cat([(torch.sigmoid(self.weight_all[n_order])) * x_all_orders[n_order],f ], dim=1)

    # f =ax
    # for n_order in range(self.opt['num_terms']-1,0,-1):
    #   x = x_full[:, n_order * dim_x:(n_order + 1) * dim_x]
    #   # print("n_order: ", n_order)
    #   # print("weight: ", self.weight_all[n_order])
    #   # x_n =x_all_orders[n_order]
    #   # f = torch.cat([(torch.sigmoid(self.weight_all[n_order])) * x_all_orders[n_order],f ], dim=1)
    #   #f = torch.cat([((self.weight_all[n_order])) * x, f], dim=1)
    #   f = torch.cat([x, f], dim=1)
    # f = torch.cat([x_full[:,dim_x * 1:int(self.opt['num_terms'])*dim_x], ax],dim=1)

    f = torch.cat([x_full[:,self.opt['hidden_dim']:],ax],dim=1)



    # print("f shape: ", f)
      # f = torch.cat([f, (1-torch.sigmoid(self.weight_all[n_order])) * x_all_orders[n_order] + torch.sigmoid(self.weight_all[n_order]) * x_all_orders[0] ], dim=1)
    # print("f shape: ", f.shape)
    # ax = torch.cat([x, ax], axis=1)
    # ax = self.lin2(ax)
    #
    # if not self.opt['no_alpha_sigmoid']:
    #   alpha = torch.sigmoid(self.alpha_train)
    # else:
    #   alpha = self.alpha_train
    # # # print("self.x0[:, :self.opt['hidden_dim']]: ", self.x0[:, :self.opt['hidden_dim']].shape)
    # # # print("x_full[:,dim_x * 1:int(self.opt['num_terms'])*dim_x]: ", x_full[:,dim_x * 1:int(self.opt['num_terms'])*dim_x].shape)
    # # # print("ax: ", ax.shape)
    # # if self.opt['add_source']:
    # #   ax = (1. - torch.sigmoid(self.beta_train)) * ax + torch.sigmoid(self.beta_train) * self.x0[:, 0:self.opt['hidden_dim']]
    # f = torch.cat([(1. - torch.sigmoid(self.beta_train2)) * alpha * x_full[:,dim_x * 1:int(self.opt['num_terms'])*dim_x] + torch.sigmoid(self.beta_train2) * self.x0[:, self.opt['hidden_dim']:],ax], dim=1)
    # #
    # f = alpha * (ax - x)
    if self.opt['add_source']:
      f = f + self.beta_train * self.x0
    # f = F.dropout(f, p=self.opt['dropout'], training=self.training)
    return f
