import torch
import torch.nn as nn
from pointnet2_ops import pointnet2_utils
from models.point_encoder import PointcloudEncoder
import logging

class FeaturePropagation(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Conv1d(in_channels, out_channels, 1),
            nn.BatchNorm1d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv1d(out_channels, out_channels, 1)
        )

    def forward(self, low_res_points, high_res_points, low_res_features, high_res_features):
        """
        Args:
            low_res_points: B x S x 3 (centers from FPS)
            high_res_points: B x N x 3 (original points)
            low_res_features: B x F x S (features from transformer)
            high_res_features: B x F x N (skip-connected features)
        Returns:
            upsampled_features: B x F_out x N
        """
        B, N, _ = high_res_points.shape
        _, S, _ = low_res_points.shape

        # KNN-based interpolation
        sqrdists = square_distance(high_res_points, low_res_points)  # B x N x S
        _, idx = torch.topk(sqrdists, k=3, dim=-1, largest=False, sorted=False)  # B x N x 3
        dist = torch.gather(sqrdists, dim=-1, index=idx) + 1e-8  # Avoid division by zero
        norm = torch.sum(1.0 / dist, dim=-1, keepdim=True)
        weight = (1.0 / dist) / norm  # B x N x 3

        # Gather low_res_features using idx
        idx_expanded = idx.unsqueeze(1).expand(-1, low_res_features.size(1), -1, -1)  # B x F x N x 3
        gathered_features = torch.gather(
            low_res_features.unsqueeze(-1).expand(-1, -1, -1, 3), dim=2, index=idx_expanded
        )  # B x F x N x 3

        # Interpolate features
        interpolated_features = torch.sum(gathered_features * weight.unsqueeze(1), dim=-1)  # B x F x N
        
        # Concatenate interpolated and skip-connected features
        if high_res_features is not None:
            concatenated_features = torch.cat([interpolated_features, high_res_features], dim=1)  # B x (F1+F2) x N
        else:
            concatenated_features = interpolated_features  # B x F x N

        # Apply MLP
        return self.mlp(concatenated_features)

def square_distance(src, dst):
    """
    Calculate Euclid distance between each two points.
    src^T * dst = xn * xm + yn * ym + zn * zm;
    sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
    sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
    dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
         = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
    Input:
        src: source points, [B, N, C]
        dst: target points, [B, M, C]
    Output:
        dist: per-point square distance, [B, N, M]
    """
    B, N, _ = src.shape
    _, M, _ = dst.shape
    dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
    dist += torch.sum(src ** 2, -1).view(B, N, 1)
    dist += torch.sum(dst ** 2, -1).view(B, 1, M)
    return dist 

class PointcloudEncoderHighRes(nn.Module):
    def __init__(self, pretrained_encoder: PointcloudEncoder, args):
        super().__init__()
        # Freeze the pretrained encoder
        self.encoder = pretrained_encoder.point_encoder
        for param in self.encoder.parameters():
            param.requires_grad = False

        # Add feature propagation modules
        self.fp12 = FeaturePropagation(in_channels=args.pc_feat_dim, out_channels=256)
        self.fp8 = FeaturePropagation(in_channels=1408 + 256, out_channels=128)
        self.fp4 = FeaturePropagation(in_channels=1408 + 128, out_channels=50)

    def load_pretrained_encoder(self, state_dict):
        """
        Load pretrained weights into the frozen encoder.
        """
        self.encoder.load_state_dict(state_dict)

    def forward(self, pts, colors):
        # Forward pass through the frozen encoder
        _, center, features = self.encoder.group_divider(pts, colors)
        group_input_tokens = self.encoder.encoder(features)  # B G N
        group_input_tokens = self.encoder.encoder2trans(group_input_tokens)

        cls_tokens = self.encoder.cls_token.expand(group_input_tokens.size(0), -1, -1)
        cls_pos = self.encoder.cls_pos.expand(group_input_tokens.size(0), -1, -1)

        pos = self.encoder.pos_embed(center)
        x = torch.cat((cls_tokens, group_input_tokens), dim=1)
        pos = torch.cat((cls_pos, pos), dim=1)
        x = x + pos
        x = self.encoder.patch_dropout(x)

        x = self.encoder.visual.pos_drop(x)

        # Extract intermediate features
        H4, H8, H12 = None, None, None
        for i, blk in enumerate(self.encoder.visual.blocks):
            x = blk(x)
            if i == 3:  # 4th block
                H4 = x[:,1:].permute(0,2,1)
            elif i == 7:  # 8th block
                H8 = x[:,1:].permute(0,2,1)
            elif i == len(self.encoder.visual.blocks) - 1:  # Last block
                H12 = x[:,1:].permute(0,2,1)


        fp12 = self.fp12(center, pts, H12, None)
        fp8 = self.fp8(center, pts, H8, fp12)
        fp4 = self.fp4(center, pts, H4, fp8)

        return fp4  # Return point-wise features
