import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_scatter import scatter_mean, scatter_add

class PrecipitationEnsembleSelector(nn.Module):
    def __init__(self, hidden_dim, output_dim, n_samples=5):
        super(PrecipitationEnsembleSelector, self).__init__()
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.n_samples = n_samples
        
        self.self_attention = nn.MultiheadAttention(
            embed_dim=output_dim, 
            num_heads=1, 
            dropout=0.1
        )
        self.spatial_correlation = SpatialCorrelationNet(output_dim, hidden_dim)

        self.extreme_detector = nn.Sequential(
            nn.Linear(output_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )

        self.score_aggregator = nn.Sequential(
            nn.Linear(3, 128), 
            nn.ReLU(),
            nn.Linear(128, 1)
        )

        
    def forward(self, ensemble_outputs,edge_index=None, batch=None, pos=None):
    
        n_stations = ensemble_outputs.shape[1]
        
        attn_input = ensemble_outputs.permute(1, 0, 2) 
        attn_output, attn_weights = self.self_attention(
            attn_input, attn_input, attn_input
        )
        consensus_scores = attn_weights.mean(dim=1) 

        sample_consensus = consensus_scores.mean(dim=1)  
        
        spatial_scores = []
        
        if edge_index is not None:
            for i in range(self.n_samples):
                sample_output = ensemble_outputs[i]
                
                spatial_score = self.spatial_correlation(sample_output, edge_index) 
                
                avg_spatial_score = spatial_score.mean()
                spatial_scores.append(avg_spatial_score)
        else:
            spatial_scores = [torch.tensor(0.5, device=ensemble_outputs.device) for _ in range(self.n_samples)]
            
        spatial_scores = torch.stack(spatial_scores) 
        
        extreme_scores = []
        for i in range(self.n_samples):
            sample_output = ensemble_outputs[i] 
            
            reasonability = self.extreme_detector(sample_output) 
            avg_reasonability = reasonability.mean()
            extreme_scores.append(avg_reasonability)
            
        extreme_scores = torch.stack(extreme_scores)  
        combined_scores = torch.stack([
            sample_consensus,
            spatial_scores,
            extreme_scores
        ], dim=1)  
        
        final_scores = self.score_aggregator(combined_scores).squeeze(-1)  
        weights = F.softmax(final_scores, dim=0)  

        best_idx = torch.argmax(weights)
        selected_output = ensemble_outputs[best_idx]  
        
        weighted_output = (weights.view(-1, 1, 1) * ensemble_outputs).sum(dim=0)  
        
        return selected_output, weighted_output, weights

class SpatialCorrelationNet(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(SpatialCorrelationNet, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.norm1 = nn.LayerNorm(hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim // 2)
        self.norm2 = nn.LayerNorm(hidden_dim // 2)
        self.conv3 = GCNConv(hidden_dim // 2, 1)
        self.dropout = nn.Dropout(0.1)
        self.act = nn.ReLU()
        
    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = self.norm1(x)
        x = self.act(x)
        x = self.dropout(x)
        
        x = self.conv2(x, edge_index)
        x = self.norm2(x)
        x = self.act(x)
        x = self.dropout(x)
        
        x = self.conv3(x, edge_index)
        return x

class MultiLayerAttention(nn.Module):
    def __init__(self, embed_dim, num_heads=4, num_layers=2, dropout=0.1):
        super(MultiLayerAttention, self).__init__()
        self.layers = nn.ModuleList([
            nn.MultiheadAttention(
                embed_dim=embed_dim, 
                num_heads=num_heads, 
                dropout=dropout
            ) for _ in range(num_layers)
        ])
        self.norms = nn.ModuleList([
            nn.LayerNorm(embed_dim) for _ in range(num_layers)
        ])
        
    def forward(self, query, key, value):
        x = query
        for i, (attn, norm) in enumerate(zip(self.layers, self.norms)):
            residual = x
            x, weights = attn(x, x if i == 0 else key, x if i == 0 else value)
            x = norm(x + residual)
        return x, weights 

class EnhancedDetector(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(EnhancedDetector, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.norm1 = nn.LayerNorm(hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim // 2)
        self.norm2 = nn.LayerNorm(hidden_dim // 2)
        self.fc3 = nn.Linear(hidden_dim // 2, hidden_dim // 4)
        self.fc4 = nn.Linear(hidden_dim // 4, 1)
        self.dropout = nn.Dropout(0.2)
        
    def forward(self, x):
        x = self.fc1(x)
        x = self.norm1(x)
        x = F.relu(x)
        x = self.dropout(x)
        
        x = self.fc2(x)
        x = self.norm2(x)
        x = F.relu(x)
        x = self.dropout(x)
        
        x = self.fc3(x)
        x = F.relu(x)
        x = self.fc4(x)
        return torch.sigmoid(x)
    
class ScoreAggregator(nn.Module):
    def __init__(self, n_scores=3, hidden_dim=64):
        super(ScoreAggregator, self).__init__()
        self.score_weights = nn.Parameter(torch.ones(n_scores) / n_scores)
        self.fc1 = nn.Linear(n_scores, hidden_dim)
        self.norm1 = nn.LayerNorm(hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim // 2)
        self.fc3 = nn.Linear(hidden_dim // 2, 1)
        
    def forward(self, scores):
        weighted_scores = scores * F.softmax(self.score_weights, dim=0).unsqueeze(0)
        x = self.fc1(weighted_scores)
        x = self.norm1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)
        return x