import torch
import torch.nn as nn
from torch.nn import Parameter
import torch.nn.functional as F

from GDSS.models.layers import DenseGCNConv, MLP
from GDSS.utils.graph_utils import mask_adjs, pow_tensor
from GDSS.models.attention import  Attention
from GDSS.utils.graph_utils import mask_x, node_feature_to_matrix

class TimeEmbeddingReLu(nn.Module):
    def __init__(self, embed_dim):
        super(TimeEmbeddingReLu, self).__init__()
        self.lin1 = nn.Linear(1, embed_dim)
        self.lin2 = nn.Linear(embed_dim, embed_dim)

    def forward(self, t_scalar):
        if not torch.is_tensor(t_scalar):
            t = torch.tensor([t_scalar], dtype=torch.float32, device=next(self.parameters()).device)
            t = t.unsqueeze(0)
        else:
            t = t_scalar.view(-1,1).to(next(self.parameters()).device)
        h = F.relu(self.lin1(t))
        out = F.relu(self.lin2(h))
        return out
        
class TimeEmbeddingSiLu(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(1, embed_dim),
            nn.SiLU(), 
            nn.Linear(embed_dim, embed_dim),
            nn.SiLU(),
        )
    def forward(self, t_scalar):
        if not torch.is_tensor(t_scalar):
            t = torch.tensor([t_scalar], dtype=torch.float32, device=next(self.parameters()).device).unsqueeze(0)
        else:
            t = t_scalar.view(-1, 1).to(next(self.parameters()).device)
        return self.net(t)

import math
import torch
import torch.nn as nn



class BaselineNetworkLayer(torch.nn.Module):

    def __init__(self, num_linears, conv_input_dim, conv_output_dim, input_dim, output_dim, batch_norm=False):

        super(BaselineNetworkLayer, self).__init__()

        self.convs = torch.nn.ModuleList()
        for _ in range(input_dim):
            self.convs.append(DenseGCNConv(conv_input_dim, conv_output_dim))
        self.hidden_dim = max(input_dim, output_dim)
        self.mlp_in_dim = input_dim + 2*conv_output_dim
        self.mlp = MLP(num_linears, self.mlp_in_dim, self.hidden_dim, output_dim, 
                            use_bn=False, activate_func=F.elu)
        self.multi_channel = MLP(2, input_dim*conv_output_dim, self.hidden_dim, conv_output_dim, 
                                    use_bn=False, activate_func=F.elu)
        
    def forward(self, x, adj, flags):
    
        x_list = []
        for _ in range(len(self.convs)):
            _x = self.convs[_](x, adj[:,_,:,:])
            x_list.append(_x)
        x_out = mask_x(self.multi_channel(torch.cat(x_list, dim=-1)) , flags)
        x_out = torch.tanh(x_out)

        x_matrix = node_feature_to_matrix(x_out)
        mlp_in = torch.cat([x_matrix, adj.permute(0,2,3,1)], dim=-1)
        shape = mlp_in.shape
        mlp_out = self.mlp(mlp_in.view(-1, shape[-1]))
        _adj = mlp_out.view(shape[0], shape[1], shape[2], -1).permute(0,3,1,2)
        _adj = _adj + _adj.transpose(-1,-2)
        adj_out = mask_adjs(_adj, flags)

        return x_out, adj_out


