import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch_scatter import scatter_sum, scatter_softmax


def local_to_global_nodes(R: torch.Tensor, t: torch.Tensor, p: torch.Tensor):
    
    return torch.einsum('nij,nhpj->nhpi', R, p) + t[:, None, None, :]

def global_to_local_nodes(R: torch.Tensor, t: torch.Tensor, q: torch.Tensor):
    
    return torch.einsum('nji,nhpj->nhpi', R, q - t[:, None, None, :])

def normalize_vector(v, dim, eps=1e-6):
    return v / (torch.linalg.norm(v, ord=2, dim=dim, keepdim=True) + eps)

class LayerNorm(nn.Module):
    def __init__(self, normal_shape, gamma=True, beta=True, epsilon=1e-10):
        super().__init__()
        if isinstance(normal_shape, int):
            normal_shape = (normal_shape,)
        else:
            normal_shape = (normal_shape[-1],)
        self.normal_shape = torch.Size(normal_shape)
        self.epsilon = epsilon
        if gamma:
            self.gamma = nn.Parameter(torch.Tensor(*normal_shape))
        else:
            self.register_parameter('gamma', None)
        if beta:
            self.beta = nn.Parameter(torch.Tensor(*normal_shape))
        else:
            self.register_parameter('beta', None)
        self.reset_parameters()

    def reset_parameters(self):
        if self.gamma is not None:
            self.gamma.data.fill_(1)
        if self.beta is not None:
            self.beta.data.zero_()

    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        var = ((x - mean) ** 2).mean(dim=-1, keepdim=True)
        std = (var + self.epsilon).sqrt()
        y = (x - mean) / std
        if self.gamma is not None:
            y = y * self.gamma
        if self.beta is not None:
            y = y + self.beta
        return y



class GABlockPyG(nn.Module):
    
    def __init__(self,
                 node_feat_dim,
                 pair_feat_dim,
                 value_dim=32,
                 query_key_dim=32,
                 num_query_points=8,
                 num_value_points=8,
                 num_heads=12,
                 bias=False):
        super().__init__()
        self.node_feat_dim = node_feat_dim
        self.pair_feat_dim = pair_feat_dim
        self.value_dim = value_dim
        self.query_key_dim = query_key_dim
        self.num_query_points = num_query_points
        self.num_value_points = num_value_points
        self.num_heads = num_heads

        
        self.proj_query = nn.Linear(node_feat_dim, query_key_dim * num_heads, bias=bias)
        self.proj_key   = nn.Linear(node_feat_dim, query_key_dim * num_heads, bias=bias)
        self.proj_value = nn.Linear(node_feat_dim, value_dim * num_heads, bias=bias)

        
        self.proj_pair_bias = nn.Linear(pair_feat_dim, num_heads, bias=bias)

        
        self.spatial_coef = nn.Parameter(
            torch.full([1, 1, 1, num_heads], fill_value=np.log(np.exp(1.) - 1.), dtype=torch.float32),
            requires_grad=True
        )
        self.proj_query_point = nn.Linear(node_feat_dim, num_query_points * num_heads * 3, bias=bias)
        self.proj_key_point   = nn.Linear(node_feat_dim, num_query_points * num_heads * 3, bias=bias)
        self.proj_value_point = nn.Linear(node_feat_dim, num_value_points * num_heads * 3, bias=bias)

        in_features = (num_heads * pair_feat_dim) + (num_heads * value_dim) + \
                      (num_heads * num_value_points * (3 + 3 + 1))  
        self.out_transform = nn.Linear(in_features, node_feat_dim)

        self.layer_norm_1 = LayerNorm(node_feat_dim)
        self.mlp_transition = nn.Sequential(
            nn.Linear(node_feat_dim, node_feat_dim), nn.ReLU(),
            nn.Linear(node_feat_dim, node_feat_dim), nn.ReLU(),
            nn.Linear(node_feat_dim, node_feat_dim),
        )
        self.layer_norm_2 = LayerNorm(node_feat_dim)

    def _split_heads(self, x, n_heads, ch):
        
        N = x.size(0)
        return x.view(N, n_heads, ch)

    def forward(self, data):
        
        x = data.node_embed                              
        z = data.edge_attr                               
        edge_index = data.edge_index                     
        R = data.R_t                                     
        t = data.trans_t                                 

        N = x.size(0)
        E = edge_index.size(1)
        H = self.num_heads
        Pq = self.num_query_points
        Pv = self.num_value_points
        dk = self.query_key_dim
        device = x.device
        src, dst = edge_index[0], edge_index[1]

        
        q = self._split_heads(self.proj_query(x), H, dk)                  
        k = self._split_heads(self.proj_key(x),   H, dk)                  
        v = self._split_heads(self.proj_value(x), H, self.value_dim)      

        q_dst = q[dst]            
        k_src = k[src]            
        logits_node = (q_dst * k_src).sum(-1) * (1.0 / np.sqrt(dk))       

        logits_pair = self.proj_pair_bias(z)                              

        
        qpts_local = self.proj_query_point(x).view(N, H, Pq, 3)
        kpts_local = self.proj_key_point(x).view(N, H, Pq, 3)
        qpts_global = local_to_global_nodes(R, t, qpts_local)             
        kpts_global = local_to_global_nodes(R, t, kpts_local)             
        qpts_dst = qpts_global[dst]                                       
        kpts_src = kpts_global[src]                                       

        sum_sq_dist = ((qpts_dst - kpts_src) ** 2).sum(dim=(-1, -2))      
        gamma = F.softplus(self.spatial_coef).view(1, H)                  
        logits_spatial = sum_sq_dist * ((-1.0 * gamma * np.sqrt(2.0 / (9.0 * Pq))) / 2.0)  

        
        heads = torch.arange(H, device=device)
        group_index = dst.view(E, 1).expand(-1, H) + heads.view(1, H) * N  
        group_index_flat = group_index.reshape(-1).to(torch.long)          

        
        logits_sum = (logits_node + logits_pair + logits_spatial) * np.sqrt(1.0 / 3.0)  
        alpha = scatter_softmax(logits_sum.reshape(-1), group_index_flat)               
        alpha = alpha.view(E, H)                                                        

        

        
        feat_p_ehc = alpha.unsqueeze(-1) * z.unsqueeze(1)                 
        feat_p_flat = feat_p_ehc.reshape(E * H, self.pair_feat_dim)       
        feat_p = scatter_sum(feat_p_flat, group_index_flat, dim=0, dim_size=N * H)  
        feat_p = feat_p.view(N, H, self.pair_feat_dim).reshape(N, H * self.pair_feat_dim)

        
        v_src = v[src]                                                    
        v_src_flat = v_src.reshape(E * H, self.value_dim)                 
        alpha_flat = alpha.reshape(E * H, 1)                              
        feat_n = scatter_sum(alpha_flat * v_src_flat, group_index_flat, dim=0, dim_size=N * H)  
        feat_n = feat_n.view(N, H, self.value_dim).reshape(N, H * self.value_dim)

        
        vpts_local = self.proj_value_point(x).view(N, H, Pv, 3)           
        vpts_global = local_to_global_nodes(R, t, vpts_local)             
        vpts_src = vpts_global[src]                                       
        vpts_src_flat = vpts_src.reshape(E * H, Pv, 3)                    
        aggr_pts = scatter_sum(alpha_flat.view(E * H, 1, 1) * vpts_src_flat,
                            group_index_flat, dim=0, dim_size=N * H)   
        aggr_pts = aggr_pts.view(N, H, Pv, 3)                             

        aggr_pts_local = global_to_local_nodes(R, t, aggr_pts)            
        feat_dist = torch.linalg.norm(aggr_pts_local, dim=-1)             
        feat_dir  = normalize_vector(aggr_pts_local, dim=-1, eps=1e-4)    

        feat_s = torch.cat([
            aggr_pts_local.reshape(N, H * Pv * 3),
            feat_dist.reshape(N, H * Pv),
            feat_dir.reshape(N, H * Pv * 3),
        ], dim=-1)  

        
        feat_all = torch.cat([feat_p, feat_n, feat_s], dim=-1)            
        out = self.out_transform(feat_all)                                 
        x_updated = self.layer_norm_1(x + out)
        x_updated = self.layer_norm_2(x_updated + self.mlp_transition(x_updated))
        return x_updated




