import itertools

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as f


def knn(x: torch.tensor, k: int) -> int:
    batch_size = x.size(0)
    num_points = x.size(2)

    inner = -2 * torch.matmul(x.transpose(2, 1), x)
    xx = torch.sum(x ** 2, dim=1, keepdim=True)
    pairwise_distance = -xx - inner - xx.transpose(2, 1)

    idx = pairwise_distance.topk(k=k, dim=-1)[1]  # (batch_size, num_points, k)
    idx_base = (torch.arange(0, batch_size, device=idx.device).view(-1, 1, 1) * num_points)
    idx = idx + idx_base
    idx = idx.view(-1)

    return idx


def local_cov(pts: torch.tensor, idx: int) -> torch.tensor:
    batch_size = pts.size(0)
    num_points = pts.size(2)
    pts = pts.view(batch_size, -1, num_points)  # (batch_size, 3, num_points)

    _, num_dims, _ = pts.size()

    x = pts.transpose(2, 1).contiguous()  # (batch_size, num_points, 3)
    x = x.view(batch_size * num_points, -1)[idx, :]  # (batch_size*num_points*2, 3)
    x = x.view(batch_size, num_points, -1, num_dims)  # (batch_size, num_points, k, 3)

    x = torch.matmul(x[:, :, 0].unsqueeze(3), x[:, :, 1].unsqueeze(2))
    # (batch_size, num_points, 3, 1) * (batch_size, num_points, 1, 3) -> (batch_size, num_points, 3, 3)

    x = x.view(batch_size, num_points, 9).transpose(2, 1)  # (batch_size, 9, num_points)
    x = torch.cat((pts, x), dim=1)  # (batch_size, 12, num_points)
    return x


def local_maxpool(x: torch.tensor, idx: int) -> torch.tensor:
    batch_size = x.size(0)
    num_points = x.size(2)
    x = x.view(batch_size, -1, num_points)

    _, num_dims, _ = x.size()

    # (batch_size, num_points, num_dims)
    x = x.transpose(2, 1).contiguous()

    # (batch_size*n, num_dims) -> (batch_size*n*k, num_dims)
    x = x.view(batch_size * num_points, -1)[idx, :]

    # (batch_size, num_points, k, num_dims)
    x = x.view(batch_size, num_points, -1, num_dims)

    # (batch_size, num_points, num_dims)
    x, _ = torch.max(x, dim=2)
    return x


class SkipVariationalEncoder(nn.Module):
    def __init__(self, feat_dims=512) -> None:
        super().__init__()
        self.k = 16
        self.feat_dims = feat_dims
        self.mlp1 = nn.Sequential(
            nn.Conv1d(12, 64, 1),
            nn.ReLU(),
            nn.Conv1d(64, 64, 1),
            nn.ReLU(),
            nn.Conv1d(64, 64, 1),
            nn.ReLU(),
        )
        self.linear1 = nn.Linear(64, 64)
        self.conv1 = nn.Conv1d(64, 128, 1)
        self.linear2 = nn.Linear(128, 128)
        self.conv2 = nn.Conv1d(128, 1024, 1)

    def graph_layer(self, x: torch.tensor, idx: int) -> torch.tensor:
        x = local_maxpool(x, idx)
        x = self.linear1(x)
        x = x.transpose(2, 1)
        x = f.relu(self.conv1(x))
        x = local_maxpool(x, idx)
        x = self.linear2(x)
        x = x.transpose(2, 1)
        x = self.conv2(x)
        return x

    def forward(self, pts: torch.tensor) -> torch.tensor:
        pts = pts.transpose(2, 1)
        idx = knn(pts, k=self.k)
        x = local_cov(pts, idx)
        x = self.mlp1(x)
        local_feat_1 = x
        x = self.graph_layer(x, idx)
        local_feat_2 = x
        cat_feat = torch.cat([local_feat_1, local_feat_2], 1)
        x = torch.max(cat_feat, 2, keepdim=True)[0]
        return x


class FoldingNetDecoder(nn.Module):
    def __init__(self, feat_dims: int, m_points: int) -> None:
        super().__init__()
        self.m = m_points
        grid = self.build_grid(m_points)  # (3, m_points)
        self.grid = torch.from_numpy(grid)
        self.folding1 = nn.Sequential(
            nn.Conv1d(feat_dims + 3, feat_dims, 1),
            nn.ReLU(),
            nn.Conv1d(feat_dims, feat_dims, 1),
            nn.ReLU(),
            nn.Conv1d(feat_dims, 3, 1),
        )
        self.folding2 = nn.Sequential(
            nn.Conv1d(feat_dims + 3, feat_dims, 1),
            nn.ReLU(),
            nn.Conv1d(feat_dims, feat_dims, 1),
            nn.ReLU(),
            nn.Conv1d(feat_dims, 3, 1),
        )

    @staticmethod
    def build_grid(n_points):
        # sample from normal distribution
        vec = np.random.randn(3, n_points)
        # normalize vector
        vec /= np.linalg.norm(vec, axis=0)
        random_sphere = np.array(vec, dtype=np.float32)  # (3, n_points)
        return random_sphere

    def forward(self, x: torch.tensor) -> torch.tensor:
        x = x.transpose(1, 2).repeat(1, 1, self.m)  # # (B, 1, feat_dims) -> (B, feat_dims, 1) -> (B, feat_dims, M)
        grid = self.grid.unsqueeze(0).repeat(x.shape[0], 1, 1).to(x.device)

        cat1 = torch.cat([x, grid], dim=1)
        folding_result1 = self.folding1(cat1)

        cat2 = torch.cat([x, folding_result1], dim=1)
        folding_result2 = self.folding2(cat2)

        return folding_result2.transpose(1, 2), folding_result1.transpose(1, 2)


class SkipVariationalFoldingNet(nn.Module):
    def __init__(self, feat_dims=512, m_points=2048) -> None:
        super().__init__()
        self.encoder = SkipVariationalEncoder(feat_dims)
        self.fc_mu = nn.Conv1d(1024 + 64, feat_dims, 1)
        self.fc_var = nn.Conv1d(1024 + 64, feat_dims, 1)
        self.decoder = FoldingNetDecoder(feat_dims, m_points)

    @staticmethod
    def sample_z(mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
        """
        :mu: (Tensor) Mean of the latent Gaussian
        :sigma: (Tensor) Standard deviation of the latent Gaussian
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps * std + mu

    def forward(self, input: torch.tensor):
        x = self.encoder(input)
        mu = self.fc_mu(x)
        sigma = self.fc_var(x)
        feature = self.sample_z(mu.transpose(2, 1), sigma.transpose(2, 1))
        folding2, folding1 = self.decoder(feature)
        return folding2, folding1, mu, sigma

    def get_parameter(self, *args):
        return list(self.encoder.parameters()) + list(self.decoder.parameters())


if __name__ == "__main__":
    pc = torch.randn(8, 2048, 3)
    vae = SkipVariationalFoldingNet()
    fold2, fold1, mu, sigma = vae(pc)
    print(fold2.shape, mu.shape, sigma.shape)

    total = sum(p.numel() for p in vae.parameters())
    print("Total params: %.2fM" % (total / 1e6))

