import torch
from torch import nn as nn
from torch_geometric.nn import MessagePassing, InstanceNorm
from torch.nn import Parameter
import torch.nn.functional as F
import torch_scatter
from Modules.Activations import Tanh
import torch_geometric
from Modules.GNN.StationLinear import StationAwareLinear
from Modules.GNN.DyT import DynamicTanh
from Modules.GNN.liquidnet import LiquidNeuron,GeoLiquidEmbedding,LiquidEmbedding
class GNN_Layer_External(MessagePassing):

    def __init__(self, in_dim, out_dim, hidden_dim, ex_in_dim, heads=4, dropout=0.1):
        super(GNN_Layer_External, self).__init__(node_dim=-2, aggr=None)

        self.in_dim = in_dim
        self.out_dim = out_dim
        self.hidden_dim = hidden_dim
        self.ex_in_dim = ex_in_dim
        self.heads = heads 
        self.dropout = dropout

        self.ex_embed_net_1 = nn.Sequential(nn.Linear(1 + 2, hidden_dim),
                                            Tanh()
                                            )
        self.ex_embed_net_2 = nn.Sequential(nn.Linear(hidden_dim, hidden_dim),
                                            DynamicTanh(hidden_dim, n_params=3)
                                            )
        self.temporal_attention = nn.Sequential(
            nn.Linear(ex_in_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 1),
            nn.Softmax(dim=1)
        )
        self.message_net_1 = nn.Sequential(nn.Linear(in_dim + hidden_dim + 2, hidden_dim),
                                           Tanh()
                                           )
        self.message_net_2 = nn.Sequential(nn.Linear(hidden_dim, hidden_dim),
                                           DynamicTanh(hidden_dim, n_params=1)
                                           )
        
        
        self.att_net = nn.Sequential(
            nn.Linear(hidden_dim * 2 + 2, hidden_dim),  
            nn.LayerNorm(hidden_dim),  
            nn.LeakyReLU(0.2),
            nn.Dropout(dropout),  
            nn.Linear(hidden_dim, heads),  
            nn.LeakyReLU(0.2)
        )
        self.distance_beta = nn.Parameter(torch.tensor(1.0)) 
        self.update_net_1 = nn.Sequential(nn.Linear(in_dim + hidden_dim, hidden_dim),
                                          Tanh()
                                          )
        self.update_net_2 = nn.Sequential(nn.Linear(hidden_dim, out_dim),
                                          DynamicTanh(hidden_dim, n_params=1)
                                          )
        
        self.norm = InstanceNorm(out_dim)

        self.distance_threshold = nn.Parameter(torch.tensor(0.5))

    def forward(self, in_x, ex_x, in_pos, ex_pos, edge_index, batch,station_indices):
        n_in_x = in_x.size(0)
        ex_x=ex_x[:,-1].unsqueeze(1)
        ex_x = self.ex_embed_net_1(torch.cat((ex_x, ex_pos), dim=1))
        ex_x = self.ex_embed_net_2(ex_x)

        x = torch.cat((in_x, ex_x), dim=0)
        pos = torch.cat((in_pos, ex_pos), dim=0)

        index_shift = torch.zeros_like(edge_index)
        index_shift[0] = index_shift[0] + n_in_x

        x = self.propagate(edge_index + index_shift, x=x, pos=pos)
        x = x[:n_in_x]
        x = self.norm(x, batch)
        return x

    def message(self, x_i, x_j, pos_i, pos_j, edge_index, size):
        message = self.message_net_1(torch.cat((x_i, x_j, pos_i - pos_j), dim=-1))
        message = self.message_net_2(message)
        distance = self.haversine_distance(pos_i, pos_j)
        attention_input = torch.cat([x_i, x_j, pos_i - pos_j], dim=-1)
        alpha = self.att_net(attention_input)  # [num_edges, heads]
        edge_gate = (distance <= self.distance_threshold).float()
        min_beta = 0.1
        max_beta = 5.0
        beta = min_beta + (max_beta - min_beta) * torch.sigmoid(self.distance_beta)
        distance_factor = torch.exp(-beta * distance *10)
        alpha = alpha * distance_factor * edge_gate 
    
        alpha = torch_geometric.utils.softmax(alpha, edge_index[0], num_nodes=size[0])
        alpha = F.dropout(alpha, p=self.dropout, training=self.training)
        messages = []
        for h in range(self.heads):
            head_message = message * alpha[:, h:h+1]
            messages.append(head_message)
        return torch.mean(torch.stack(messages), dim=0)

    def update(self, aggr_out, x):
        update = self.update_net_1(torch.cat((x, aggr_out), dim=-1))
        update = self.update_net_2(update)
        return x + update
    
    def aggregate(self, message, edge_index, size):
        node_dim = -2
        out = torch_scatter.scatter(message, edge_index[1], dim=node_dim, 
                                   dim_size=size[1], reduce='mean')
        return out
    
    def haversine_distance(self, pos_i, pos_j):
        R = 6371.0 
        lon1, lat1 = pos_i[:, 0] * torch.pi / 180.0, pos_i[:, 1] * torch.pi / 180.0
        lon2, lat2 = pos_j[:, 0] * torch.pi / 180.0, pos_j[:, 1] * torch.pi / 180.0
        dlon = lon2 - lon1
        dlat = lat2 - lat1
    
        a = torch.sin(dlat/2)**2 + torch.cos(lat1) * torch.cos(lat2) * torch.sin(dlon/2)**2
        c = 2 * torch.atan2(torch.sqrt(a), torch.sqrt(1-a))
        distance = R * c
    
        return distance.unsqueeze(-1)