class GAEncoderPyG(nn.Module):
    def __init__(self, node_feat_dim, pair_feat_dim, num_layers, ga_block_opt={}):
        super().__init__()
        self.blocks = nn.ModuleList([
            GABlockPyG(node_feat_dim, pair_feat_dim, **ga_block_opt)
            for _ in range(num_layers)
        ])

    @torch.no_grad()
    def _ensure_fields(self, data):
        fields = ["node_embed", "edge_attr", "edge_index", "R_t", "trans_t"]
        missing = [f for f in fields if not hasattr(data, f)]
        if missing:
            raise ValueError(f"Data is missing fields: {missing}")

    def forward(self, data):
        self._ensure_fields(data)
        x = data.node_embed
        for blk in self.blocks:
            
            data.node_embed = x
            x = blk(data)
        return x



def concat_by_ptr(feat1: torch.Tensor, ptr1: torch.Tensor,
                  feat2: torch.Tensor, ptr2: torch.Tensor) -> torch.Tensor:
    
    chunks = []
    for i in range(ptr1.numel() - 1):
        s1, e1 = ptr1[i].item(), ptr1[i+1].item()
        s2, e2 = ptr2[i].item(), ptr2[i+1].item()
        chunks.append(torch.cat([feat1[s1:e1], feat2[s2:e2]], dim=0))
    return torch.cat(chunks, dim=0)



class GAInterfacePipeline(nn.Module):
    
    def __init__(self, node_feat_dim, pair_feat_dim, num_layers, share_blocks=False, ga_block_opt={}):
        super().__init__()
        self.chain_encoder = GAEncoderPyG(node_feat_dim, pair_feat_dim, num_layers, ga_block_opt)
        self.interface_encoder = self.chain_encoder if share_blocks else GAEncoderPyG(node_feat_dim, pair_feat_dim, num_layers, ga_block_opt)

    def forward(self, batch_dict):
        
        c1 = batch_dict["chain_1_graph"]
        c2 = batch_dict["chain_2_graph"]
        iF = batch_dict["interface_graph"]

        c1_updated = self.chain_encoder(c1)  
        c2_updated = self.chain_encoder(c2)  

        
        iface_node_embed = concat_by_ptr(c1_updated, c1.ptr, c2_updated, c2.ptr)  
        
        iF = iF.clone()  
        iF.node_embed = iface_node_embed

        
        iface_updated = self.interface_encoder(iF)  
        return iface_updated
