import torch
import torch.nn as nn
import torch.nn.functional as F

class GraphUNetDenoiser(nn.Module):
    def __init__(self, hidden_dim, time_embed_dim=None):
        super().__init__()
        self.hidden_dim = hidden_dim
        if time_embed_dim is None:
            time_embed_dim = hidden_dim
        self.time_proj = nn.Sequential(
            nn.Linear(time_embed_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        self.down1 = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.SiLU()
        )
        
        self.down2 = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim*2),
            nn.LayerNorm(hidden_dim*2),
            nn.SiLU()
        )

        self.mid = nn.Sequential(
            nn.Linear(hidden_dim*2, hidden_dim*2),
            nn.LayerNorm(hidden_dim*2),
            nn.SiLU(),
            nn.Linear(hidden_dim*2, hidden_dim*2),
            nn.LayerNorm(hidden_dim*2),
            nn.SiLU()
        )

        self.up1 = nn.Sequential(
            nn.Linear(hidden_dim*2 + hidden_dim*2, hidden_dim*2),  
            nn.LayerNorm(hidden_dim*2),
            nn.SiLU(),
            nn.Linear(hidden_dim*2, hidden_dim)
        )
        
        self.up2 = nn.Sequential(
            nn.Linear(hidden_dim + hidden_dim, hidden_dim), 
            nn.LayerNorm(hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        self.final = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
    
    def forward(self, x, time_emb):
        t_emb = self.time_proj(time_emb)
        d1 = self.down1(x + t_emb) 
        d2 = self.down2(d1)

        mid = self.mid(d2)
        
        u1 = self.up1(torch.cat([mid, d2], dim=-1))  
        u2 = self.up2(torch.cat([u1, d1], dim=-1))   
        return self.final(u2)
    
class NodeFeatureNoiser(nn.Module):
    def __init__(self, hidden_dim, noise_schedule_size=10):
        super().__init__()
        self.hidden_dim = hidden_dim
        
        self.noise_schedule = nn.Parameter(torch.linspace(0.05, 0.2, noise_schedule_size))
        self.register_buffer('timesteps', torch.arange(noise_schedule_size))
        
        self.time_encoder = nn.Sequential(
            nn.Linear(1, hidden_dim//2),
            nn.SiLU(),
            nn.Linear(hidden_dim//2, hidden_dim)
        )
        
        self.residual_noise_layer = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
    
    def forward(self, x, training=True):
        batch_size = x.size(0)
        
        if training:
            t_idx = torch.randint(0, len(self.noise_schedule), (1,)).item()
            noise_level = self.noise_schedule[t_idx]
            t = self.timesteps[t_idx].expand(batch_size)
        else:
            t_idx = len(self.noise_schedule) // 3
            noise_level = self.noise_schedule[t_idx]
            t = self.timesteps[t_idx].expand(batch_size)
        
        raw_noise = torch.randn_like(x)

        t_emb = self.time_encoder(t.float().unsqueeze(-1))
        
        residual_noise = self.residual_noise_layer(raw_noise) * noise_level
        

        alpha = 1.0 - noise_level  
        noisy_x = alpha * x + residual_noise
        
        return noisy_x, residual_noise, t, t_emb


class DiffusionNodeEnhancer(nn.Module):
    def __init__(self, hidden_dim, noise_schedule_size=10, training_noise_scale=1.0):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.training_noise_scale = training_noise_scale
        self.noiser = NodeFeatureNoiser(hidden_dim, noise_schedule_size)
        self.denoiser = GraphUNetDenoiser(hidden_dim)
        
    def forward(self, x, edge_index=None, batch=None, pos=None, training=True):

        original_x = x.clone()
        noisy_x, added_noise, t, t_emb = self.noiser(x, training=training)
        self.added_noise = added_noise
        self.timestep_emb = t_emb
        return noisy_x, added_noise, t_emb
    
    def denoise(self, processed_features, external_t_emb=None, return_pred=False):
        time_emb = external_t_emb if external_t_emb is not None else self.timestep_emb
        
        if time_emb.size(0) != processed_features.size(0):
            batch_size = processed_features.size(0)
            default_t = torch.ones(batch_size, 1, device=processed_features.device) * 0.5
            time_emb = self.noiser.time_encoder(default_t)
        
        pred_noise = self.denoiser(processed_features, time_emb)
        
        denoised_features = processed_features - pred_noise
        
        denoised_features = 0.8 * denoised_features + 0.2 * processed_features
        
        if return_pred:
            return denoised_features, pred_noise
        else:
            return denoised_features