import torch
import torch.nn as nn


class ShapeClassifier(nn.Module):
    def __init__(self, args):
        super(ShapeClassifier, self).__init__()
        self.lpe_dim = args.lpe_dim
        self.lpe_normalize = args.lpe_normalize
        self.use_xyz = args.use_xyz
        self.use_second_deg = args.use_second_deg
        self.use_lap_reorder = args.use_lap_reorder
        self.lap_eigenvalues_dim = args.lap_eigenvalues_dim

        input_dim = 0
        if self.use_xyz:
            input_dim = 3
        if self.use_second_deg:
            input_dim = 9
        if self.lpe_dim:
            input_dim += self.lpe_dim

        input_size = input_dim * (args.sampled_points + 1)
        if self.lap_eigenvalues_dim:
            input_size += self.lap_eigenvalues_dim

        self.classifier = MLP(
            input_size=input_size,
            num_layers=args.num_mlp_layers,
            num_neurons_per_layer=args.num_neurons_per_layer,
            output_size=args.output_dim,
        )

    def forward(self, x):
        batch_size, _, num_centroids, k_neighbors = x.shape
        x = x.permute(0, 2, 3, 1).reshape(batch_size * num_centroids, k_neighbors, 3)

        if self.use_lap_reorder or self.lpe_dim or self.lap_eigenvalues_dim:
            laplacian = self.create_laplacian(x)
            eigvecs, eigenvals = self.top_k_smallest_eigenvectors(laplacian, self.lpe_dim)
            indices, fixed_eigs = self.sort_by_first_eigenvector(eigvecs)

        data = x
        if self.use_lap_reorder:
            data = torch.gather(x, 1, indices.unsqueeze(2).expand(-1, -1, 3))
            data = transform_point_clouds_to_canonical(data)

        if self.use_second_deg:
            x, y, z = data.unbind(dim=2)
            data = torch.stack([x**2, x*y, x*z, y**2, y*z, z**2, x, y, z], dim=2)

        if self.lpe_dim:
            data = torch.cat([data, fixed_eigs], dim=2)

        data = data.permute(0, 2, 1)

        if self.lap_eigenvalues_dim:
            output = self.classifier(data, eigenvals[:, 1 : 1 + self.lap_eigenvalues_dim])
        else:
            output = self.classifier(data)

        return output.view(batch_size, num_centroids, -1)

    def create_laplacian(self, point_cloud):
        distances = torch.cdist(point_cloud, point_cloud)
        weights = torch.exp(-distances**2)
        column_sums = weights.sum(dim=1)

        if self.lpe_normalize:
            inv_sqrt = torch.rsqrt(column_sums + 1e-7)
            laplacian = (
                torch.eye(weights.size(1), device=weights.device)
                - (inv_sqrt.unsqueeze(2) * weights * inv_sqrt.unsqueeze(1))
            )
        else:
            diag = torch.diag_embed(column_sums)
            laplacian = diag - weights

        return laplacian

    def top_k_smallest_eigenvectors(self, graph, k):
        k = max(k, 1)
        eigenvalues, eigenvectors = torch.linalg.eigh(graph)
        return eigenvectors[:, :, 1:k+1], eigenvalues

    def sort_by_first_eigenvector(self, eigenvectors):
        v = eigenvectors[:, :, 0]
        max_indices = torch.argmax(torch.abs(v), dim=1)
        signs = torch.sign(v[torch.arange(v.shape[0]), max_indices])
        signed = v * signs.unsqueeze(1)

        sorted_indices = torch.argsort(signed[:, 1:])
        sorted_indices = torch.cat(
            [torch.zeros((v.shape[0], 1), device=v.device, dtype=torch.long), 1 + sorted_indices], dim=1
        )

        sorted_eigvecs = torch.gather(
            eigenvectors,
            1,
            sorted_indices.unsqueeze(-1).expand(-1, -1, eigenvectors.size(-1)),
        )

        return sorted_indices, sorted_eigvecs


class MLP(nn.Module):
    def __init__(self, input_size, num_layers, num_neurons_per_layer, output_size):
        super(MLP, self).__init__()
        layers = [nn.Linear(input_size, num_neurons_per_layer)]

        for _ in range(num_layers - 1):
            layers.extend([
                nn.BatchNorm1d(num_neurons_per_layer),
                nn.ReLU(),
                nn.Linear(num_neurons_per_layer, num_neurons_per_layer)
            ])

        layers.extend([
            nn.BatchNorm1d(num_neurons_per_layer),
            nn.ReLU(),
            nn.Linear(num_neurons_per_layer, output_size)
        ])

        self.model = nn.Sequential(*layers)

    def forward(self, x, eigenvals=None):
        x = x.reshape(x.size(0), -1)
        if eigenvals is not None:
            x = torch.cat((x, eigenvals), dim=1)
        return self.model(x)


def transform_point_clouds_to_canonical(point_clouds: torch.Tensor, epsilon: float = 1e-10) -> torch.Tensor:
    m_point = torch.mean(point_clouds, dim=1)
    m_norm = torch.norm(m_point, dim=1, keepdim=True)
    v1 = m_point / (m_norm + epsilon)
    v2 = torch.zeros_like(m_point)
    v2[:, 2] = 1.0

    rotation_axis = torch.cross(v1, v2, dim=1)
    rotation_axis = rotation_axis / (torch.norm(rotation_axis, dim=1, keepdim=True) + epsilon)

    cos_theta = torch.sum(v1 * v2, dim=1)
    sin_theta = torch.sqrt(1 - cos_theta**2 + epsilon)

    K = torch.zeros((point_clouds.shape[0], 3, 3), device=point_clouds.device)
    K[:, 0, 1] = -rotation_axis[:, 2]
    K[:, 0, 2] = rotation_axis[:, 1]
    K[:, 1, 0] = rotation_axis[:, 2]
    K[:, 1, 2] = -rotation_axis[:, 0]
    K[:, 2, 0] = -rotation_axis[:, 1]
    K[:, 2, 1] = rotation_axis[:, 0]

    R1 = (
        torch.eye(3, device=point_clouds.device).unsqueeze(0)
        + sin_theta.unsqueeze(-1).unsqueeze(-1) * K
        + (1 - cos_theta).unsqueeze(-1).unsqueeze(-1) * (K @ K)
    )

    rotated_points = torch.bmm(point_clouds, R1.transpose(1, 2))
    rotated_p = rotated_points[:, -1, :]
    angle = torch.atan2(rotated_p[:, 1], rotated_p[:, 0])

    xy_magnitude = torch.sqrt(rotated_p[:, 0] ** 2 + rotated_p[:, 1] ** 2 + epsilon)
    cos_current = rotated_p[:, 0] / xy_magnitude
    sin_current = rotated_p[:, 1] / xy_magnitude

    cos_phi = cos_current
    sin_phi = -sin_current

    R2 = torch.zeros((point_clouds.shape[0], 3, 3), device=point_clouds.device)
    R2[:, 0, 0] = cos_phi
    R2[:, 0, 1] = -sin_phi
    R2[:, 1, 0] = sin_phi
    R2[:, 1, 1] = cos_phi
    R2[:, 2, 2] = 1.0

    final_points = torch.bmm(rotated_points, R2.transpose(1, 2))
    final_points[:, -1, 1] = 0

    return torch.where(torch.abs(final_points) < epsilon, torch.zeros_like(final_points), final_points)
