import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.init import trunc_normal_
from pathlib import Path
import sys, math

from utils.pointnet2_ops_lib.pointnet2_ops.pointnet2_utils import FurthestPointSampling

sys.path.append(str(Path(__file__).absolute().parent.parent))
from utils.timm.models.layers import DropPath
from pointnet2_ops import pointnet2_utils
from torch.cuda.amp import autocast
from pointnet2_ops.pointnet2_utils import furthest_point_sample
all_dist = [[] for _ in range(10)]
spse_m = None

@autocast(False)
def calc_pwd(x):
    x2 = x.square().sum(dim=2, keepdim=True)
    return x2 + x2.transpose(1, 2) + torch.bmm(x, x.transpose(1,2).mul(-2))

def get_graph_feature(x, idx):
    B, N, C = x.shape
    k = idx.shape[-1]
    nbr = torch.gather(x, 1, idx.view(B, N*k, 1).expand(-1, -1, C)).view(B*N, k, C)
    x = x.view(B*N, 1, C).expand(-1, k, -1)
    return nbr-x

def get_nbr_feature(x, idx):
    B, N, k = idx.shape
    C = x.shape[-1]
    nbr = torch.gather(x, 1, idx.view(B, N*k, 1).expand(-1, -1, C)).view(B*N*k, C)
    return nbr
