import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint as torch_checkpoint
from torch.nn.init import trunc_normal_
import sys
from pathlib import Path

sys.path.append(str(Path(__file__).absolute().parent.parent))
from utils.timm.models.layers import DropPath
from pointnet2_ops.pointnet2_utils import furthest_point_sample

all_dist = [[] for _ in range(10)]
spse_m = None


def checkpoint(function, *args, **kwargs):
    return torch_checkpoint(function, *args, use_reentrant=False, **kwargs)


class SA(nn.Module):
    def __init__(self, in_dim, out_dim, bn_momentum, init=0.):
        super().__init__()
        self.proj = nn.Linear(in_dim, out_dim, bias=False)
        self.bn = nn.BatchNorm1d(out_dim, momentum=bn_momentum)
        nn.init.constant_(self.bn.weight, init)

    def forward(self, x, prev_knn):
        knn = prev_knn
        length = knn.shape[-1]
        knn = knn.long()
        x = self.proj(x)
        B, N, C = x.shape
        fj = torch.gather(x.unsqueeze(-1).expand(-1,-1,-1,length),dim=1,index= knn.unsqueeze(2).expand(-1,-1,C,-1))
        x = torch.max(fj,dim=-1,keepdim=False)[0] - x
        x = self.bn(x.view(B * N, -1)).view(B, N, -1)
        return x


