# pytorch
import torch
import torch.nn as nn
import torch.nn.functional as F

# numpy
import numpy as np

import torch

def normalize_pc_list(pc_list):
    """
    Normalize all point clouds in a list together.
    
    Args:
        pc_list: List of point cloud tensors, each of shape (b, n, c)
                where the first 3 dimensions are coordinates and the rest are features.
    
    Returns:
        List of normalized point clouds with the same order and shapes.
    """
    # Collect all coordinates from all point clouds in the list
    all_coords = []

    
    for pc in pc_list:
        coords = pc[:, :, :3]  # (b, n, 3)
        all_coords.append(coords)
    
    # Concatenate all point cloud coordinates
    all_coords = torch.cat(all_coords, dim=1)  # (n, total_n, 3)
    
    # Compute global centroid and max distance for all point clouds
    centroid = torch.mean(all_coords, dim=1, keepdim=True)  # (b, 1, 3)
    centered_coords = all_coords - centroid
    max_dist = torch.max(torch.sqrt(torch.sum(centered_coords**2, dim=-1)))  # scalar
    
    # Normalize each point cloud in the list using the global parameters
    normalized_pc_list = []
    start_idx = 0
    
    for pc in pc_list:
        n = pc.shape[1]
        
        # Extract coordinates and features
        coords = pc[:, :, :3]  # (b, n, 3)
        features = pc[:, :, 3:]  # (b, n, c-3)
        
        # Apply normalization using the global centroid and max distance
        centroid_batch = centroid  # (b, 1, 3)
        centered_coords = coords - centroid_batch
        normalized_coords = centered_coords / max_dist
        
        # Recombine with features
        normalized_pc = torch.cat([normalized_coords, features], dim=-1)
        normalized_pc_list.append(normalized_pc)
        
        start_idx += n
    
    return normalized_pc_list

def rotate(points, rotation):
    # Euler rotation in XYZ order
    rx, ry, rz = rotation[:,0], rotation[:,1], rotation[:,2]

    cosx, sinx = torch.cos(rx), torch.sin(rx)
    cosy, siny = torch.cos(ry), torch.sin(ry)
    cosz, sinz = torch.cos(rz), torch.sin(rz)

    Rx = torch.stack([
        torch.stack([torch.ones_like(rx), torch.zeros_like(rx), torch.zeros_like(rx)], dim=-1),
        torch.stack([torch.zeros_like(rx), cosx, -sinx], dim=-1),
        torch.stack([torch.zeros_like(rx), sinx, cosx], dim=-1)
    ], dim=1)

    Ry = torch.stack([
        torch.stack([cosy, torch.zeros_like(ry), siny], dim=-1),
        torch.stack([torch.zeros_like(ry), torch.ones_like(ry), torch.zeros_like(ry)], dim=-1),
        torch.stack([-siny, torch.zeros_like(ry), cosy], dim=-1)
    ], dim=1)

    Rz = torch.stack([
        torch.stack([cosz, -sinz, torch.zeros_like(rz)], dim=-1),
        torch.stack([sinz,  cosz, torch.zeros_like(rz)], dim=-1),
        torch.stack([torch.zeros_like(rz), torch.zeros_like(rz), torch.ones_like(rz)], dim=-1)
    ], dim=1)

    R = Rz @ Ry @ Rx  # [B, 3, 3]
    return torch.bmm(points, R.transpose(1, 2))  # [B, N, 3]

class ParaEstimator(nn.Module):
    def __init__ (self, in_dim, hidden_dim, out_dim) :
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, out_dim),    
        )
    def forward(self, feature):
        return self.encoder(feature)
    
