import torch
import torch.nn as nn
import torch.nn.functional as F

from pointnet2_ops import pointnet2_utils

from dataset.prepross import *

import random

#dgcnn
def knn(x, k):
    inner = -2*torch.matmul(x.transpose(2, 1), x)
    xx = torch.sum(x**2, dim=1, keepdim=True)
    pairwise_distance = -xx - inner - xx.transpose(2, 1)

    idx = pairwise_distance.topk(k=k, dim=-1)[1]   # (batch_size, num_points, k)
    return idx

def get_graph_feature_split(x_all, k=20):

    B, D, N = x_all.shape
    xyz = x_all[:, :3, :]         # (B, 3, N)
    other = x_all[:, 3:, :]       # (B, D-3, N)
    
    idx = knn(xyz, k=k)           # Compute KNN using xyz only

    # === Helper ===
    def gather_neighbors(x, idx):
        B, C, N = x.shape
        device = x.device
        idx_base = torch.arange(0, B, device=device).view(-1, 1, 1) * N
        idx = (idx + idx_base).view(-1)  # (B*N*k)
        x = x.transpose(2, 1).contiguous()  # (B, N, C)
        neighbor = x.view(B * N, C)[idx, :].view(B, N, k, C)  # (B, N, k, C)
        center = x.view(B, N, 1, C).repeat(1, 1, k, 1)         # (B, N, k, C)
        return neighbor, center

    # === xyz feature with difference ===
    xyz_neighbor, xyz_center = gather_neighbors(xyz, idx)  # both (B, N, k, 3)
    xyz_feat = torch.cat((xyz_neighbor - xyz_center, xyz_center), dim=-1)  # (B, N, k, 6)
    xyz_feat = xyz_feat.permute(0, 3, 1, 2).contiguous()   # (B, 6, N, k)

    # === other features without difference ===
    other_neighbor, other_center = gather_neighbors(other, idx)  # (B, N, k, D-3)
    other_feat = torch.cat((other_neighbor, other_center), dim=-1)  # (B, N, k, 2*(D-3))
    other_feat = other_feat.permute(0, 3, 1, 2).contiguous()        # (B, 2*(D-3), N, k)

    return xyz_feat, other_feat

def get_graph_feature(x, k=20, idx=None):
    batch_size = x.size(0)
    num_points = x.size(2)
    x = x.view(batch_size, -1, num_points)
    if idx is None:
        idx = knn(x, k=k)   # (batch_size, num_points, k)
    device = torch.device('cuda')

    idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1)*num_points

    idx = idx + idx_base

    idx = idx.view(-1)

    _, num_dims, _ = x.size()

    x = x.transpose(2, 1).contiguous()   # (batch_size, num_points, num_dims)  -> (batch_size*num_points, num_dims) #   batch_size * num_points * k + range(0, batch_size*num_points)
    feature = x.view(batch_size*num_points, -1)[idx, :]
    feature = feature.view(batch_size, num_points, k, num_dims)
    x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)

    feature = torch.cat((feature-x, x), dim=3).permute(0, 3, 1, 2)

    return feature