class Local_Aggregation(nn.Module):

    def __init__(self,in_dim,out_dim
                 ):
        super().__init__()
        self.sample_fn = furthest_point_sample
        self.conv = nn.Conv2d(in_dim // 2 + 3, in_dim // 2, 1, 1, bias=False)
        self.bn = nn.BatchNorm2d(in_dim // 2, eps=1e-05, momentum=0.1, affine=True,
                                 track_running_stats=True)
        self.action = nn.ReLU(inplace=True)
        self.PE=nn.Sequential(nn.Conv2d(3,in_dim,1,1,bias=False),nn.BatchNorm2d(in_dim))
        self.conv3x3 = nn.Sequential(
            nn.Conv2d(in_dim // 2 + 3, out_dim // 2 + 27, [1, 3], [1, 1], padding=[0, 1], bias=False),
            nn.Conv2d(in_dim // 2 + 3, out_dim // 2 + 27, [1, 3], [1, 1], padding=[0, 1], bias=False),
            nn.Conv2d(in_dim // 2 + 3, out_dim // 2 + 27, [1, 3], [1, 1], padding=[0, 1], bias=False))
        self.bn3x3 = nn.BatchNorm2d(out_dim // 2, eps=1e-05, momentum=0.1, affine=True,
                                    track_running_stats=True)
        self.add_conv = nn.Conv2d(3, out_dim // 2, [1, 9], [1, 1], bias=False)
        self.pre_conv=nn.Sequential(nn.Conv2d(in_dim,in_dim,1,1,bias=False),nn.BatchNorm2d(in_dim),nn.ReLU(inplace=True))

        self.bnp = nn.BatchNorm2d(in_dim, eps=1e-05, momentum=0.1, affine=True,
                                  track_running_stats=True)
        self.post_conv = nn.Sequential(
            nn.Conv1d(out_dim, out_dim, 1, bias=False),
            nn.BatchNorm1d(out_dim, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))
    def forward(self, x,knn,dp) -> torch.Tensor:
        # p: position, f: feature
        #p, f = pf
        # neighborhood_features

        #grouping
        #dp, fj = self.grouper(p, p, f)


        #sp = dp.permute(0,2,3,1).reshape(-1,32,3).contiguous()
        #idx = self.sample_fn(sp,8).long()
        #dp = torch.gather(dp, 3, idx.reshape(f.shape[0],-1,8).unsqueeze(1).expand(-1, 3, -1,-1))
        #fj = torch.gather(fj, 3, idx.reshape(f.shape[0],-1,8).unsqueeze(1).expand(-1, f.shape[1], -1,-1))

        f = x
        channel = f.shape[2]
        length = knn.shape[-1]
        knn = knn.long()
        fj = torch.gather(f.unsqueeze(-1).expand(-1,-1,-1,length),dim=1,index= knn.unsqueeze(2).expand(-1,-1,channel,-1))
        fj = fj.permute(0,2,1,3)
        f = f.permute(0,2,1)


        fj = self.bnp(fj-f.unsqueeze(-1))+self.PE(dp)
        fj = self.action(fj)
        fj = self.pre_conv(fj)

        fj0 = fj[:, :channel//2, :, :]
        fj = fj[:, channel//2:, :, :]

        #Seq. Conv Path
        fj = torch.cat([dp, fj], 1)
        fj0 = torch.cat([dp, fj0], 1)


        num_position_features = 3
        bt, num_features, num_samples, num_points = fj.shape
        #Seq. mapping
        sorted_indices_xyz = torch.argsort(dp, dim=-1)
        sorted_indices_xyz__ = torch.argsort(sorted_indices_xyz, dim=-1)
        sorted_indices_xyz = sorted_indices_xyz.unsqueeze(1).expand(-1, num_features, -1, -1, -1)

        features = torch.gather(fj.unsqueeze(2).expand(-1, -1, num_position_features, -1, -1), dim=-1,
                                index=sorted_indices_xyz)  # .view(bt, num_features, num_samples*3, -1)

        #reparam position fefinement conv
        self.add_conv.weight.data[:, :, :, :3] = self.conv3x3[0].weight.data[27:, :3, :, 0:3]
        self.add_conv.weight.data[:, :, :, 3:6] = self.conv3x3[1].weight.data[27:, :3, :, 0:3]
        self.add_conv.weight.data[:, :, :, 6:9] = self.conv3x3[2].weight.data[27:, :3, :, 0:3]
        shaped_channel = num_features - 3

        #Seq.Conv and Seq.Fuse
        sorted_indices_xyz__ = sorted_indices_xyz__.unsqueeze(1).expand(-1, shaped_channel + 27, -1, -1, -1)
        list_f = []
        for i in range(3):
            #Adaconv
            self.conv3x3[i].weight.data[:, :3, :, 1] = -(
                    self.conv3x3[i].weight.data[:, :3, :, 0] + self.conv3x3[i].weight.data[:, :3, :, 2])
            sort_f = self.conv3x3[i](features[:, :, i, :, :])
            list_f.append(torch.gather(sort_f, dim=-1, index=sorted_indices_xyz__[:, :, i, :, :]))
        fj = list_f[0] + list_f[1] + list_f[2]

        #position fefinement （SPR）
        dfj = fj[:, :27, :, :].view(bt, 3, 9, -1).permute(0, 1, 3, 2)  # .contiguous()
        dfj = self.add_conv(dfj).reshape(bt, -1, num_samples, num_points)  # .contiguous()
        fj = fj[:, 27:, :, :] + dfj

        fj = self.action(self.bn3x3(fj))

        # Reduction
        f = torch.max(fj,dim=-1,keepdim=False)[0]

        # MLP path
        fj0 = self.action(self.bn(self.conv(fj0)))
        fj0 = torch.max(fj0, dim=-1, keepdim=False)[0]
        #fj0 = torch.sum(fj0, dim=-1, keepdim=False)

        #channel mix
        f = torch.cat([fj0, f], 1)

        f = self.post_conv(f)
        return f
class SA(nn.Module):


    def __init__(self, in_dim, out_dim, bn_momentum, init=0.):
        super().__init__()
        self.proj = nn.Linear(in_dim, out_dim, bias=False)
        self.bn = nn.BatchNorm1d(out_dim, momentum=bn_momentum)
        nn.init.constant_(self.bn.weight, init)

    def forward(self, x, prev_knn):
        knn = prev_knn
        length = knn.shape[-1]
        knn = knn.long()
        x = self.proj(x)
        B, N, C = x.shape
        fj = torch.gather(x.unsqueeze(-1).expand(-1,-1,-1,length),dim=1,index= knn.unsqueeze(2).expand(-1,-1,C,-1))
        x = torch.max(fj,dim=-1,keepdim=False)[0] - x
        x = self.bn(x.view(B * N, -1)).view(B, N, -1)
        return x


class Mlp(nn.Module):
    def __init__(self, in_dim, mlp_ratio, bn_momentum, act, init=0.):
        super().__init__()
        hid_dim = round(in_dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, hid_dim),
            act(),
            nn.Linear(hid_dim, in_dim, bias=False),
            nn.BatchNorm1d(in_dim, momentum=bn_momentum),
        )
        nn.init.constant_(self.mlp[-1].weight, init)

    def forward(self, x):
        B, N, C = x.shape
        x = self.mlp(x.view(B*N, -1)).view(B, N, -1)
        return x

class Block(nn.Module):
    def __init__(self, dim, depth, drop_path, mlp_ratio, bn_momentum, act):
        super().__init__()
        self.depth = depth
        self.DSConvs = nn.ModuleList([
            Local_Aggregation(dim, dim) for _ in range(depth//2)
        ])
        self.skip_convs = nn.ModuleList([
            nn.Sequential(
                nn.Linear(dim, dim, bias=False),
                nn.BatchNorm1d(dim, momentum=bn_momentum),
            ) for _ in range(depth//2)
        ])
        self.mlp = Mlp(dim, mlp_ratio, bn_momentum, act, 0.2)
        self.mlps = nn.ModuleList([
            Mlp(dim, mlp_ratio, bn_momentum, act) for _ in range(depth // 2)
        ])
        if isinstance(drop_path, list):
            drop_rates = drop_path
        else:
            drop_rates = torch.linspace(0., drop_path, self.depth).tolist()
        self.drop_paths = nn.ModuleList([
            DropPath(dpr) for dpr in drop_rates
        ])
        self.relu = nn.ReLU()
        self.sample_fn = furthest_point_sample

    def forward(self, x, knn):
        x = x + self.drop_paths[0](self.mlp(x))
        B, N, C = x.shape

        knns, xyz = knn
        xyz_ch = xyz.shape[-1]
        B = xyz.shape[0]
        length = knns.shape[-1]
        knns = knns.long()

        dp = torch.gather(xyz.unsqueeze(-1).expand(-1, -1, -1, length), dim=1,
                          index=knns.unsqueeze(2).expand(-1, -1, xyz_ch, -1)) - xyz.unsqueeze(-1)

        # feature preprocess
        dp = dp.permute(0, 2, 1, 3)
        sp = dp.permute(0, 2, 3, 1).reshape(-1, length, 3).contiguous()

        #decentering
        dp_r = dp[:, 0, :, :] * dp[:, 0, :, :] + dp[:, 1, :, :] * dp[:, 1, :, :] + dp[:, 2, :, :] * dp[:, 2, :, :]
        dp = dp / torch.max(dp_r.sqrt(), dim=-1)[0].reshape(B, 1, -1, 1)

        # FPS  32->8
        idx = self.sample_fn(sp, 8).long()
        dp = torch.gather(dp, 3, idx.reshape(B, -1, 8).unsqueeze(1).expand(-1, 3, -1, -1))
        knns = torch.gather(knns, -1, idx.reshape(B, -1, 8))
        for i in range(self.depth):
            if i % 2 == 1:
                x = x + self.drop_paths[i](self.mlps[i // 2](x))
            else:
                x = self.relu(self.skip_convs[i // 2](x.view(B*N, -1)).view(B, N, -1) + self.drop_paths[i](self.DSConvs[i // 2](x, knns, dp).permute(0, 2, 1).contiguous()))
        return x


class Stage(nn.Module):
    def __init__(self, args, depth=0):
        super().__init__()

        self.depth = depth

        self.first = first = depth == 0
        self.last = last = depth == len(args.depths) - 1

        self.n = args.ns[depth]
        self.k = args.ks[depth]

        dim = args.dims[depth]

        nbr_in_dim = 4 if self.first else 3
        nbr_hid_dim = args.nbr_dims[0] if first else args.nbr_dims[1] // 2
        nbr_out_dim = dim if first else args.nbr_dims[1]
        self.nbr_embed = nn.Sequential(
            nn.Linear(nbr_in_dim, nbr_hid_dim // 2, bias=False),
            nn.BatchNorm1d(nbr_hid_dim // 2, momentum=args.bn_momentum),
            args.act(),
            nn.Linear(nbr_hid_dim // 2, nbr_hid_dim, bias=False),
            nn.BatchNorm1d(nbr_hid_dim, momentum=args.bn_momentum),
            args.act(),
            nn.Linear(nbr_hid_dim, nbr_out_dim, bias=False)
        )
        self.nbr_bn = nn.BatchNorm1d(dim, momentum=args.bn_momentum)
        nn.init.constant_(self.nbr_bn.weight, 0.8 if first else 0.2)
        self.nbr_proj = nn.Identity() if first else nn.Linear(nbr_out_dim, dim, bias=False)
        self.sample_fn = furthest_point_sample

        if not first:
            in_dim = args.dims[depth - 1]
            self.sa = SA(in_dim, dim, args.bn_momentum, 0.3)
            self.skip_proj = nn.Sequential(
                nn.Linear(in_dim, dim, bias=False),
                nn.BatchNorm1d(dim, momentum=args.bn_momentum)
            )
            nn.init.constant_(self.skip_proj[1].weight, 0.3)

        self.blk = Block(dim, args.depths[depth], args.drop_paths[depth], args.mlp_ratio, args.bn_momentum, args.act)

        self.cor_std = 1 / args.cor_std[depth]
        self.cor_head = nn.Sequential(
            nn.Linear(dim, 32, bias=False),
            nn.BatchNorm1d(32, momentum=args.bn_momentum),
            args.act(),
            nn.Linear(32, 3, bias=False),
        )

        if not last:
            self.sub_stage = Stage(args, depth + 1)
        self.relu = nn.ReLU()
    def forward(self, x, xyz, prev_knn, pwd):
        """
        x: B x N x C
        """
        # downsampling
        if not self.first:

            B, N, C = x.shape
            xyz = xyz[:, :self.n].contiguous()
            x = self.skip_proj(x.view(B*N, C)).view(B, N, -1)[:, :self.n] + self.sa(x, prev_knn[0])[:, :self.n]

        _, knn = pwd[:, :self.n, :self.n].topk(k=self.k, dim=-1, largest=False, sorted=False)


        # spatial encoding
        # We retain the spatial encoding part of DeLA, which works synergistically with downsampling to
        # achieve the same effect as Set Abstraction in PointNeXt, but with less computational cost.
        B, N, k = knn.shape
        nbr = get_graph_feature(xyz, knn).view(-1, 3)
        if self.first:
            height = xyz[..., 1:2] / 40
            height -= height.min(dim=1, keepdim=True)[0]
            nbr = torch.cat([nbr, get_nbr_feature(height, knn)], dim=1)
        nbr = self.nbr_embed(nbr).view(B * N, k, -1).max(dim=1)[0]
        nbr = self.nbr_proj(nbr)
        nbr = self.nbr_bn(nbr).view(B, N, -1)
        x = nbr if self.first else nbr + x

        # main block
        knn_xyz= (knn, xyz * self.cor_std)
        x = self.blk(x, knn_xyz)

        # next stage
        if not self.last:
            sub_x, sub_c = self.sub_stage(x, xyz, knn_xyz, pwd)
        else:
            sub_x = x
            sub_c = None

        # regularization
        if self.training:
            rel_k = torch.randint(self.k, (B, N, 1), device=x.device)
            rel_k = torch.gather(knn_xyz[0].long(), 2, rel_k)
            rel_cor = get_graph_feature(xyz, rel_k).flatten(1).mul_(self.cor_std)
            # print(rel_cor.std(dim=0))
            rel_p = get_graph_feature(x, rel_k).flatten(1)
            rel_p = self.cor_head(rel_p)
            closs = F.mse_loss(rel_p, rel_cor)
            sub_c = sub_c + closs if sub_c is not None else closs

        return sub_x, sub_c


class DSConv_Cls(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.stage = Stage(args)

        in_dim = args.dims[-1]
        out_dim = args.num_classes

        self.head = nn.Sequential(
            nn.BatchNorm1d(in_dim * 2, momentum=args.bn_momentum),
            nn.Linear(in_dim*2, out_dim)
        )

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)

    def forward(self, xyz):
        if not self.training:
            idx = pointnet2_utils.furthest_point_sample(xyz, 1024).long()
            xyz = torch.gather(xyz, 1, idx.unsqueeze(-1).expand(-1, -1, 3))
        pwd = calc_pwd(xyz)
        x, closs = self.stage(None, xyz, None, pwd)
        x = torch.cat([x.mean(dim=1), x.std(dim=1)], dim=1)
        if self.training:
            return self.head(x), closs
        else:
            return self.head(x)