class BaselineNetwork(torch.nn.Module):

    def __init__(self, max_feat_num, max_node_num, nhid, num_layers, num_linears, 
                    c_init, c_hid, c_final, adim, num_heads=4, conv='GCN'):

        super(BaselineNetwork, self).__init__()

        self.nfeat = max_feat_num
        self.max_node_num = max_node_num
        self.nhid  = nhid
        self.num_layers = num_layers
        self.num_linears = num_linears
        self.c_init = c_init
        self.c_hid = c_hid
        self.c_final = c_final
        
        
        self.adim = adim
        self.num_heads = num_heads
        self.conv = conv
        

        self.layers = torch.nn.ModuleList()
        for _ in range(self.num_layers):
            if _==0:
                self.layers.append(BaselineNetworkLayer(self.num_linears, self.nfeat, self.nhid, self.c_init, self.c_hid))

            elif _==self.num_layers-1:
                self.layers.append(BaselineNetworkLayer(self.num_linears, self.nhid, self.nhid, self.c_hid, self.c_final))

            else:
                self.layers.append(BaselineNetworkLayer(self.num_linears, self.nhid, self.nhid, self.c_hid, self.c_hid)) 

        self.fdim = self.c_hid*(self.num_layers-1) + self.c_final + self.c_init
        self.final = MLP(num_layers=3, input_dim=self.fdim, hidden_dim=2*self.fdim, output_dim=1, 
                            use_bn=False, activate_func=F.elu)

    def forward(self, x, adj, flags=None):

        adjc = pow_tensor(adj, self.c_init)

        adj_list = [adjc]
        for _ in range(self.num_layers):

            x, adjc = self.layers[_](x, adjc, flags)
            adj_list.append(adjc)
        
        adjs = torch.cat(adj_list, dim=1).permute(0,2,3,1)
        out_shape = adjs.shape[:-1]
        score = self.final(adjs).view(*out_shape)
        B, N, _ = score.shape
        mask = torch.ones(N, N, device=score.device) - torch.eye(N, device=score.device)
        mask = mask.unsqueeze(0).expand(B, -1, -1)
        score = score * mask
        
        score = mask_adjs(score, flags)
        assert torch.all(score[~flags.unsqueeze(1).expand(-1, N, N)] == 0), "Padded nodes not fully masked"
        return score



class AdaRMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps

    def forward(self, x: torch.Tensor, gamma: torch.Tensor) -> torch.Tensor:
       
        norm_factor = torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
        x_norm = x * norm_factor
        while len(gamma.shape) < len(x.shape):
            gamma = gamma.unsqueeze(1)
            
        return x_norm * gamma



class SinusoidalTimeEmbedding(nn.Module):
    def __init__(self, dim: int, scale: float = 2*math.pi):
        super().__init__()
        if dim % 2 != 0:
            raise ValueError(f"SinusoidalPosEmb: dim must be even, got {dim}.")
        self.dim = dim
        self.scale = scale
        half = dim // 2
        inv_freq = torch.exp(-torch.arange(half, dtype=torch.float32) *
                             (math.log(10000.0) / (half - 1)))
        self.register_buffer("inv_freq", inv_freq, persistent=False)

    def forward(self, t: torch.Tensor) -> torch.Tensor:
        t = t.view(-1) * self.scale 
        inv = self.inv_freq.to(device=t.device, dtype=t.dtype)
        angles = t[:, None] * inv[None, :]        
        return torch.cat([angles.sin(), angles.cos()], dim=-1)


