import torch
from torch import nn
import torch.nn.functional as F
from base_classes import BaseGNN
from model_configurations import set_block, set_function

from torch_geometric.nn import MLP, GINConv, global_add_pool
# Define the GNN model.
class GNN_OGB(BaseGNN):
  def __init__(self, opt, dataset, device=torch.device('cpu')):
    super(GNN_OGB, self).__init__(opt, dataset, device)
    self.f = set_function(opt)
    block = set_block(opt)
    time_tensor = torch.tensor([0, self.T]).to(device)
    self.odeblock = block(self.f, self.regularization_fns, opt, dataset.data, device, t=time_tensor).to(device)
    # self.alpha_ode = nn.Parameter(torch.tensor(torch.tensor(0.1), requires_grad=True))
    # print("self.alpha_ode: ",self.alpha_ode )


  def forward(self, x,adjs):
    # Encode each node based on its feature.
    # x = x_in.float()
    # x = data.x
    # edge_index = data.edge_index
    # batch = data.batch


    if self.opt['use_labels']:
      y = x[:, -self.num_classes:]
      x = x[:, :-self.num_classes]


    x = F.dropout(x, self.opt['input_dropout'], training=self.training)
    x = self.m1(x)
    for i, (edge_index, _, size) in enumerate(adjs):
      if self.opt['use_mlp']:
        x = F.dropout(x, self.opt['dropout'], training=self.training)
        x = F.dropout(x + self.m11(F.relu(x)), self.opt['dropout'], training=self.training)
        x = F.dropout(x + self.m12(F.relu(x)), self.opt['dropout'], training=self.training)
      # todo investigate if some input non-linearity solves the problem with smooth deformations identified in the ANODE paper

      if self.opt['use_labels']:
        x = torch.cat([x, y], dim=-1)

      if self.opt['batch_norm']:
        x = self.bn_in(x)

      # Solve the initial value problem of the ODE.
      if self.opt['augment']:
        c_aux = torch.zeros(x.shape).to(self.device)
        x = torch.cat([x, c_aux], dim=1)

      self.odeblock.set_x0(x)

      if self.training and self.odeblock.nreg > 0:
        z, self.reg_states = self.odeblock(x)
      else:
        # alpha_ode = torch.sigmoid(self.alpha_ode)
        # alpha_ode = self.alpha_ode
        z = self.odeblock(x,edge_index)

      if self.opt['augment']:
        z = torch.split(z, x.shape[1] // 2, dim=1)[0]

      # Activation.
      z = F.relu(z)

      if self.opt['fc_out']:
        z = self.fc(z)
        z = F.relu(z)

      # Dropout.
      z = F.dropout(z, self.opt['dropout'], training=self.training)

      # Decode each node embedding to get node label.

      z = self.m2(z)

    return z.log_softmax(dim=-1)

