import torch
import torch.nn.functional as F
import torch.nn as nn
from torch_geometric.nn import GATConv, GCNConv, SAGEConv, GCN2Conv, GINConv
from torch_geometric.utils import to_dense_adj
#from models.layers import DenseGCNConv, MLP
import math

from models.layers import (ResidualModuleWrapper, FeedForwardModule, GCNModule, SAGEModule, GATModule, GATSepModule,
                     TransformerAttentionModule, TransformerAttentionSepModule)

def SinusoidalPosEmb(x, num_steps, dim,rescale=4):
    x = x / num_steps * num_steps*rescale
    device = x.device
    half_dim = dim // 2
    emb = math.log(10000) / (half_dim - 1)
    emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
    emb = x[:, None] * emb[None, :]
    emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
    return emb



class Simple_Model(nn.Module):
    def __init__(self, nfeat, nlabel,num_layers, num_linears, nhid, nhead=4, skip=False):
        super(Simple_Model, self).__init__()

        normalization = nn.LayerNorm
        module = GATSepModule

        self.input_linear = nn.Linear(in_features=nfeat, out_features=nhid)
        self.dropout = nn.Dropout(p=0.5)
        self.act = nn.GELU()

        self.residual_modules = nn.ModuleList()
        for _ in range(num_layers):

            residual_module = ResidualModuleWrapper(module=module,
                                                    normalization=normalization,
                                                    dim=nhid,
                                                    hidden_dim_multiplier=1,
                                                    num_heads=8,
                                                    dropout=0.5)

            self.residual_modules.append(residual_module)

        self.output_normalization = normalization(nhid)
        self.output_linear = nn.Linear(in_features=nhid, out_features=nlabel)

    def forward(self, x, graph):
        x = self.input_linear(x)
        x = self.dropout(x)
        x = self.act(x)

        for residual_module in self.residual_modules:
            x = residual_module(graph, x, train=True)

        x = self.output_normalization(x)
        x = self.output_linear(x)#.squeeze(1)
        #print(x.shape)
        return F.log_softmax(x, dim=1)


class Denoising_Model(torch.nn.Module):
    def __init__(self, nlabel, nfeat, num_layers, num_linears, nhid, nhead=4, skip=False):
        super(Denoising_Model, self).__init__()

        normalization = nn.LayerNorm
        module = GATSepModule

        self.input_linear = nn.Linear(in_features=nfeat, out_features=nhid)
        self.dropout = nn.Dropout(p=0.5)
        self.act = nn.GELU()

        self.residual_modules = nn.ModuleList()
        for _ in range(num_layers):

            residual_module = ResidualModuleWrapper(module=module,
                                                    normalization=normalization,
                                                    dim=nhid,
                                                    hidden_dim_multiplier=1,
                                                    num_heads=8,
                                                    dropout=0.5, nlabel = nlabel)

            self.residual_modules.append(residual_module)

        self.output_normalization = normalization(nhid)
        self.output_linear = nn.Linear(in_features=nhid+nlabel, out_features=nlabel)

        self.time_mlp = nn.Sequential(
            nn.Linear(128, 128),
            nn.ELU(),
            nn.Linear(128, nhid)
        )
            
    def forward(self, x, q_Y_sample, adj, t, num_steps, train=False):
        t = SinusoidalPosEmb(t, num_steps, 128)
        t = self.time_mlp(t)        
        x = self.input_linear(x)
        if train:
            x = self.dropout(x)
        x = self.act(x)
        #x = torch.cat([x, q_Y_sample], dim = -1)

        for residual_module in self.residual_modules[:-1]:
            x = residual_module(adj, x, q_Y_sample, train) + t
            #x = torch.cat([x, q_Y_sample], dim = -1)

        x = self.residual_modules[-1](adj, x, q_Y_sample, train) + t
        x = self.output_normalization(x)
        x = torch.cat([x, q_Y_sample], dim = -1)
        x = self.output_linear(x)#.squeeze(1)
        return x