class Local_Aggregation(nn.Module):


    def __init__(self,in_dim,out_dim
                 ):
        super().__init__()
        self.sample_fn = furthest_point_sample
        self.conv = nn.Conv2d(in_dim // 2 + 3, in_dim // 2, 1, 1, bias=False)
        self.bn = nn.BatchNorm2d(in_dim // 2, eps=1e-05, momentum=0.1, affine=True,
                                 track_running_stats=True)
        self.action = nn.ReLU(inplace=True)
        self.PE=nn.Sequential(nn.Conv2d(3,in_dim,1,1,bias=False),nn.BatchNorm2d(in_dim))
        self.conv3x3 = nn.Sequential(
            nn.Conv2d(in_dim // 2 + 3, out_dim // 2 + 27, [1, 3], [1, 1], padding=[0, 1], bias=False),
            nn.Conv2d(in_dim // 2 + 3, out_dim // 2 + 27, [1, 3], [1, 1], padding=[0, 1], bias=False),
            nn.Conv2d(in_dim // 2 + 3, out_dim // 2 + 27, [1, 3], [1, 1], padding=[0, 1], bias=False))
        self.bn3x3 = nn.BatchNorm2d(out_dim // 2, eps=1e-05, momentum=0.1, affine=True,
                                    track_running_stats=True)
        self.add_conv = nn.Conv2d(3, out_dim // 2, [1, 9], [1, 1], bias=False)
        self.pre_conv=nn.Sequential(nn.Conv2d(in_dim,in_dim,1,1,bias=False),nn.BatchNorm2d(in_dim),nn.ReLU(inplace=True))

        self.bnp = nn.BatchNorm2d(in_dim, eps=1e-05, momentum=0.1, affine=True,
                                  track_running_stats=True)
        self.post_conv = nn.Sequential(
            nn.Conv1d(out_dim, out_dim, 1, bias=False),
            nn.BatchNorm1d(out_dim, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))
    def forward(self, x,knn,dp) -> torch.Tensor:
        # p: position, f: feature
        #p, f = pf
        # neighborhood_features

        #grouping
        #dp, fj = self.grouper(p, p, f)

        #FPS  32->8
        #sp = dp.permute(0,2,3,1).reshape(-1,32,3).contiguous()
        #idx = self.sample_fn(sp,8).long()
        #dp = torch.gather(dp, 3, idx.reshape(f.shape[0],-1,8).unsqueeze(1).expand(-1, 3, -1,-1))
        #fj = torch.gather(fj, 3, idx.reshape(f.shape[0],-1,8).unsqueeze(1).expand(-1, f.shape[1], -1,-1))

        f = x
        channel = f.shape[2]
        length = knn.shape[-1]
        #knn = knn.long()
        fj = torch.gather(f.unsqueeze(-1).expand(-1,-1,-1,length),dim=1,index= knn.unsqueeze(2).expand(-1,-1,channel,-1))
        fj = fj.permute(0,2,1,3)
        f = f.permute(0,2,1)


        fj = self.bnp(fj-f.unsqueeze(-1))+self.PE(dp)
        fj = self.action(fj)
        fj = self.pre_conv(fj)
        #fj = get_aggregation_feautres(p, dp, f, fj, self.feature_type)

        fj0 = fj[:, :channel//2, :, :]
        fj = fj[:, channel//2:, :, :]

        #Seq. Conv Path
        fj = torch.cat([dp, fj], 1)
        fj0 = torch.cat([dp, fj0], 1)


        num_position_features = 3
        bt, num_features, num_samples, num_points = fj.shape
        #Seq. mapping
        sorted_indices_xyz = torch.argsort(dp, dim=-1)
        sorted_indices_xyz__ = torch.argsort(sorted_indices_xyz, dim=-1)
        sorted_indices_xyz = sorted_indices_xyz.unsqueeze(1).expand(-1, num_features, -1, -1, -1)

        features = torch.gather(fj.unsqueeze(2).expand(-1, -1, num_position_features, -1, -1), dim=-1,
                                index=sorted_indices_xyz)  # .view(bt, num_features, num_samples*3, -1)

        #reparam position fefinement conv
        self.add_conv.weight.data[:, :, :, :3] = self.conv3x3[0].weight.data[27:, :3, :, 0:3]
        self.add_conv.weight.data[:, :, :, 3:6] = self.conv3x3[1].weight.data[27:, :3, :, 0:3]
        self.add_conv.weight.data[:, :, :, 6:9] = self.conv3x3[2].weight.data[27:, :3, :, 0:3]
        shaped_channel = num_features - 3

        #Seq.Conv and Seq.Fuse
        sorted_indices_xyz__ = sorted_indices_xyz__.unsqueeze(1).expand(-1, shaped_channel + 27, -1, -1, -1)
        list_f = []
        for i in range(3):
            #Adaconv
            self.conv3x3[i].weight.data[:, :3, :, 1] = -(
                    self.conv3x3[i].weight.data[:, :3, :, 0] + self.conv3x3[i].weight.data[:, :3, :, 2])
            sort_f = self.conv3x3[i](features[:, :, i, :, :])
            list_f.append(torch.gather(sort_f, dim=-1, index=sorted_indices_xyz__[:, :, i, :, :]))
        fj = list_f[0] + list_f[1] + list_f[2]

        #position fefinement （SPR）
        dfj = fj[:, :27, :, :].view(bt, 3, 9, -1).permute(0, 1, 3, 2)  # .contiguous()
        dfj = self.add_conv(dfj).reshape(bt, -1, num_samples, num_points)  # .contiguous()
        fj = fj[:, 27:, :, :] + dfj

        fj = self.action(self.bn3x3(fj))

        # Reduction
        f = torch.max(fj,dim=-1,keepdim=False)[0]

        # MLP path
        fj0 = self.action(self.bn(self.conv(fj0)))
        fj0 = torch.sum(fj0, dim=-1, keepdim=False)

        #channel mix
        f = torch.cat([fj0, f], 1)

        f = self.post_conv(f)
        return f


class Mlp(nn.Module):
    def __init__(self, in_dim, mlp_ratio, bn_momentum, act, init=0.):
        super().__init__()
        hid_dim = round(in_dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, hid_dim),
            act(),
            nn.Linear(hid_dim, in_dim, bias=False),
            nn.BatchNorm1d(in_dim, momentum=bn_momentum),
        )
        nn.init.constant_(self.mlp[-1].weight, init)

    def forward(self, x):
        B, N, C = x.shape
        x = self.mlp(x.view(B * N, -1)).view(B, N, -1)
        return x


class Block(nn.Module):
    def __init__(self, dim, depth, drop_path, mlp_ratio, bn_momentum, act):
        super().__init__()

        self.depth = depth
        self.DSConvs = nn.ModuleList([
            Local_Aggregation(dim, dim) for _ in range(depth // 2)
        ])
        self.skip_convs = nn.ModuleList([
            nn.Sequential(
                nn.Linear(dim, dim, bias=False),
                nn.BatchNorm1d(dim, momentum=bn_momentum),
            ) for _ in range(depth // 2)
        ])
        self.mlp = Mlp(dim, mlp_ratio, bn_momentum, act, 0.2)
        self.mlps = nn.ModuleList([
            Mlp(dim, mlp_ratio, bn_momentum, act) for _ in range(depth // 2)
        ])

        if isinstance(drop_path, list):
            drop_rates = drop_path
            self.dp = [dp > 0. for dp in drop_path]
        else:
            drop_rates = torch.linspace(0., drop_path, self.depth).tolist()
            self.dp = [drop_path > 0.] * depth
        # print(drop_rates)
        self.drop_paths = nn.ModuleList([
            DropPath(dpr) for dpr in drop_rates
        ])
        self.relu = nn.ReLU()
        self.sample_fn = furthest_point_sample

    def drop_path(self, x, i, pts):
        if not self.dp[i] or not self.training:
            return x
        return torch.cat([self.drop_paths[i](xx) for xx in torch.split(x, pts, dim=1)], dim=1)

    def forward(self, x, knn, pts=None):
        x = x + self.drop_path(self.mlp(x), 0, pts)
        B, N, C = x.shape

        knns, xyz = knn
        xyz_ch = xyz.shape[-1]
        B = xyz.shape[0]
        length = knns.shape[-1]
        knns = knns.long()

        dp = torch.gather(xyz.unsqueeze(-1).expand(-1, -1, -1, length), dim=1,
                          index=knns.unsqueeze(2).expand(-1, -1, xyz_ch, -1)) - xyz.unsqueeze(-1)

        # feature preprocess
        dp = dp.permute(0, 2, 1, 3)
        sp = dp.permute(0, 2, 3, 1).reshape(-1, length, 3).contiguous()

        dp_r = dp[:, 0, :, :] * dp[:, 0, :, :] + dp[:, 1, :, :] * dp[:, 1, :, :] + dp[:, 2, :, :] * dp[:, 2, :, :]
        dp = dp / torch.max(dp_r.sqrt(), dim=-1)[0].reshape(B, 1, -1, 1)

        idx = self.sample_fn(sp, 8).long()
        dp = torch.gather(dp, 3, idx.reshape(B, -1, 8).unsqueeze(1).expand(-1, 3, -1, -1))
        knns = torch.gather(knns, -1, idx.reshape(B, -1, 8))

        for i in range(self.depth):
            if i % 2 == 1:
                x = x + self.drop_path(self.mlps[i // 2](x), i, pts)
            else:
                x = self.relu(self.skip_convs[i//2](x.view(B * N, -1)).view(B, N, -1) + self.drop_path(
                    self.DSConvs[i//2](x, knns, dp).permute(0, 2, 1).contiguous(), i, pts))
        return x


class Stage(nn.Module):
    def __init__(self, args, depth=0):
        super().__init__()

        self.depth = depth
        self.up_depth = len(args.depths) - 1

        self.first = first = depth == 0
        self.last = last = depth == self.up_depth
        self.relu = nn.ReLU(inplace=True)
        self.k = args.ks[depth]

        self.cp = cp = args.use_cp
        cp_bn_momentum = args.cp_bn_momentum if cp else args.bn_momentum

        dim = args.dims[depth]
        nbr_in_dim = 10 if first else 3
        nbr_hid_dim = args.nbr_dims[0] if first else args.nbr_dims[1] // 2
        nbr_out_dim = dim if first else args.nbr_dims[1]
        self.nbr_embed = nn.Sequential(
            nn.Linear(nbr_in_dim, nbr_hid_dim // 2, bias=False),
            nn.BatchNorm1d(nbr_hid_dim // 2, momentum=cp_bn_momentum),
            args.act(),
            nn.Linear(nbr_hid_dim // 2, nbr_hid_dim, bias=False),
            nn.BatchNorm1d(nbr_hid_dim, momentum=cp_bn_momentum),
            args.act(),
            nn.Linear(nbr_hid_dim, nbr_out_dim, bias=False),
        )
        self.nbr_bn = nn.BatchNorm1d(dim, momentum=args.bn_momentum)
        nn.init.constant_(self.nbr_bn.weight, 0.8 if first else 0.2)
        self.nbr_proj = nn.Identity() if first else nn.Linear(nbr_out_dim, dim, bias=False)

        self.sp_dim = nbr_out_dim

        if not first:
            in_dim = args.dims[depth - 1]
            self.sa = SA(in_dim, dim, args.bn_momentum, 0.3)
            self.skip_proj = nn.Sequential(
                nn.Linear(in_dim, dim, bias=False),
                nn.BatchNorm1d(dim, momentum=args.bn_momentum)
            )
            nn.init.constant_(self.skip_proj[1].weight, 0.3)

        self.blk = Block(dim, args.depths[depth], args.drop_paths[depth], args.mlp_ratio, cp_bn_momentum, args.act)
        self.drop = DropPath(args.head_drops[depth])
        self.postproj = nn.Sequential(
            nn.BatchNorm1d(dim, momentum=args.bn_momentum),
            nn.Linear(dim, args.head_dim, bias=False),
        )
        nn.init.constant_(self.postproj[0].weight, (args.dims[0] / dim) ** 0.5)

        self.cor_std = 1 / args.cor_std[depth]
        self.cor_head = nn.Sequential(
            nn.Linear(dim, 32, bias=False),
            nn.BatchNorm1d(32, momentum=args.bn_momentum),
            args.act(),
            nn.Linear(32, 3, bias=False),
        )

        if not last:
            self.sub_stage = Stage(args, depth + 1)

    def local_aggregation(self, x, knn, pts):
        x = x.unsqueeze(0)
        x = self.blk(x, knn, pts)
        x = x.squeeze(0)
        return x

    def forward(self, x, xyz, prev_knn, indices, pts_list):
        """
        x: N x C
        """
        # downsampling
        if not self.first:
            ids = indices.pop()
            xyz = xyz[ids]
            x = self.skip_proj(x)[ids] + self.sa(x.unsqueeze(0), prev_knn[0]).squeeze(0)[ids]

        knn = indices.pop()

        xyz_std = xyz * self.cor_std

        # spatial encoding
        # We retain the spatial encoding part of DeLA, which works synergistically with downsampling to
        # achieve the same effect as Set Abstraction in PointNeXt, but with less computational cost.

        N, k = knn.shape
        nbr = xyz[knn] - xyz.unsqueeze(1)
        nbr = torch.cat([nbr, x[knn]], dim=-1).view(-1, 10) if self.first else nbr.view(-1, 3)
        if self.training and self.cp:
            nbr.requires_grad_()
        nbr_embed_func = lambda x: self.nbr_embed(x).view(N, k, -1).max(dim=1)[0]
        nbr = checkpoint(nbr_embed_func, nbr) if self.training and self.cp else nbr_embed_func(nbr)
        nbr = self.nbr_proj(nbr)
        nbr = self.nbr_bn(nbr)
        x = nbr if self.first else nbr + x

        # main block
        nbr_knn = knn.unsqueeze(0).long()
        knn_xyz = (nbr_knn, xyz_std.view(1, -1, 3))
        pts = pts_list.pop() if pts_list is not None else None
        x = checkpoint(self.local_aggregation, x, knn_xyz, pts) if self.training and self.cp else self.local_aggregation(x,
                                                                                                                     knn_xyz,
                                                                                                                     pts)

        # get subsequent feature maps
        if not self.last:
            sub_x, sub_c = self.sub_stage(x, xyz, knn_xyz, indices, pts_list)
        else:
            sub_x = sub_c = None

        # regularization
        if self.training:
            rel_k = torch.randint(self.k, (N, 1), device=x.device)
            rel_k = torch.gather(knn_xyz[0].long().squeeze(0), 1, rel_k).squeeze(1)
            rel_cor = (xyz[rel_k] - xyz)
            rel_cor.mul_(self.cor_std)
            # print(rel_cor.std(dim=0))
            rel_p = x[rel_k] - x
            rel_p = self.cor_head(rel_p)
            closs = F.mse_loss(rel_p, rel_cor)
            sub_c = sub_c + closs if sub_c is not None else closs

        # upsampling
        x = self.postproj(x)
        if not self.first:
            back_nn = indices[self.depth - 1]
            x = x[back_nn]
        x = self.drop(x)
        sub_x = sub_x + x if sub_x is not None else x

        return sub_x, sub_c


class DSConvSeg(nn.Module):
    def __init__(self, args):
        super().__init__()

        # bn momentum for checkpointed layers
        args.cp_bn_momentum = 1 - (1 - args.bn_momentum) ** 0.5

        self.stage = Stage(args)

        hid_dim = args.head_dim
        out_dim = args.num_classes

        self.head = nn.Sequential(
            nn.BatchNorm1d(hid_dim, momentum=args.bn_momentum),
            args.act(),
            nn.Linear(hid_dim, out_dim)
        )

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)

    def forward(self, xyz, x, indices, pts_list=None):
        indices = indices[:]
        x, closs = self.stage(x, xyz, None, indices, pts_list)
        if self.training:
            return self.head(x), closs
        return self.head(x)