import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import einsum
from timm.models.layers import DropPath

def square_distance(src, dst):
    """
    Calculate Squared distance between each two points.

    src^T * dst = xn * xm + yn * ym + zn * zm；
    sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
    sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
    dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
         = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst

    Input:
        src: source points, [B, N, C]
        dst: target points, [B, M, C]
    Output:
        dist: per-point square distance, [B, N, M]
    """
    B, N, _ = src.shape
    _, M, _ = dst.shape
    dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
    dist += torch.sum(src ** 2, -1).view(B, N, 1)
    dist += torch.sum(dst ** 2, -1).view(B, 1, M)
    return dist


def query_knn_point(k, xyz, new_xyz):
    dist = square_distance(new_xyz, xyz)
    _, group_idx = dist.topk(k, largest=False)
    return group_idx


def index_points(points, idx):
    """

    Input:
        points: input points data, [B, N, C]
        idx: sample index data, [B, S]
    Return:
        new_points:, indexed points data, [B, S, C]
    """
    device = points.device
    B = points.shape[0]
    view_shape = list(idx.shape)
    view_shape[1:] = [1] * (len(view_shape) - 1)
    repeat_shape = list(idx.shape)
    repeat_shape[0] = 1
    batch_indices = torch.arange(B, dtype=torch.long).cuda().view(view_shape).repeat(repeat_shape)
    new_points = points[batch_indices, idx, :]
    return new_points

def group_local(xyz, k=20):
    """
    Input:
        x: point cloud, [B, 3, N]
    Return:
        group_xyz: [B, 3, N, K]

    """
    xyz = xyz.transpose(2, 1).contiguous()
    idx = query_knn_point(k, xyz, xyz)
    # torch.cuda.empty_cache()
    group_xyz = index_points(xyz, idx)  # [B, npoint, nsample, C]
    # torch.cuda.empty_cache()
    group_xyz = group_xyz.permute(0, 3, 1, 2)

    return group_xyz

def farthest_point_sample(xyz, npoint):
    """
    Input:
        xyz: pointcloud data, [B, N, 3]
        npoint: number of samples
    Return:
        centroids: sampled pointcloud index, [B, npoint]
    """
    device = xyz.device
    B, N, C = xyz.shape
    centroids = torch.zeros(B, npoint, dtype=torch.long)
    distance = torch.ones(B, N) * 1e10
    farthest = torch.randint(0, N, (B,), dtype=torch.long)
    batch_indices = torch.arange(B, dtype=torch.long)
    for i in range(npoint):
        centroids[:, i] = farthest
        centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
        dist = torch.sum((xyz - centroid) ** 2, -1)
        mask = dist < distance
        distance[mask] = dist[mask]
        farthest = torch.max(distance, -1)[1]
    return centroids

def query_ball_point(radius, nsample, xyz, new_xyz):
    """
    Input:
        radius: local region radius
        nsample: max sample number in local region
        xyz: all points, [B, N, 3]
        new_xyz: query points, [B, S, 3]
    Return:
        group_idx: grouped points index, [B, S, nsample]
    """
    device = xyz.device
    B, N, C = xyz.shape
    _, S, _ = new_xyz.shape
    group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
    sqrdists = square_distance(new_xyz, xyz)
    group_idx[sqrdists > radius ** 2] = N
    group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
    group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
    mask = group_idx == N
    group_idx[mask] = group_first[mask]
    return group_idx

# sample_and_group(xyz, points, self.npoint, self.nsample, self.radius, self.use_xyz)

def sample_and_group(xyz, points, npoint, nsample, radius, use_xyz=True):
    """
    Input:
        npoint:
        radius:
        nsample:
        xyz: input points position data, [B, N, 3]
        points: input points data, [B, N, D]
    Return:
        new_xyz: sampled points position data, [B, npoint, nsample, 3]
        new_points: sampled points data, [B, npoint, nsample, 3+D]
    """
    xyz = xyz.transpose(2, 1).contiguous()
    points = points.transpose(2, 1).contiguous()
    B, N, C = xyz.shape
    S = npoint
    fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C]
    torch.cuda.empty_cache()
    new_xyz = index_points(xyz, fps_idx)
    torch.cuda.empty_cache()
    idx = query_ball_point(radius, nsample, xyz, new_xyz)
    torch.cuda.empty_cache()
    grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C]
    torch.cuda.empty_cache()
    grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)
    torch.cuda.empty_cache()

    if points is not None:
        grouped_points = index_points(points, idx)
        if use_xyz:
            new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D]
        else:
            new_points = grouped_points
        
    else:
        new_points = grouped_xyz_norm
    # print(f"new_points shape:{new_points.shape}")
    # print(f"new_xyz shape:{new_xyz.shape}")
    return new_xyz.permute(0, 2,1), new_points.permute(0,3,1,2), grouped_xyz, fps_idx


