import torch
import torch.nn as nn
from pointnet2_ops import pointnet2_utils

def fps(data, number):
    '''
        data B N 3
        number int
    '''
    fps_idx = pointnet2_utils.furthest_point_sample(data, number) 
    fps_data = pointnet2_utils.gather_operation(data.transpose(1, 2).contiguous(), fps_idx).transpose(1,2).contiguous()
    return fps_data

# https://github.com/Strawberry-Eat-Mango/PCT_Pytorch/blob/main/util.py 
def knn_point(nsample, xyz, new_xyz):
    """
    Input:
        nsample: max sample number in local region
        xyz: all points, [B, N, C]
        new_xyz: query points, [B, S, C]
    Return:
        group_idx: grouped points index, [B, S, nsample]
    """
    sqrdists = square_distance(new_xyz, xyz)
    _, group_idx = torch.topk(sqrdists, nsample, dim = -1, largest=False, sorted=False)
    return group_idx

def square_distance(src, dst):
    """
    Calculate Euclid distance between each two points.
    src^T * dst = xn * xm + yn * ym + zn * zm;
    sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
    sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
    dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
         = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
    Input:
        src: source points, [B, N, C]
        dst: target points, [B, M, C]
    Output:
        dist: per-point square distance, [B, N, M]
    """
    B, N, _ = src.shape
    _, M, _ = dst.shape
    dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
    dist += torch.sum(src ** 2, -1).view(B, N, 1)
    dist += torch.sum(dst ** 2, -1).view(B, 1, M)
    return dist    


class PatchDropout(nn.Module):
    """
    https://arxiv.org/abs/2212.00794
    """

    def __init__(self, prob, exclude_first_token=True):
        super().__init__()
        assert 0 <= prob < 1.
        self.prob = prob
        self.exclude_first_token = exclude_first_token  # exclude CLS token
        logging.info("patch dropout prob is {}".format(prob))

    def forward(self, x):

        if self.exclude_first_token:
            cls_tokens, x = x[:, :1], x[:, 1:]
        else:
            cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])

        batch = x.size()[0]
        num_tokens = x.size()[1]

        batch_indices = torch.arange(batch)
        batch_indices = batch_indices[..., None]

        keep_prob = 1 - self.prob
        num_patches_keep = max(1, int(num_tokens * keep_prob))

        rand = torch.randn(batch, num_tokens)
        patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices

        x = x[batch_indices, patch_indices_keep]

        if self.exclude_first_token:
            x = torch.cat((cls_tokens, x), dim=1)

        return x


def _index_points(points, idx):
    batch_size, num_points, _ = points.shape
    device = points.device
    
    idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points
    S, K = idx.shape[1], idx.shape[2]
    idx_with_offset = idx + idx_base
    idx_flat = idx_with_offset.view(-1)
    
    new_points_flat = points.view(batch_size * num_points, -1)[idx_flat, :]
    new_points = new_points_flat.view(batch_size, S, K, -1).contiguous()
    
    return new_points

class Group(nn.Module):
    def __init__(self, num_group, group_size):
        super().__init__()
        self.num_group = num_group
        self.group_size = group_size

    def forward(self, xyz, color):
        batch_size, num_points, _ = xyz.shape
        center = fps(xyz, self.num_group) 
        idx = knn_point(self.group_size, xyz, center)
        
        neighborhood = _index_points(xyz, idx)
        neighborhood_color = _index_points(color, idx)

        neighborhood = neighborhood - center.unsqueeze(2)
        features = torch.cat((neighborhood, neighborhood_color), dim=-1)
        return neighborhood, center, features

class Encoder(nn.Module):
    def __init__(self, encoder_channel):
        super().__init__()
        self.encoder_channel = encoder_channel
        self.first_conv = nn.Sequential(
            nn.Conv1d(6, 128, 1),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Conv1d(128, 256, 1)
        )
        self.second_conv = nn.Sequential(
            nn.Conv1d(512, 512, 1),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Conv1d(512, self.encoder_channel, 1)
        )
    def forward(self, point_groups):
        bs, g, n , _ = point_groups.shape
        point_groups = point_groups.reshape(bs * g, n, 6)
        feature = self.first_conv(point_groups.transpose(2,1)) 
        feature_global = torch.max(feature,dim=2,keepdim=True)[0] 
        feature = torch.cat([feature_global.expand(-1,-1,n), feature], dim=1)
        feature = self.second_conv(feature)
        feature_global = torch.max(feature, dim=2, keepdim=False)[0] 
        return feature_global.reshape(bs, g, self.encoder_channel)

class PointcloudEncoder(nn.Module):
    def __init__(self, point_transformer):
        super().__init__()
        self.trans_dim = 768
        self.embed_dim = 1024
        self.group_size = 32
        self.num_group = 512
        self.group_divider = Group(num_group = self.num_group, group_size = self.group_size)

        self.encoder_dim =  512
        self.encoder = Encoder(encoder_channel = self.encoder_dim)
       
        self.encoder2trans = nn.Linear(self.encoder_dim,  self.trans_dim)
        
        self.trans2embed = nn.Linear(self.trans_dim,  self.embed_dim)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, self.trans_dim))
        
        self.cls_pos = nn.Parameter(torch.randn(1, 1, self.trans_dim))

        self.pos_embed = nn.Sequential(
            nn.Linear(3, 128),
            nn.GELU(),
            nn.Linear(128, self.trans_dim)
        )  
        dr = 0
        self.patch_dropout = PatchDropout(dr) if dr > 0. else nn.Identity()
        self.visual = point_transformer

    def forward(self, pts, colors):
        
        _, center, features = self.group_divider(pts, colors)

        group_input_tokens = self.encoder(features) 

        group_input_tokens = self.encoder2trans(group_input_tokens)

        cls_tokens = self.cls_token.expand(group_input_tokens.size(0), -1, -1)  
        cls_pos = self.cls_pos.expand(group_input_tokens.size(0), -1, -1) 
        pos = self.pos_embed(center)

        x = torch.cat((cls_tokens, group_input_tokens), dim=1)
        pos = torch.cat((cls_pos, pos), dim=1)
        x = x + pos
        x = self.patch_dropout(x)

        x = self.visual.pos_drop(x)
        features = {}
        for i, blk in enumerate(self.visual.blocks):
            x = blk(x)
            if i == 3:  
                features['h4'] = x
            elif i == 7:  
                features['h8'] = x
        
        features['h_last'] = x

        h4 = self.visual.norm(features['h4'])
        h8 = self.visual.norm(features['h8'])
        h_last = self.visual.norm(features['h_last'])
    
        h4 = h4[:, 1:, :]
        h8 = h8[:, 1:, :]
        h12 = h_last[:, 1:, :]

        center_level_0 = pts.permute(0,2,1) 
        center_level_1 = fps(pts, 1536).transpose(-1, -2).contiguous()            
        center_level_2 = fps(pts, 1024).transpose(-1, -2).contiguous()            
        center_level_3 = center.transpose(-1, -2).contiguous()                 
        
        return h4,h8,h12,pts,center_level_0,center_level_1,center_level_2,center_level_3