class AttentionLayer(torch.nn.Module):
    def __init__(self, num_linears, conv_input_dim, attn_dim, conv_output_dim, input_dim, output_dim,
                 num_heads=4, conv='GCN', time_embed_dim=32):

        super(AttentionLayer, self).__init__()
        self.attn = torch.nn.ModuleList()
        for _ in range(input_dim):
            self.attn.append(Attention(conv_input_dim, attn_dim, conv_output_dim,
                                       num_heads=num_heads, conv=conv))
        
        self.hidden_dim = 2*max(input_dim, output_dim)
        self.mlp = MLP(num_linears, 2*input_dim, self.hidden_dim, output_dim, use_bn=False, activate_func=F.elu)
        self.multi_channel = MLP(2, input_dim*conv_output_dim, self.hidden_dim, conv_output_dim,
                                 use_bn=False, activate_func=F.elu)
        self.norm_x = AdaRMSNorm(dim=conv_input_dim)
        self.norm_adj = AdaRMSNorm(dim=2*input_dim)
        self.time_mlp_x = nn.Sequential(
            nn.SiLU(),
            nn.Linear(time_embed_dim, conv_input_dim)
        )
        hidden_dim = time_embed_dim * 4
        self.time_mlp_adj = nn.Sequential(
            nn.Linear(time_embed_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, 2*input_dim)
        )

    
    def forward(self, x, adj, flags, time_emb):
        
        gamma_x = self.time_mlp_x(time_emb)
        gamma_adj = self.time_mlp_adj(time_emb)
        x_norm = self.norm_x(x, gamma_x)

        mask_list = []
        x_list = []
        for _ in range(len(self.attn)):
            _x, mask = self.attn[_](x_norm, adj[:,_,:,:], flags)
            mask_list.append(mask.unsqueeze(-1))
            x_list.append(_x)

        x_out = mask_x(self.multi_channel(torch.cat(x_list, dim=-1)), flags)
        x_out = torch.tanh(x_out)

        mlp_in = torch.cat([torch.cat(mask_list, dim=-1), adj.permute(0,2,3,1)], dim=-1)
        mlp_in_norm = self.norm_adj(mlp_in, gamma_adj)

        shape = mlp_in_norm.shape
        mlp_out = self.mlp(mlp_in_norm.view(-1, shape[-1]))
        
        _adj = mlp_out.view(shape[0], shape[1], shape[2], -1).permute(0,3,1,2)
        _adj = _adj + _adj.transpose(-1,-2)
        adj_out = mask_adjs(_adj, flags)

        return x_out, adj_out

class DenoiseNetworkA(BaselineNetwork):
    def __init__(self, max_feat_num, max_node_num, nhid, num_layers, num_linears, 
                 c_init, c_hid, c_final, adim, num_heads=4, conv='GCN'):

        super().__init__(max_feat_num, max_node_num, nhid, num_layers, num_linears,
                         c_init, c_hid, c_final, adim, num_heads, conv)
        
        self.time_embed_dim = 32
        self.time_embed = SinusoidalTimeEmbedding(self.time_embed_dim)
        self.layers = torch.nn.ModuleList()
        for _ in range(self.num_layers):
            if _==0:
                self.layers.append(AttentionLayer(self.num_linears, self.nfeat, self.nhid, self.nhid, self.c_init, 
                                                  self.c_hid, self.num_heads, self.conv, self.time_embed_dim))
            elif _==self.num_layers-1:
                self.layers.append(AttentionLayer(self.num_linears, self.nhid, self.adim, self.nhid, self.c_hid, 
                                                  self.c_final, self.num_heads, self.conv, self.time_embed_dim))
            else:
                self.layers.append(AttentionLayer(self.num_linears, self.nhid, self.adim, self.nhid, self.c_hid, 
                                                  self.c_hid, self.num_heads, self.conv, self.time_embed_dim))

        self.fdim = self.c_hid*(self.num_layers-1) + self.c_final + self.c_init
        self.final = MLP(num_layers=3, input_dim=self.fdim, hidden_dim=2*self.fdim, output_dim=1, 
                         use_bn=False, activate_func=F.elu)
    def forward(self, x, adj, flags, t):
        time_emb = self.time_embed(t)
        if adj.ndim == 4:
            adj = adj[:, 0]
        adjc = pow_tensor(adj, self.c_init)
        
        adj_list = [adjc]
        for layer in self.layers:
            x, adjc = layer(x, adjc, flags, time_emb)
            adj_list.append(adjc)
        adjs = torch.cat(adj_list, dim=1).permute(0, 2, 3, 1)
        B, N, _, _ = adjs.shape
        score = self.final(adjs).view(B, N, N)
        
        mask = torch.ones(N, N, device=score.device) - torch.eye(N, device=score.device)
        score = score * mask.unsqueeze(0)
        score = mask_adjs(score, flags)
        
        return score
