import torch
from torch import nn as nn
from torch_geometric.data import Data
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, MessagePassing
from torch_scatter import scatter_mean, scatter_add, scatter_max

class Diffusion(nn.Module):
    def __init__(self, hidden_dim, n_samples=10, noise_level=0.1):
        super(Diffusion, self).__init__()
        self.hidden_dim = hidden_dim
        self.n_samples = n_samples
        self.noise_level = noise_level
        self.noise_predictor = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim*2),
            nn.ReLU(),
            nn.Linear(hidden_dim*2, hidden_dim)
        )
        
    def forward(self, x, batch=None, pos=None):
        samples = []

        for _ in range(self.n_samples):

            noise = torch.randn_like(x) * self.noise_level
            noisy_x = x + noise

            pred_noise = self.noise_predictor(noisy_x)
            denoised = noisy_x - pred_noise
            
            samples.append(denoised)
        
        ensemble_mean = torch.stack(samples).mean(dim=0)
        
        return ensemble_mean

class GraphAwareDiffusion(nn.Module):

    def __init__(self, hidden_dim, n_samples=10, noise_level=0.1):
        super(GraphAwareDiffusion, self).__init__()
        self.hidden_dim = hidden_dim
        self.n_samples = n_samples
        self.noise_level = noise_level
        

        self.pos_encoder = nn.Sequential(
            nn.Linear(2, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, hidden_dim)
        )
        self.graph_conv1 = GCNConv(hidden_dim, hidden_dim)
        self.graph_conv2 = GCNConv(hidden_dim, hidden_dim)
        
        self.noise_predictor = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim * 2), 
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim * 2, hidden_dim)
        )
        
        self.spatial_encoder = nn.Sequential(
            nn.Linear(2, hidden_dim // 4),
            nn.ReLU(),
            nn.Linear(hidden_dim // 4, 1)
        )
    
    def generate_correlated_noise(self, x, pos, batch):
        base_noise = torch.randn_like(x)
        
        spatial_weights = self.spatial_encoder(pos)  
        spatial_weights = torch.sigmoid(spatial_weights) * 2  
        
        batch_sizes = scatter_add(torch.ones_like(batch), batch)
        max_nodes = int(batch_sizes.max().item())
        
        correlated_noise = []
        
        for b in range(len(batch_sizes)):

            batch_mask = (batch == b)
            batch_noise = base_noise[batch_mask]
            batch_pos = pos[batch_mask]
            batch_weights = spatial_weights[batch_mask]
            

            n_nodes = batch_mask.sum()
            pos_i = batch_pos.unsqueeze(1).repeat(1, n_nodes, 1)  
            pos_j = batch_pos.unsqueeze(0).repeat(n_nodes, 1, 1)  

            dist = torch.sqrt(((pos_i - pos_j) ** 2).sum(dim=2))  

            similarity = torch.exp(-dist / 0.2)  
            
            similarity = F.normalize(similarity, p=1, dim=1)
            
            weighted_noise = torch.mm(similarity, batch_noise)
            

            batch_correlated = batch_weights * base_noise[batch_mask] + (1 - batch_weights) * weighted_noise
            
            correlated_noise.append(batch_correlated)

        return torch.cat(correlated_noise, dim=0)
    
    def forward(self, x, edge_index=None, batch=None, pos=None):
        if edge_index is None or batch is None or pos is None:

            return self.simple_forward(x)

        pos_embedding = self.pos_encoder(pos)
        
        samples = []
        uncertainty_contributions = []
        for _ in range(self.n_samples):
            noise = self.generate_correlated_noise(x, pos, batch) * self.noise_level
            noisy_x = x + noise
            graph_x = F.relu(self.graph_conv1(noisy_x, edge_index))
            graph_x = F.dropout(graph_x, p=0.1, training=self.training)
            graph_x = self.graph_conv2(graph_x, edge_index)
            combined_features = torch.cat([graph_x, pos_embedding], dim=1)
            
            pred_noise = self.noise_predictor(combined_features)
            
            denoised = noisy_x - pred_noise
            
            samples.append(denoised)
            uncertainty_contributions.append((denoised - x).abs())
        
        samples_stacked = torch.stack(samples)
        
        ensemble_mean = samples_stacked.mean(dim=0)
        

        uncertainty = samples_stacked.std(dim=0)
        
        return ensemble_mean, uncertainty
    
    def simple_forward(self, x):
        samples = []
        
        for _ in range(self.n_samples):
            noise = torch.randn_like(x) * self.noise_level
            noisy_x = x + noise
            
            pred_noise = self.noise_predictor(torch.cat([noisy_x, torch.zeros_like(noisy_x)], dim=1))
            denoised = noisy_x - pred_noise
            
            samples.append(denoised)
        
        ensemble_mean = torch.stack(samples).mean(dim=0)
        uncertainty = torch.stack(samples).std(dim=0)
        
        return ensemble_mean, uncertainty


import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class SimpleGraphDiffusion(nn.Module):
    def __init__(self, hidden_dim, n_samples=5, noise_level=0.01, residual_strength=0.95):
        super(SimpleGraphDiffusion, self).__init__()
        self.hidden_dim = hidden_dim
        self.n_samples = n_samples
        self.noise_level = noise_level
        self.residual_strength = residual_strength
        
        self.pos_encoder = nn.Linear(2, hidden_dim)

        self.down_conv1 = GCNConv(hidden_dim, hidden_dim)
        self.down_conv2 = GCNConv(hidden_dim, hidden_dim)
        self.middle_conv = GCNConv(hidden_dim, hidden_dim)
        self.up_conv1 = GCNConv(hidden_dim * 2, hidden_dim)
        self.up_conv2 = GCNConv(hidden_dim * 2, hidden_dim)
        self.output_conv = nn.Linear(hidden_dim, hidden_dim)
        
        self.weight_net = nn.Sequential(
            nn.Linear(hidden_dim, 1),
            nn.Softplus()
        )

    def unet_process(self, x, edge_index):
        """UNet结构"""
        x1 = F.relu(self.down_conv1(x, edge_index))
        x2 = F.relu(self.down_conv2(x1, edge_index))
        x3 = F.relu(self.middle_conv(x2, edge_index))
        x4 = torch.cat([x3, x2], dim=1)
        x4 = F.relu(self.up_conv1(x4, edge_index))
        x5 = torch.cat([x4, x1], dim=1)
        x5 = F.relu(self.up_conv2(x5, edge_index))
        out = self.output_conv(x5)
        return out

    def forward(self, x, edge_index,batch=None,pos=None, output_mode="mean"):
        base_features = x
        pos_features = self.pos_encoder(pos)

        samples = []
        uncertainties = []

        for _ in range(self.n_samples):
            noise = torch.randn_like(x) * self.noise_level
            noisy_x = x + noise
            input_features = noisy_x + pos_features
            denoised = self.unet_process(input_features, edge_index)

            sample = self.residual_strength * base_features + (1 - self.residual_strength) * denoised
            samples.append(sample)

        samples_stack = torch.stack(samples) 
        mean_sample = samples_stack.mean(dim=0, keepdim=True)  
        uncertainties = torch.sqrt(((samples_stack - mean_sample)**2).sum(dim=(1, 2)))  

        if output_mode == "mean":
            return samples_stack.mean(dim=0), uncertainties, samples_stack

        elif output_mode == "weighted":
            weights = 1.0 / (uncertainties + 1e-6)
            weights = weights / weights.sum()
            weighted_sample = torch.sum(samples_stack * weights.view(-1, 1, 1), dim=0)
            return weighted_sample, uncertainties, samples_stack

        elif output_mode == "best":
            best_idx = torch.argmin(uncertainties)
            best_sample = samples_stack[best_idx]
            return best_sample, uncertainties, samples_stack

        elif output_mode == "learned":
            learned_weights = []
            for sample in samples_stack:
                weight = self.weight_net(sample).mean() 
                learned_weights.append(weight)
            learned_weights = torch.stack(learned_weights)
            learned_weights = F.softmax(learned_weights, dim=0)

            learned_sample = torch.sum(samples_stack * learned_weights.view(-1, 1, 1), dim=0)
            return learned_sample, uncertainties, samples_stack

        elif output_mode == "all":
            return samples_stack, uncertainties, samples_stack
