import torch
from torch import nn as nn
from torch_geometric.data import Data
import torch
import torch.nn.functional as F
from Modules.Activations import Tanh
from Modules.GNN.GNN_Layer_External import GNN_Layer_External
from Modules.GNN.GNN_Layer_Internal import GNN_Layer_Internal
from Modules.GNN.selector import PrecipitationEnsembleSelector
from Modules.GNN.diffusion_ensemble import SimpleGraphDiffusion
from Modules.GNN.StationLinear import StationAwareLinear
from Modules.GNN.DyT import DynamicTanh
from Modules.GNN.DIffusion import DiffusionNodeEnhancer
from Modules.GNN.liquidnet import LiquidTimeSeries,GeoLiquidEmbedding
class MPNN(nn.Module):
    def __init__(self,
                 n_passing,
                 lead_hrs,
                 n_node_features_m,
                 n_node_features_e,
                 n_out_features,
                 hidden_dim=128,
                 heads=4,         
                 dropout=0.1,
                 liquid_weight=0.5):

        super(MPNN, self).__init__()
        self.liquid_weight = liquid_weight
        self.lead_hrs = lead_hrs
        self.n_node_features_m = n_node_features_m
        self.n_node_features_e = n_node_features_e
        self.n_passing = n_passing
        self.hidden_dim = hidden_dim
        self.n_out_features = n_out_features
        self.heads = heads        
        self.dropout = dropout    
        
        self.gnn_ex_1 = GNN_Layer_External(in_dim=self.hidden_dim, out_dim=self.hidden_dim, 
                                     hidden_dim=self.hidden_dim, ex_in_dim=self.n_node_features_e,
                                     heads=self.heads, dropout=self.dropout)
        self.gnn_ex_2 = GNN_Layer_External(in_dim=self.hidden_dim, out_dim=self.hidden_dim, 
                                     hidden_dim=self.hidden_dim, ex_in_dim=self.n_node_features_e,
                                     heads=self.heads, dropout=self.dropout)
        self.gnn_layers = nn.ModuleList(modules=(
            GNN_Layer_Internal(
                in_dim=self.hidden_dim,
                hidden_dim=self.hidden_dim,
                out_dim=self.hidden_dim,
                org_in_dim=self.n_node_features_m,
                heads=self.heads,
                dropout=self.dropout)
            for _ in range(self.n_passing)))
        
        self.embedding_mlp = GeoLiquidEmbedding(
            feature_dim=1,
            hidden_dim=self.hidden_dim,
            n_days=6,
            dt=0.1,
            steps=6
        )

        self.output_mlp = nn.Sequential(
            nn.Linear(self.hidden_dim, self.hidden_dim),
            DynamicTanh(hidden_dim, n_params=1),
            nn.Dropout(self.dropout),
            StationAwareLinear(self.hidden_dim, self.n_out_features, n_stations=2288)
            )
    
        self.diffusion_model = SimpleGraphDiffusion(
            hidden_dim=self.hidden_dim, 
            n_samples=10, 
            noise_level=0.01, 
            residual_strength=0.7  
        )

        self.ensemble_selector = PrecipitationEnsembleSelector(
            hidden_dim=self.hidden_dim,
            output_dim=self.n_out_features,
            n_samples=10)
        

    def build_graph_internal(self, x, station_lon, station_lat, edge_index):
        n_batch = x.size(0)
        n_stations = x.size(1)
        x = x.view(n_batch * n_stations, -1)
        pos = torch.cat((station_lon, station_lat), dim=2)
        pos = pos.view(n_batch * n_stations, -1)
        batch = torch.arange(n_batch).view(-1, 1) * torch.ones(1, n_stations)
        batch = batch.view(n_batch * n_stations, ).to(x.device)
        index_shift = (torch.arange(n_batch) * n_stations).view(-1, 1, 1).to(x.device)
        edge_index = torch.cat(list(edge_index + index_shift), dim=1)

        
        graph = Data(x=x, pos=pos, batch=batch.long(), edge_index=edge_index.long())
        return graph

    def build_graph_external(self, station_x, ex_x, ex_lon, ex_lat, edge_index):

        n_batch = station_x.size(0)
        n_stations_m = station_x.size(1)
        n_stations_e = ex_x.size(1)
        ex_x = ex_x.view(n_batch * n_stations_e, -1)

        ex_pos = torch.cat((ex_lon.view(n_batch, n_stations_e, 1), ex_lat.view(n_batch, n_stations_e, 1)), dim=2)
        ex_pos = ex_pos.view(n_batch * n_stations_e, -1)

        station_shift = (torch.arange(n_batch) * n_stations_m).view((n_batch, 1))
        ex_shift = (torch.arange(n_batch) * n_stations_e).view((n_batch, 1))
        shift = torch.cat((ex_shift, station_shift), dim=1).unsqueeze(-1).to(station_x.device)
        edge_index = torch.cat(list(edge_index + shift), dim=1)
        graph = Data(x=ex_x, pos=ex_pos, edge_index=edge_index.long())

        return graph
    

    def forward(self,
                station_x,
                station_lon,
                station_lat,
                edge_index,
                ex_lon,
                ex_lat,
                ex_x,
                edge_index_e2m):
        
        n_batch, n_stations_m, n_hours_m, n_features_m = station_x.shape
        station_x = station_x.view(n_batch, n_stations_m, -1)
        in_graph = self.build_graph_internal(station_x, station_lon, station_lat, edge_index)
        u = in_graph.x
        in_pos = in_graph.pos
        batch = in_graph.batch
        edge_index_m2m = in_graph.edge_index

        n_batch = station_x.size(0)
        n_stations_m = station_x.size(1)
        stations_per_batch = n_stations_m
        station_indices = torch.arange(stations_per_batch, device=station_x.device).repeat(n_batch)
        in_x = self.embedding_mlp(torch.cat((u, in_pos), dim=-1))
        if ex_x is not None:
            ex_graph = self.build_graph_external(station_x, ex_x, ex_lon, ex_lat, edge_index_e2m)
            ex_x = ex_graph.x
            ex_pos = ex_graph.pos
            edge_index_e2m = ex_graph.edge_index

        if ex_x is not None:
            in_x = self.gnn_ex_1(in_x, ex_x, in_pos, ex_pos, edge_index_e2m, batch,station_indices)
            
        for i in range(self.n_passing):
            in_x = self.gnn_layers[i](in_x, u, in_pos, edge_index_m2m, batch)

        if ex_x is not None:
            in_x = self.gnn_ex_2(in_x, ex_x, in_pos, ex_pos, edge_index_e2m, batch,station_indices)
        

        mean_sample, std_sample, samples_stack = self.diffusion_model(
            x=in_x,
            edge_index=edge_index_m2m,  
            batch=batch,               
            pos=in_pos                 
        )
        
        
        output_samples = []
        for i in range(samples_stack.size(0)):
            x = self.output_mlp[0](samples_stack[i])  
            x = self.output_mlp[1](x)  
            x = self.output_mlp[2](x) 
            out = self.output_mlp[3](x, station_indices)  
            output_samples.append(out)

        output_stack = torch.stack(output_samples)  
        _, weighted_output, _ = self.ensemble_selector(
            ensemble_outputs=output_stack,
            edge_index=edge_index_m2m,
            batch=batch,
            pos=in_pos
        )
        out= weighted_output.view(n_batch, n_stations_m, self.n_out_features)
        return torch.relu(out) 
    