def sample_and_group_all(xyz, points, use_xyz=True):
    """
    Args:
        xyz: Tensor, (B, 3, nsample)
        points: Tensor, (B, f, nsample)
        use_xyz: boolean

    Returns:
        new_xyz: Tensor, (B, 3, 1)
        new_points: Tensor, (B, f|f+3|3, 1, nsample)
        idx: Tensor, (B, 1, nsample)
        grouped_xyz: Tensor, (B, 3, 1, nsample)
    """
    b, _, nsample = xyz.shape
    device = xyz.device
    new_xyz = torch.zeros((1, 3, 1), dtype=torch.float, device=device).repeat(b, 1, 1)
    grouped_xyz = xyz.reshape((b, 3, 1, nsample))
    idx = torch.arange(nsample, device=device).reshape(1, 1, nsample).repeat(b, 1, 1)
    if points is not None:
        if use_xyz:
            new_points = torch.cat([xyz, points], 1)
        else:
            new_points = points
        new_points = new_points.unsqueeze(2)
    else:
        new_points = grouped_xyz

    return new_xyz, new_points, idx, grouped_xyz


class Conv2d(nn.Module):
    def __init__(self, in_channel, out_channel, kernel_size=(1, 1), stride=(1, 1), if_bn=True, activation_fn=torch.relu):
        super(Conv2d, self).__init__()
        self.af = nn.ReLU()
        self.conv = nn.Conv2d(in_channel, out_channel, kernel_size, stride=stride)
        self.if_bn = if_bn
        self.bn = nn.BatchNorm2d(out_channel)
        self.activation_fn = activation_fn

    def forward(self, input):
        out = self.conv(input)
        if self.if_bn:
            out = self.bn(out)

        if self.activation_fn is not None:
            out = self.af(out)

        return out

class MLP_CONV(nn.Module):
    def __init__(self, in_channel, layer_dims, bn=None):
        super(MLP_CONV, self).__init__()
        self.af = nn.ReLU()
        layers = []
        last_channel = in_channel
        for out_channel in layer_dims[:-1]:
            layers.append(nn.Conv1d(last_channel, out_channel, 1))
            if bn:
                layers.append(nn.BatchNorm1d(out_channel))
            layers.append(self.af)
            last_channel = out_channel
        layers.append(nn.Conv1d(last_channel, layer_dims[-1], 1))
        self.mlp = nn.Sequential(*layers)

    def forward(self, inputs):
        return self.mlp(inputs)

class PointNet_SA_Module(nn.Module):
    def __init__(self, npoint, nsample, radius, in_channel, mlp, if_bn=True, group_all=False, use_xyz=True):
        """
        Args:
            npoint: int, number of points to sample
            nsample: int, number of points in each local region
            radius: float
            in_channel: int, input channel of features(points)
            mlp: list of int,
        """
        super(PointNet_SA_Module, self).__init__()
        self.npoint = npoint
        self.nsample = nsample
        self.radius = radius
        self.mlp = mlp
        self.group_all = group_all
        self.use_xyz = use_xyz
        if use_xyz:
            in_channel += 3

        last_channel = in_channel
        self.mlp_conv = []
        for out_channel in mlp:
            self.mlp_conv.append(Conv2d(last_channel, out_channel, if_bn=if_bn))
            last_channel = out_channel

        self.mlp_conv = nn.Sequential(*self.mlp_conv)

    def forward(self, xyz, points):
        """
        Args:
            xyz: Tensor, (B, 3, N)
            points: Tensor, (B, f, N)

        Returns:
            new_xyz: Tensor, (B, 3, npoint)
            new_points: Tensor, (B, mlp[-1], npoint)
        """
        if self.group_all:
            new_xyz, new_points, idx, grouped_xyz = sample_and_group_all(xyz, points, self.use_xyz)
        else:
            new_xyz, new_points, idx, grouped_xyz = sample_and_group(xyz, points, self.npoint, self.nsample, self.radius, self.use_xyz)

        new_points = self.mlp_conv(new_points)
        new_points = torch.max(new_points, 3)[0]

        return new_xyz, new_points