class StructTransformerBlock(nn.Module):
    def __init__(self, d_pos, d_points, d_model) -> None:
        super().__init__()
        self.fc1 = nn.Linear(d_points, d_model)
        self.fc2 = nn.Linear(d_model, d_points)
        self.fc_delta = nn.Sequential(
            nn.Linear(d_pos, d_model),
            nn.ReLU(),
            nn.Linear(d_model, d_model)
        )
        self.fc_gamma = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.ReLU(),
            nn.Linear(d_model, d_model)
        )
        self.w_qs = nn.Linear(d_model, d_model, bias=False)
        self.w_ks = nn.Linear(d_model, d_model, bias=False)
        self.w_vs = nn.Linear(d_model, d_model, bias=False)
        
    # xyz: b x n x 3, features: b x n x f
    def forward(self, xyz, features):        
        pre = features
        x = self.fc1(features)
        q, k, v = self.w_qs(x), self.w_ks(x), self.w_vs(x)

        pos_enc = self.fc_delta(xyz[:,:,None] - xyz[:,None])  # b x n x n x f
        
        attn = self.fc_gamma(q[:, :, None, :] - k[:, None, :, :] + pos_enc)
        attn = F.softmax(attn / np.sqrt(k.size(-1)), dim=-2)  # b x n x n x f

        v_expanded = v.unsqueeze(1).expand(-1, v.size(1), v.size(1), -1)  # [b, n, n, f]
        res = torch.einsum('bmnf,bmnf->bmf', attn, v_expanded + pos_enc) # b × n × f
        res = self.fc2(res) + pre
        return res, attn
    
class StructEncoder(nn.Module):
    def __init__(self, d_Feature1=32, d_Attention1=64, d_token=128, d_Feature2 = 64, d_Attention2 = 128, d_Hidden = 256):
        super().__init__()
        self.fc_feature1 = nn.Sequential(
            nn.Linear(3, d_Feature1),
            nn.ReLU(),
            nn.Linear(d_Feature1, d_Feature1),
        )
        self.PT1 = StructTransformerBlock(3, d_Feature1, d_Attention1)

        self.attn_pool1 = nn.Sequential(
            nn.Linear(d_Feature1, d_Attention1),
            nn.Tanh(),
            nn.Linear(d_Attention1, 1)
        )

        self.fc_head1 = nn.Sequential(
            nn.Linear(d_Feature1, d_token),
            nn.ReLU(),
            nn.Linear(d_token, d_token*2),
            nn.ReLU(),
            nn.Linear(d_token*2, d_token),
        )
        
        self.fc_feature2 = nn.Sequential(
            nn.Linear(d_token, d_Feature2),
            nn.ReLU(),
            nn.Linear(d_Feature2, d_Feature2)
        )
        self.PT2 = StructTransformerBlock(d_token, d_Feature2, d_Attention2)

        self.attn_pool2 = nn.Sequential(
            nn.Linear(d_Feature2, d_Attention2),
            nn.Tanh(),
            nn.Linear(d_Attention2, 1)
        )

        self.fc_head2 = nn.Sequential(
            nn.Linear(d_Feature2, d_Hidden),
            nn.ReLU(),
            nn.Linear(d_Hidden, d_Hidden*2),
            nn.ReLU(),
            nn.Linear(d_Hidden*2, d_Hidden),
        )        

    def forward(self, lx):
        """
        lx: list of x
        x: [B, N, 4], where each point is (x, y, z, label)
        """
        batch_size = lx[0].shape[0] 
        num_label = len(lx)
        pooled_list = []  

        for x in lx:
            xyz = x[:, :, 0:3]    # [B, N, 3]
            
            features = self.fc_feature1(xyz)  # [B, N, d_Feature]
            
            res, attn = self.PT1(xyz, features)  # [B, N, d_Attention]
            
            weights = self.attn_pool1(res)              # [B, N, 1]
            weights = F.softmax(weights, dim=1)       # [B, N, 1]
            pooled = torch.sum(weights * res, dim=1)  # [B, d_Attention]
            
            pooled = self.fc_head1(pooled)             # [B, d_Hidden]
            pooled_list.append(pooled)
            
        label_tokens = torch.stack(pooled_list, dim=1)
        # labels = torch.arange(1, num_label + 1, device=label_tokens.device, dtype=label_tokens.dtype)  
        # labels = labels.view(1, num_label, 1).expand(batch_size, -1, -1) 
        
        features = self.fc_feature2(label_tokens)
        res, attn = self.PT2(label_tokens, features)
        
        weights = self.attn_pool2(res)
        weights = F.softmax(weights, dim=1)
        pooled = torch.sum(weights * res, dim=1)

        return self.fc_head2(pooled)                   # [B, d_Hidden]