import torch
from torch import nn
import torch.nn.functional as F
from base_classes import BaseGNN
from model_configurations import set_block, set_function

from torch_geometric.utils import get_laplacian,to_dense_adj,to_scipy_sparse_matrix,add_remaining_self_loops
import numpy as np
# Define the GNN model.


class GNN_energy(BaseGNN):
  def __init__(self, opt, dataset, device=torch.device('cpu')):
    super(GNN_energy, self).__init__(opt, dataset, device)
    self.f = set_function(opt)
    block = set_block(opt)
    time_tensor = torch.tensor([0, self.T]).to(device)
    self.odeblock = block(self.f, self.regularization_fns, opt, dataset.data, device, t=time_tensor).to(device)
    # self.alpha_ode = nn.Parameter(torch.tensor(torch.tensor(0.1), requires_grad=True))
    # print("self.alpha_ode: ",self.alpha_ode )
    # self.linear_term = nn.Linear(opt['hidden_dim'] * opt['num_terms'], opt['hidden_dim'])
    self.edge_index = dataset.data.edge_index

  def forward(self, x, pos_encoding=None,cal_energy=False):
    # Encode each node based on its feature.
    if self.opt['use_labels']:
      y = x[:, -self.num_classes:]
      x = x[:, :-self.num_classes]

    if self.opt['beltrami']:
      x = F.dropout(x, self.opt['input_dropout'], training=self.training)
      x = self.mx(x)
      p = F.dropout(pos_encoding, self.opt['input_dropout'], training=self.training)
      p = self.mp(p)
      x = torch.cat([x, p], dim=1)
    else:
      x = F.dropout(x, self.opt['input_dropout'], training=self.training)
      x = self.m1(x)

    if self.opt['use_mlp']:
      x = F.dropout(x, self.opt['dropout'], training=self.training)
      x = F.dropout(x + self.m11(F.relu(x)), self.opt['dropout'], training=self.training)
      x = F.dropout(x + self.m12(F.relu(x)), self.opt['dropout'], training=self.training)
    # todo investigate if some input non-linearity solves the problem with smooth deformations identified in the ANODE paper

    if self.opt['use_labels']:
      x = torch.cat([x, y], dim=-1)

    if self.opt['batch_norm']:
      x = self.bn_in(x)

    # Solve the initial value problem of the ODE.
    if self.opt['augment']:
      c_aux = torch.zeros(x.shape).to(self.device)
      x = torch.cat([x, c_aux], dim=1)
    x_init = x.clone()

    if 'graphcon' in self.opt['function']:
      if 'term' in self.opt['function']:
        # x2 = torch.zeros_like(x, device=self.device)
        for j in range(self.opt['num_terms'] - 1):
          if j//(self.opt['num_terms']/2) == 0:
            x2 = x_init.clone()
          else:
            x2 = torch.zeros_like(x_init, device=self.device)
          x = torch.cat((x, x2), dim=1)
        # print("x.shape in GNN: ",x.shape)
        self.odeblock.set_x0(x)
        if self.training and self.odeblock.nreg > 0:
          z, self.reg_states = self.odeblock(x)
        else:
          # alpha_ode = torch.sigmoid(self.alpha_ode)
          # alpha_ode = self.alpha_ode
          z = self.odeblock(x)
        z = z[:, 0:self.opt['hidden_dim']]
      else:

        x = torch.cat([x, x], dim=-1)
        self.odeblock.set_x0(x)

        if self.training and self.odeblock.nreg > 0:
          z, self.reg_states = self.odeblock(x)
        else:
            # alpha_ode = torch.sigmoid(self.alpha_ode)
            # alpha_ode = self.alpha_ode
          z = self.odeblock(x)
        z = z[:,self.opt['hidden_dim']:]
    elif 'term' in self.opt['function']:


      x2 = torch.zeros_like(x, device=self.device)
      # x2 = x.clone()
      # x2 = torch.ones_like(x, device=self.device)
      # x2 = torch.ones_like(x,device=self.device)
      for _ in range(self.opt['num_terms'] - 1):
        x = torch.cat((x, x2), dim=1)
      self.odeblock.set_x0(x)

      if self.training and self.odeblock.nreg > 0:
        z, self.reg_states = self.odeblock(x)
      else:
        # alpha_ode = torch.sigmoid(self.alpha_ode)
        # alpha_ode = self.alpha_ode
        z = self.odeblock(x)
      z = z[:,0:self.opt['hidden_dim']]
        # z = z[:,-self.opt['hidden_dim']:]
        # # z = self.linear_term(z)
        # # reshape z to be of shape (batch_size, num_terms, hidden_dim)
        # z = z.reshape(z.shape[0], self.opt['num_terms'], self.opt['hidden_dim'])
        # # mean over the num_terms dimension
        # z = torch.mean(z, dim=1)
        # #


    else:

      self.odeblock.set_x0(x)

      if self.training and self.odeblock.nreg > 0:
        z, self.reg_states = self.odeblock(x)
      else:
        # alpha_ode = torch.sigmoid(self.alpha_ode)
        # alpha_ode = self.alpha_ode
        z = self.odeblock(x)

    # calculate the dirichlet energy
    # energy_diri = cal_dirichlet_energy(z.detach().clone(),self.edge_index)
    # print("time: ",self.opt['time'])
    # print("energy_diri: ",energy_diri)
    z_ode = z.clone()

    if self.opt['augment']:
      z = torch.split(z, x.shape[1] // 2, dim=1)[0]

    # Activation.
    z = F.relu(z)

    if self.opt['fc_out']:
      z = self.fc(z)
      z = F.relu(z)

    # Dropout.
    z = F.dropout(z, self.opt['dropout'], training=self.training)

    # Decode each node embedding to get node label.
    z = self.m2(z)

    if cal_energy:
      return z_ode
    else:
      return z

def scipy_sparse_to_torch_sparse(x):
  coo = x.tocoo()
  indices = torch.from_numpy(np.vstack((coo.row, coo.col)))
  # convert indices to float to avoid error
  # indices = indices.to(torch.float)
  values = torch.from_numpy(coo.data)
  shape = torch.Size(coo.shape)
  # RuntimeError: expected scalar type Long but found Float
  indices = indices.to(torch.long)
  return torch.sparse.FloatTensor(indices, values, shape)

def cal_dirichlet_energy(x, edge_index):
  # calculate the dirichlet energy of the graph
  # x: node feature
  # edge_index: edge index
  # return: dirichlet energy
  # dirichlet energy = trace(x.T @ L @ x)
  lap_edge_index, lap_edge_weight = get_laplacian(edge_index, edge_weight=None, normalization="sym",
                                                  num_nodes=x.shape[0])
  lap = to_scipy_sparse_matrix(lap_edge_index, edge_attr=lap_edge_weight, num_nodes=x.shape[0])
  lap = scipy_sparse_to_torch_sparse(lap)
  # RuntimeError: sparse tensors do not have strides
  lap = lap.to_dense()
  # lap to x device
  lap = lap.to(x.device)

  dirichlet_energy = torch.trace(torch.mm(torch.mm(x.T, lap), x))
  # print("x: ",x)
  # print("dirichlet_energy: ",dirichlet_energy)
  return dirichlet_energy