import torch
from torch import nn
from torch_geometric.nn import MessagePassing, InstanceNorm
import torch.nn.functional as F
import torch_scatter
from Modules.Activations import Tanh
import torch_geometric
from Modules.GNN.StationLinear import StationAwareLinear
from Modules.GNN.DyT import DynamicTanh

class GNN_Layer_Internal(MessagePassing):
    def __init__(self, in_dim, out_dim, hidden_dim, org_in_dim, heads=4, dropout=0.1):
        super(GNN_Layer_Internal, self).__init__(node_dim=-2, aggr=None)

        self.in_dim = in_dim
        self.out_dim = out_dim
        self.hidden_dim = hidden_dim
        self.heads = heads
        self.dropout = dropout

        self.message_net_1 = nn.Sequential(nn.Linear(2 * in_dim + org_in_dim + 2, hidden_dim),
                                           Tanh()
                                           )
        self.message_net_2 = nn.Sequential(nn.Linear(hidden_dim, hidden_dim),
                                           DynamicTanh(hidden_dim, n_params=1)
                                           )
        
        self.att_net = nn.Sequential(
            nn.Linear(2 * in_dim + org_in_dim + 2, hidden_dim), 
            nn.LayerNorm(hidden_dim), 
            nn.LeakyReLU(0.2),
            nn.Dropout(dropout), 
            nn.Linear(hidden_dim, heads), 
            nn.LeakyReLU(0.2)
        )

        
        self.distance_beta = nn.Parameter(torch.tensor(1.0)) 

        self.update_net_1 = nn.Sequential(nn.Linear(in_dim + hidden_dim, hidden_dim),
                                          Tanh()
                                          )
        self.update_net_2 = nn.Sequential(nn.Linear(hidden_dim, out_dim),
                                          DynamicTanh(hidden_dim, n_params=1)
                                          )
        self.norm = InstanceNorm(out_dim)

    def forward(self, x, u, pos, edge_index, batch):
        x = self.propagate(edge_index, x=x, u=u, pos=pos, size=(x.size(0), x.size(0)))
        x = self.norm(x, batch)
        return x

    def message(self, x_i, x_j, u_i, u_j, pos_i, pos_j, edge_index, size):
        att_input = torch.cat((x_i, x_j, u_i - u_j, pos_i - pos_j), dim=-1)
        message = self.message_net_1(att_input)
        message = self.message_net_2(message)
    
        distance = self.haversine_distance(pos_i, pos_j)
        alpha = self.att_net(att_input)  # [num_edges, heads]
        edge_gate = (distance <= 0.7).float()
        distance_factor = torch.exp(-distance*10)
        alpha = alpha * distance_factor * edge_gate 
        alpha = torch_geometric.utils.softmax(alpha, edge_index[0], num_nodes=size[0])
        alpha = F.dropout(alpha, p=self.dropout, training=self.training)
    
        messages = []
        for h in range(self.heads):
            head_message = message * alpha[:, h:h+1]
            messages.append(head_message)
        return torch.mean(torch.stack(messages), dim=0)

    def update(self, aggr_out, x):
        update = self.update_net_1(torch.cat((x, aggr_out), dim=-1))
        update = self.update_net_2(update)
        return x + update
    
    def aggregate(self, message, edge_index, size):
        node_dim = -2
        out = torch_scatter.scatter(message, edge_index[1], dim=node_dim, 
                                   dim_size=size[1], reduce='mean')
        return out
    
    def haversine_distance(self, pos_i, pos_j):
        R = 6371.0
        lon1, lat1 = pos_i[:, 0] * torch.pi / 180.0, pos_i[:, 1] * torch.pi / 180.0
        lon2, lat2 = pos_j[:, 0] * torch.pi / 180.0, pos_j[:, 1] * torch.pi / 180.0
        dlon = lon2 - lon1
        dlat = lat2 - lat1
        a = torch.sin(dlat/2)**2 + torch.cos(lat1) * torch.cos(lat2) * torch.sin(dlon/2)**2
        c = 2 * torch.atan2(torch.sqrt(a), torch.sqrt(1-a))
        distance = R * c
        return distance.unsqueeze(-1)