import torch
import torchsde
import torch.nn as nn
import torch.nn.functional as F

from models.backbone import *

# used for decoding embedding to classes
class decoder_MLP(nn.Module):
  def __init__(self, c, cfg):
    super().__init__()
    self.cfg = cfg
    self.m21 = nn.Linear(cfg["hidden_channels"], cfg["hidden_channels"])
    self.m22 = nn.Linear(cfg["hidden_channels"], c)

  def forward(self, x):
    x = F.dropout(x, self.cfg["dropout"], training=self.training)
    x = F.dropout(x + self.m21(torch.tanh(x)), self.cfg["dropout"], training=self.training)
    x = F.dropout(self.m22(torch.tanh(x)), self.cfg["dropout"], training=self.training)
    return x


class decoder_MLP_simple(nn.Module):
  def __init__(self, c, cfg):
    super().__init__()
    self.cfg = cfg

    self.m22 = nn.Linear(cfg["hidden_channels"], c)

  def forward(self, x):
    x = self.m22(torch.tanh(x))
    return x

# used for encoding feature to embedding
class encoder_MLP(nn.Module):
  def __init__(self, d, cfg):
    super().__init__()
    self.cfg = cfg
    self.m11 = nn.Linear(d, cfg["hidden_channels"])


  def forward(self, x):

      x = x.to(self.cfg["device"])
      x = F.dropout(x, self.cfg["input_dropout"], training=self.training)

      x = self.m11(x)

      return x



class SGNN(torchsde.SDEIto):
    def __init__(self, d, c, cfg):
        super(SGNN, self).__init__(noise_type="diagonal")
        self.input_encoder = encoder_MLP(d, cfg)
        self.bnin = nn.BatchNorm1d(cfg["hidden_channels"])
        self.f_encoder = Drift(cfg["hidden_channels"], cfg["hidden_channels"], cfg["hidden_channels"], num_layers=1, dropout=cfg["dropout"], use_bn=cfg["use_bn"])
        # SFN includes the graph information to model  the dependency of noises.
        self.g_encoder = SFN(cfg["hidden_channels"], cfg["hidden_channels"], cfg["hidden_channels"], num_layers=1, dropout=cfg["dropout"], use_bn=cfg["use_bn"])
        
        self.bng = nn.BatchNorm1d(cfg["hidden_channels"])
        

        self.output_decoder = decoder_MLP_simple(c, cfg)

        self.cfg = cfg
        self.time = self.cfg["time"]

        self.N = self.cfg["N"]
        self.ts = torch.tensor([0, self.time])
        self.device = self.cfg["device"]

        self.ind_flag = True
        self.ind_edge_index = None
        self.ood_edge_index = None
        self.ind_edge_weight = None
        self.ood_edge_weight = None
        self.c_size = c
        self.sdeint_fn = torchsde.sdeint_adjoint if self.cfg["adjoint"] else torchsde.sdeint

    def reset_parameters(self):
        self.f_encoder.reset_parameters()
        self.g_encoder.reset_parameters()

    def f_net(self, t, y):
        if self.ind_flag == True:
            edge_index = self.ind_edge_index.to(self.device)
            ax = self.f_encoder(y, edge_index)
            return ax - y
        else:

            edge_index = self.ood_edge_index.to(self.device)
            ax = self.f_encoder(y, edge_index)
            return ax - y

    def g_net(self, t, y):
        if self.ind_flag == True:
            edge_index = self.ind_edge_index.to(self.device)
            g_output = self.g_encoder(y, edge_index)

            return y-g_output
        else:
            edge_index = self.ood_edge_index.to(self.device)
            g_output = self.g_encoder(y, edge_index)

            return y-g_output
    

    def forward(self, x, flag, device):
        self.ind_flag = flag
        node_embeddings = self.input_encoder(x)
        if self.cfg["use_bn"]:
            node_embeddings = self.bnin(node_embeddings)
        ts = torch.linspace(0, self.time, self.N).to(device)


        z = self.sdeint_fn(
            sde=self,
            y0=node_embeddings,
            ts=ts,
            method=self.cfg["method"],
            dt=self.cfg["dt"],
            adaptive=self.cfg["adaptive"],
            rtol=self.cfg["rtol"],
            atol=self.cfg["atol"],
            names={'drift': 'f_net', 'diffusion': 'g_net'}
        )

        hidden_embedding = z[-1]
        logits = self.output_decoder(hidden_embedding)

        return logits