class GSDEncoder(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.k = self.cfg.model.k

        self.conv1 = nn.Sequential(nn.Conv2d(38, 64, kernel_size=1, bias=False),
                                   nn.BatchNorm2d(64),
                                   nn.LeakyReLU(negative_slope=0.2))
        
        self.conv2 = nn.Sequential(nn.Conv2d(128, 64, kernel_size=1, bias=False),
                                   nn.BatchNorm2d(64),
                                   nn.LeakyReLU(negative_slope=0.2))
        
        self.conv3 = nn.Sequential(nn.Conv2d(128, 64, kernel_size=1, bias=False),
                                   nn.BatchNorm2d(64),
                                   nn.LeakyReLU(negative_slope=0.2))

    @staticmethod
    def quat_to_rotmat(q: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
        q = q / (q.norm(dim=-1, keepdim=True) + eps)  
        x, y, z, w = q.unbind(-1)

        xx, yy, zz = x * x, y * y, z * z
        xy, xz, yz = x * y, x * z, y * z
        wx, wy, wz = w * x, w * y, w * z

        m00 = 1.0 - 2.0 * (yy + zz)
        m01 = 2.0 * (xy - wz)
        m02 = 2.0 * (xz + wy)
        m10 = 2.0 * (xy + wz)
        m11 = 1.0 - 2.0 * (xx + zz)
        m12 = 2.0 * (yz - wx)
        m20 = 2.0 * (xz - wy)
        m21 = 2.0 * (yz + wx)
        m22 = 1.0 - 2.0 * (xx + yy)

        rot = torch.stack([m00, m01, m02,
                           m10, m11, m12,
                           m20, m21, m22], dim=-1)  # (..., 9)
        return rot

    def forward(self, gs_feat):
        #xyz,opacity,scale,Rotation matrix,sh
        xyz = gs_feat[..., 0:3]  
        opacity = gs_feat[..., 3:4] 
        rgb = gs_feat[..., 11:]          
        scale = gs_feat[..., 4:7]        
        quat = gs_feat[..., 7:11]

        # scale    = torch.exp(logscale)
        rotmat   = self.quat_to_rotmat(quat)

        gs_tensor = torch.cat([xyz, opacity, scale, rotmat, rgb], dim=-1)
        xyz_feat, other_feat = get_graph_feature_split(gs_tensor.transpose(2, 1), k=20)
        graph_feat = torch.cat([xyz_feat, other_feat], dim=1) 
        x = self.conv1(graph_feat)
        x = x.max(dim=-1, keepdim=False)[0]

        x = get_graph_feature(x, k=20)
        x = self.conv2(x)
        x = x.max(dim=-1, keepdim=False)[0]
        
        x = get_graph_feature(x, k=20)
        x = self.conv3(x)
        x = x.max(dim=-1, keepdim=False)[0]
        

        return x
    
class FeatureSpaceUpSampleAvg(nn.Module):
    def __init__(self, in_channels=128, out_channels=64, k=3):
        super().__init__()
        self.k = k
        self.mlp = nn.Sequential(
            nn.Conv1d(in_channels, out_channels, kernel_size=1),
            nn.ReLU(),
            nn.BatchNorm1d(out_channels)
        )

    def forward(self, feat_sparse):

        B, C_in, N = feat_sparse.shape
        k = self.k
        M = 2 * N

        # copy + noise
        feat_dense = torch.cat([
            feat_sparse,
            feat_sparse + torch.randn_like(feat_sparse) * 0.01
        ], dim=2)  # (B, C_in, 2N)

        # KNN
        feat_sparse_t = feat_sparse.permute(0, 2, 1)  # (B, N, C_in)
        feat_dense_t = feat_dense.permute(0, 2, 1)    # (B, 2N, C_in)

        dist = torch.cdist(feat_dense_t, feat_sparse_t)  # (B, 2N, N)
        _, idx = dist.topk(k, dim=-1, largest=False, sorted=False)  # (B, 2N, k)

        # mean
        def index_points(points, idx):
            # points: (B, C, N), idx: (B, M, k)
            B, C, N = points.shape
            _, M, k = idx.shape
            idx = idx.unsqueeze(1).expand(-1, C, -1, -1)  # (B, C, M, k)
            points = points.unsqueeze(2).expand(-1, -1, M, -1)  # (B, C, M, N)
            return torch.gather(points, 3, idx)  # (B, C, M, k)

        grouped_feat = index_points(feat_sparse, idx)  # (B, C_in, 2N, k)
        feat_interp = grouped_feat.mean(dim=3)         # (B, C_in, 2N)

        # MLP 
        out_feat = self.mlp(feat_interp)  # (B, C_out, 2N)
        return out_feat

def gs_feature_fps(gs_tensor, num):

    fps_idx = pointnet2_utils.furthest_point_sample(gs_tensor.transpose(2, 1).contiguous(), num).long() 
    gs_fps_tensor = index_points(gs_tensor.transpose(2, 1).contiguous(), fps_idx)
    gs_fps_tensor = torch.cat([gs_fps_tensor,gs_fps_tensor], dim=2).transpose(2, 1)

    return gs_fps_tensor

class firstscale_net(nn.Module):
    def __init__(self, cfg, ):
        super().__init__()
        self.cfg = cfg
        self.enconder = GSDEncoder(self.cfg)

        self.fps_mlp = nn.Sequential(
                                    nn.Conv2d(256, 128, kernel_size=1, bias=False),
                                    nn.BatchNorm2d(128),
                                    nn.LeakyReLU(negative_slope=0.2),
                                    nn.Conv2d(128, 64, kernel_size=1, bias=False),
                                    nn.BatchNorm2d(64),
                                    nn.LeakyReLU(negative_slope=0.2),
                                    )

    def forward(self, x):
        out_feat_2_ = self.enconder(x)
        out_feat_2_down = gs_feature_fps(out_feat_2_, 512)
        out_feat_2_down = get_graph_feature(out_feat_2_down, k=20)
        out_feat_2_down = self.fps_mlp(out_feat_2_down)
        out_feat_2_down = out_feat_2_down.max(dim=-1, keepdim=False)[0]

        return out_feat_2_, out_feat_2_down
    
class otherscale_net(nn.Module):
    def __init__(self, cfg, ):
        super().__init__()
        self.cfg = cfg

        self.enconder = GSDEncoder(self.cfg)
        self.enconder2 = GSDEncoder(self.cfg)
        self.mlp_1 = nn.Sequential(
            nn.Conv1d(128, 256, kernel_size=1),
            nn.ReLU(),
            nn.Conv1d(256, 64, kernel_size=1),
        )

        self.mlp_2 = nn.Sequential(
            nn.Conv2d(128, 128, kernel_size=1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(negative_slope=0.2),
            )

        self.mlp_4 = nn.Sequential(
            nn.Conv2d(256, 128, kernel_size=1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(negative_slope=0.2),
            nn.Conv2d(128, 64, kernel_size=1, bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(negative_slope=0.2),
        )

        self.up = FeatureSpaceUpSampleAvg(128, 64)

        self.fps_mlp = nn.Sequential(
            nn.Conv2d(256, 128, kernel_size=1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(negative_slope=0.2),
            nn.Conv2d(128, 64, kernel_size=1, bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(negative_slope=0.2)
        )

    def forward(self, x, feat_dowm,):

        dist2origin = torch.norm(x[:, :, :3], dim=2)
        idx_nearest = dist2origin.argsort(dim=1)[:, :512]
        idx_expand = idx_nearest.unsqueeze(-1).expand(-1, -1, x.shape[2])  # (B, 512, C)
        gs_nearest_inzoom = torch.gather(x, 1, idx_expand)  # (B, 512, C)
        out_feat = self.enconder(gs_nearest_inzoom)
        feat_skip12 = self.mlp_1(torch.cat([out_feat, feat_dowm], dim=1))
        feat_skip12  = get_graph_feature(feat_skip12 , k=20)
        feat_skip12 = self.mlp_2(feat_skip12)
        feat_skip12 = feat_skip12.max(dim=-1, keepdim=False)[0]
        feat_2to1 = self.up(feat_skip12)

        out_feat_1 = self.enconder2(x)
        out_feat_1_ = torch.cat([out_feat_1, feat_2to1],dim=1)
        out_feat_1_ = get_graph_feature(out_feat_1_ , k=20)
        out_feat_1_ = self.mlp_4(out_feat_1_)
        out_feat_1_ = out_feat_1_.max(dim=-1, keepdim=False)[0]
        out_feat_1_down = gs_feature_fps(out_feat_1_, 512)
        out_feat_1_down  = get_graph_feature(out_feat_1_down , k=20)
        out_feat_1_down = self.fps_mlp(out_feat_1_down)
        out_feat_1_down = out_feat_1_down.max(dim=-1, keepdim=False)[0]

        return out_feat_1_, out_feat_1_down
    
class lastscale_net(nn.Module):
    def __init__(self, cfg, ):
        super().__init__()
        self.cfg = cfg

        self.enconder = GSDEncoder(self.cfg)
        self.enconder2 = GSDEncoder(self.cfg)
        self.mlp_1 = nn.Sequential(
            nn.Conv1d(128, 256, kernel_size=1),
            nn.ReLU(),
            nn.Conv1d(256, 64, kernel_size=1),
        )

        self.mlp_2 = nn.Sequential(
            nn.Conv2d(128, 128, kernel_size=1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(negative_slope=0.2),
            )
            
        self.mlp_4 = nn.Sequential(
            nn.Conv2d(256, 128, kernel_size=1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(negative_slope=0.2),
            nn.Conv2d(128, 64, kernel_size=1, bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(negative_slope=0.2),
        )

        self.up = FeatureSpaceUpSampleAvg(128, 64)
        self.b_mlp = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=1, bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(negative_slope=0.2)
        )

    def forward(self, x, feat_dowm):

        dist2origin = torch.norm(x[:, :, :3], dim=2)
        idx_nearest = dist2origin.argsort(dim=1)[:, :512]
        idx_expand = idx_nearest.unsqueeze(-1).expand(-1, -1, x.shape[2])  # (B, 512, C)
        gs_nearest_inzoom = torch.gather(x, 1, idx_expand)  # (B, 512, C)
        out_feat = self.enconder(gs_nearest_inzoom)
        feat_skip12 = self.mlp_1(torch.cat([out_feat, feat_dowm], dim=1))
        feat_skip12  = get_graph_feature(feat_skip12 , k=20)
        feat_skip12 = self.mlp_2(feat_skip12)
        feat_skip12 = feat_skip12.max(dim=-1, keepdim=False)[0]
        feat_2to1 = self.up(feat_skip12)

        out_feat_1 = self.enconder2(x)
        out_feat_1_ = torch.cat([out_feat_1, feat_2to1],dim=1)
        out_feat_1_ = get_graph_feature(out_feat_1_ , k=20)
        out_feat_1_ = self.mlp_4(out_feat_1_)
        out_feat_1_ = out_feat_1_.max(dim=-1, keepdim=False)[0]
        out_feat_1_down = get_graph_feature(out_feat_1_ , k=20)
        out_feat_1_down = self.b_mlp(out_feat_1_down)
        out_feat_1_down = out_feat_1_down.max(dim=-1, keepdim=False)[0]

        return out_feat_1_down
    
class AttnPool1D(nn.Module):
    def __init__(self, C):
        super().__init__()
        self.q = nn.Parameter(torch.randn(1, C, 1))
        self.proj = nn.Conv1d(C, C, 1, bias=False)
    def forward(self, x):
        k = self.proj(x)
        w = torch.softmax((k * self.q).sum(dim=1, keepdim=True), dim=-1)
        return (x * w).sum(dim=-1)  # (B, C)
    
class ResidualMLP(nn.Module):
    def __init__(self, C, hidden):
        super().__init__()
        self.fc1 = nn.Linear(C, hidden, bias=False)
        self.bn1 = nn.BatchNorm1d(hidden)
        self.fc2 = nn.Linear(hidden, C, bias=False)
        self.bn2 = nn.BatchNorm1d(C)
    def forward(self, x):
        identity = x
        x = F.relu(self.bn1(self.fc1(x)))
        x = self.bn2(self.fc2(x))
        return F.relu(x + identity)

class spread(nn.Module):
    def __init__(self, cfg, background):
        super(spread, self).__init__()
        self.cfg = cfg

        self.scale2_net = firstscale_net(self.cfg)
        self.scale1_net = otherscale_net(self.cfg)
        self.scale0_net = otherscale_net(self.cfg)
        self.scaleb_net = lastscale_net(self.cfg)


        self.conv1 = nn.Sequential(nn.Conv1d(64*4, 512, kernel_size=1, bias=False),
                                   nn.BatchNorm1d(512),
                                   nn.LeakyReLU(negative_slope=0.2),
                                   nn.Conv1d(512, 1024, kernel_size=1, bias=False),
                                   nn.BatchNorm1d(1024),
                                   nn.LeakyReLU(negative_slope=0.2))

        self.fc = nn.Sequential(nn.Linear(1024*2, 512, bias=False),
                                nn.BatchNorm1d(512),
                                nn.LeakyReLU(negative_slope=0.2),
                                nn.Dropout(p=cfg.opt.dropout),
                                nn.Linear(512, 256),
                                nn.BatchNorm1d(256),
                                nn.LeakyReLU(negative_slope=0.2),
                                nn.Dropout(p=cfg.opt.dropout),
                                nn.Linear(256, cfg.data.classes)
                                )
        

    def forward(self, x):

        base_gs = x['scale_base']
        scale0_gs = x['scale0']
        scale1_gs = x['scale1']
        scale2_gs = x['scale2']
        batchsize = base_gs.shape[0]

        #in scale2
        out_feat_2_, out_feat_2_down = self.scale2_net(scale2_gs)
        #in scale1
        out_feat_1_, out_feat_1_down = self.scale1_net(scale1_gs, out_feat_2_down, )
        #in scale0
        out_feat_0_, out_feat_0_down = self.scale0_net(scale0_gs, out_feat_1_down, )
        #in scaleb
        if random.random() < 0.4:
            feat = scale0_gs
            
        else:
            feat = base_gs
            
        out_feat_b_down = self.scaleb_net(feat, out_feat_0_down, )

        #class head
        all_down = torch.cat([out_feat_b_down, out_feat_0_, out_feat_1_, out_feat_2_], dim=1)
        x = self.conv1(all_down)
        x1 = F.adaptive_max_pool1d(x, 1).view(batchsize, -1)
        x2 = F.adaptive_avg_pool1d(x, 1).view(batchsize, -1)
        x = torch.cat((x1, x2), dim=1)

        x = self.fc(x)
        
        return x