class LEA(torch.nn.Module):
    """
    Input:
        x: point cloud, [B, C1, N]
    Return:
        x: point cloud, [B, C2, N]
    """

    def __init__(self, input_channel, output_channel, k):
        super(LEA, self).__init__()
        self.num_neigh = k
        if self.num_neigh is not None:
            self.weight_mlp = nn.Sequential(
                nn.Conv2d(input_channel, output_channel, kernel_size=1),
                nn.BatchNorm2d(output_channel),
                nn.LeakyReLU(negative_slope=0.2),
                nn.Conv2d(output_channel, output_channel, kernel_size=1)
            )
        self.input_channel = input_channel
        self.conv = nn.Sequential(
            nn.Conv2d(2 * input_channel, output_channel // 2, kernel_size=1),
            nn.BatchNorm2d(output_channel // 2),
            nn.LeakyReLU(negative_slope=0.2),
            nn.Conv2d(output_channel // 2, output_channel // 2, kernel_size=1),
            nn.BatchNorm2d(output_channel // 2),
            nn.LeakyReLU(negative_slope=0.2),
            nn.Conv2d(output_channel // 2, output_channel, kernel_size=1)
        )

    def forward(self, inputs):
        batch_size, dims, num_points = inputs.shape
        if self.num_neigh is not None:
            neigh_feature = group_local(inputs, k=self.num_neigh)
            central_feat = inputs.unsqueeze(dim=3).repeat(1, 1, 1, self.num_neigh)
            edge_feature = central_feat - neigh_feature
            feature = torch.cat((edge_feature, central_feat), dim=1)
            weight = self.weight_mlp(edge_feature)
            weight = F.softmax(weight, dim=-1)
            feature = self.conv(feature)
            central_feature = einsum('b c i j, b c i j -> b c i', weight, feature)  # b, dim, n
        else:
            inputs = inputs.unsqueeze(-1)
            feature = torch.cat((inputs, inputs), dim=1)
            feature = self.conv(feature)
            central_feature = feature.max(dim=-1, keepdim=False)[0]
        # aggr_feature = torch.cat((edge_feature, central_feat), dim=1)
        return central_feature


class GroupFFN(nn.Module):
    def __init__(self, dim=128, hidden_dim=128, n_knn=16, act_layer=nn.GELU, drop = 0.):
        super().__init__()
        self.linear1 = nn.Sequential(
                        nn.Conv1d(dim, hidden_dim, 1),
                        act_layer())
        self.conv = nn.Sequential(
                        nn.Conv2d(hidden_dim * 2, hidden_dim // 2, 1),
                        # nn.BatchNorm2d(hidden_dim // 2),
                        act_layer(),
                        nn.Conv2d(hidden_dim//2,hidden_dim, 1),
                        # nn.BatchNorm2d(hidden_dim),
                        act_layer())
        self.linear2 = nn.Sequential(nn.Conv1d(hidden_dim, dim, 1))
        self.dim = dim
        self.hidden_dim = hidden_dim
        self.n_knn = n_knn

    def forward(self, feat, xyz):
        """
        Args:
             feat: B, C, N
             xyz: B, 3, N

        Returns:
            feat: (B, C, N)
        """
        B, C, N = feat.shape

        feat = self.linear1(feat)
        pos_flipped = xyz.permute(0, 2, 1).contiguous()
        idx_knn = query_knn_point(self.n_knn, pos_flipped, pos_flipped)
        neigh_feat = index_points(feat.transpose(2, 1), idx_knn).permute(0, 3, 1, 2)
        # print(f"neigh_feat shape:{neigh_feat.shape}")
        feat = feat.unsqueeze(-1).repeat(1, 1, 1, self.n_knn)
        # print(f"feat shape:{feat.shape}")
        # print(f"feat shape:{feat.shape}")
        x = torch.cat([feat, neigh_feat], dim=1)

        x = self.conv(x)
        x = x.mean(dim=-1, keepdim=False)
        x = self.linear2(x)

        return x

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
        self.scale = qk_scale or head_dim ** -0.5

        # self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.q = nn.Linear(dim, dim, bias=qkv_bias)
        self.k = nn.Linear(dim, dim, bias=qkv_bias)
        self.v = nn.Linear(dim, dim, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x, y):

        B, N, C = x.shape
        _, NK, _ = y.shape

        q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)  #  B, H, N, C
        k = self.k(y).reshape(B, NK, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)  #  B, H, NK, C
        v = self.v(y).reshape(B, NK, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)  #  B, H, NK, C

        attn = (q @ k.transpose(-2, -1)) * self.scale #  B, H, N, NK

        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class DAFormer(nn.Module):
    def __init__(self, dim, num_heads, hidden_dim=64, qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0.1, act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.bn1 = nn.LayerNorm(dim)
        self.bn2 = nn.LayerNorm(dim)
        self.attn = Attention(
            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.bn3 = nn.LayerNorm(dim)
        self.ffn = GroupFFN(dim=dim, hidden_dim=hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x, y, xyz):
        """
        Args:
             feat: B, C, N
             xyz: B, 3, N

        Returns:
            feat: (B, C, N)
        """
        # x = x + self.drop_path(self.attn(self.norm1(x)))
        short_cut = x.transpose(2, 1).contiguous()
        x = x.transpose(2, 1).contiguous()
        y = y.transpose(2, 1).contiguous()
        
        # norm_points = self.norm1(points)
        x = self.attn(self.bn1(x), self.bn2(y))
        
        x = short_cut + self.drop_path(x)
        # x = x + self.ffn(x)
        x = x.transpose(2, 1).contiguous()
        x = x + self.drop_path(self.ffn(self.bn3(x.transpose(2, 1).contiguous()).transpose(2, 1).contiguous(), xyz))
        return x


class DPA(nn.Module):
    def __init__(self, dim_feat=512, up_factor=2):
        """Snowflake Point Deconvolution"""
        super(DPA, self).__init__()
        self.up_factor = up_factor
        self.lea = LEA(3, 128, 16)
        self.daformer = DAFormer(dim=128, num_heads=8, hidden_dim=64)
        # self.conv = MLP_CONV(in_channel=128, layer_dims=[256, 512])
        self.mlp_2 = MLP_CONV(in_channel=128 * 3, layer_dims=[256, 128])

        self.fc = nn.Sequential(
            nn.Linear(128, 512),
            nn.LeakyReLU(negative_slope=0.2),
            nn.Linear(512, 512),
            nn.LeakyReLU(negative_slope=0.2),
            nn.Linear(512, 128),
            nn.LeakyReLU(negative_slope=0.2),
            nn.Linear(128, 3 * self.up_factor)
        )


    def forward(self, pcd, feat_global, pcd_prev=None, feat_prev=None):
        """
        Args:
            pcd_prev: Tensor, (B, 3, N_prev)
            feat_global: Tensor, (B, dim_feat, 1)
            K_prev: Tensor, (B, 128, N_prev)

        Returns:
            pcd_child: Tensor, up sampled point cloud, (B, 3, N_prev * up_factor)
            K_curr: Tensor, displacement feature of current step, (B, 128, N_prev * up_factor)
        """
        b, _, n_prev = pcd.shape
        feat = self.lea(pcd)
        feat_prev = feat if feat_prev is None else feat_prev
        pcd_prev = pcd if pcd_prev is None else pcd_prev
        # feat = self.conv(feat)
        feat = torch.cat([feat,
                            torch.max(feat, 2, keepdim=True)[0].repeat((1, 1, feat.size(2))),
                            torch.max(feat_prev, 2, keepdim=True)[0].repeat((1, 1, feat.size(2)))], 1)
        feat = self.mlp_2(feat)
        feat_cur = feat
        
        feat = self.daformer(feat, feat_prev, pcd)
        feat = feat.transpose(2, 1).contiguous()

        point_offset = self.fc(feat).view(b, -1, 3)  # B * 2048 * 3

        pcd_up = pcd.transpose(2, 1).contiguous().unsqueeze(dim=2).repeat(1, 1, self.up_factor, 1).view(b, -1, 3) + point_offset
        pcd_up = pcd_up.transpose(2, 1).contiguous()

        return pcd_up, pcd, feat_cur


class MultiDecoder(nn.Module):
    def __init__(self, dim_feat=512, num_pc=256, num_p0=512, up_factors=None):
        super(MultiDecoder, self).__init__()
        self.num_p0 = num_p0
        self.decoder_coarse = FCLayer(dim_feat=dim_feat, num_pc=num_pc)
        if up_factors is None:
            up_factors = [1]
        else:
            up_factors = up_factors

        uppers = []
        for i, factor in enumerate(up_factors):
            uppers.append(DPA(dim_feat=dim_feat, up_factor=factor))

        self.uppers = nn.ModuleList(uppers)

    def forward(self, feat, partial, return_P0=False):
        """
        Args:
            feat: Tensor, (b, dim_feat, n)
            partial: Tensor, (b, n, 3)
        """
        arr_pcd = []
        pcd = self.decoder_coarse(feat).permute(0, 2, 1).contiguous()  # (B, num_pc, 3)
        pcd = torch.cat([pcd, partial], 1)
        idx_fps = farthest_point_sample(pcd, self.num_p0)
        pcd = index_points(pcd, idx_fps)  # B * 512 * 3
        # print(f"pcd shape:{pcd.shape}")
        arr_pcd.append(pcd)
        feat_pre = None
        pcd_pre = None
        pcd = pcd.permute(0, 2, 1).contiguous()
        for upper in self.uppers:
            pcd, pcd_pre, feat_pre = upper(pcd, feat, pcd_pre, feat_pre)
            arr_pcd.append(pcd.permute(0, 2, 1).contiguous())

        return arr_pcd


class FCLayer(nn.Module):
    def __init__(self, dim_feat=512, num_pc=256):
        super(FCLayer, self).__init__()
        self.num_pc = num_pc
        self.fc = nn.Sequential(
            nn.Linear(dim_feat, 1024),
            nn.LeakyReLU(negative_slope=0.2),
            nn.Linear(1024, 1024),
            nn.LeakyReLU(negative_slope=0.2),
            nn.Linear(1024, 3*num_pc)
        )

    def forward(self, feat):
        """
        Args:
            feat: Tensor (b, dim_feat, 1)
        """
        b = feat.size(0)
        completion = self.fc(feat.squeeze()).reshape(b, 3, self.num_pc)  # (b, 3, 256)
        return completion


class MEA(nn.Module):
    def __init__(self, num_coarse=256, out_dim=1024):
        super(MEA, self).__init__()
        self.num_coarse = num_coarse
        self.sa_module = PointNet_SA_Module(npoint=512, nsample=32, radius=0.4,  in_channel=3, mlp=[32, 32, 64], group_all=False)
        self.gcn_1 = LEA(64, 128, 10)
        self.gcn_2 = LEA(64, 128, 20)
        self.gcn_3 = LEA(64, 128, None)
        self.k = 20
        self.conv1 = nn.Sequential(
            nn.Conv1d(128 + 128 + 128, 512, 1),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(negative_slope=0.2),
            nn.Conv1d(512, out_dim // 2, 1),
        )

    def forward(self, inputs):
        # inputs B * 3 * 32

        batch_size = inputs.size(0)
        l0_xyz = inputs
        l0_points = inputs

        l1_xyz, l1_points = self.sa_module(l0_xyz, l0_points)  # (B, 3, 512), (B, 128, 512)

        x1 = self.gcn_1(l1_points)  # B * 64 * 32
        x2 = self.gcn_2(l1_points)  # B * 128 * 32
        x3 = self.gcn_3(l1_points)  # B * 256 * 32

        x = torch.cat((x1, x2, x3), dim=1)  # B * C * K

        # TODO Post-BN
        x = self.conv1(x)  # [B, C, N]

        feature_1 = F.adaptive_max_pool1d(x, 1).view(batch_size, -1)  # B * 1024
        feature_2 = F.adaptive_avg_pool1d(x, 1).view(batch_size, -1)  # B * 1024
        gl_feat = torch.cat((feature_1, feature_2), dim=1) # B * 2048
        return gl_feat.unsqueeze(-1)


class Model(nn.Module):
    def __init__(self, global_feature_size=1024):
        super(Model, self).__init__()
        self.encoder = MEA(out_dim=global_feature_size)
        self.num_pc = 128
        self.num_p0 = 512
        self.up_factors = [1, 2, 2]
        self.decoder = MultiDecoder(dim_feat=global_feature_size, num_pc=self.num_pc, num_p0=self.num_p0, up_factors=self.up_factors)

    def forward(self, x, gt=None, is_training=True, mean_feature=None, alpha=None):
        feat = self.encoder(x)
        coarse_raw, coarse, coarse_high, fine = self.decoder(feat, x.transpose(2,1).contiguous())
        return coarse_raw, coarse, coarse_high, fine