import math
from math import pi
from typing import Callable, Optional

import torch
import torch.nn as nn
from torch import Tensor
from torch.optim import SGD
from tqdm import trange


class GridGenerator_S1_S2(nn.Module):
    def __init__(
        self,
        dim: int,
        n: int,
        steps: int = 200,
        step_size: float = 0.01,
        device: torch.device = None,
    ):
        super().__init__()
        self.dim = dim
        self.n = n
        self.steps = steps
        self.step_size = step_size
        self.device = device if device else torch.device("cpu")

    def forward(self) -> torch.Tensor:
        if self.dim == 2:
            return self.generate_s1()
        elif self.dim == 3:
            return self.generate_s2()
        else:
            raise ValueError("Only S1 and S2 are supported.")

    def generate_s1(self) -> torch.Tensor:
        angles = torch.linspace(
            start=0, end=2 * torch.pi - (2 * torch.pi / self.n), steps=self.n
        )
        x = torch.cos(angles)
        y = torch.sin(angles)
        return torch.stack((x, y), dim=1)

    def generate_s2(self) -> torch.Tensor:
        grid = self.random_s2((self.n,), device=self.device)
        return self.repulse(grid)

    def random_s2(self, shape: tuple[int, ...], device: torch.device) -> torch.Tensor:
        x = torch.randn((*shape, 3), device=device)
        return x / torch.linalg.norm(x, dim=-1, keepdim=True)

    def repulse(self, grid: torch.Tensor) -> torch.Tensor:
        grid = grid.clone().detach().requires_grad_(True)
        optimizer = SGD([grid], lr=self.step_size)

        for _ in range(self.steps):
            optimizer.zero_grad()
            dists = torch.cdist(grid, grid, p=2)
            dists = torch.clamp(dists, min=1e-6)  # Avoid division by zero
            energy = dists.pow(-2).sum()  # Simplified Coulomb energy calculation
            energy.backward()
            optimizer.step()

            with torch.no_grad():
                # Renormalize points back to the sphere after update
                grid /= grid.norm(dim=-1, keepdim=True)

        return grid.detach()

    def fibonacci_lattice(
        n: int, offset: float = 0.5, device: Optional[str] = None
    ) -> Tensor:
        """
        Creating ~uniform grid of points on S2 using the fibonacci spiral algorithm.

        Arguments:
            - n: Number of points.
            - offset: Strength for how much points are pushed away from the poles.
                    Default of 0.5 works well for uniformity.
        """
        if n < 1:
            raise ValueError("n must be greater than 0.")

        i = torch.arange(n, device=device)

        theta = (math.pi * i * (1 + math.sqrt(5))) % (2 * math.pi)
        phi = torch.acos(1 - 2 * (i + offset) / (n - 1 + 2 * offset))

        cos_theta, sin_theta = torch.cos(theta), torch.sin(theta)
        cos_phi, sin_phi = torch.cos(phi), torch.sin(phi)

        return torch.stack((cos_theta * sin_phi, sin_theta * sin_phi, cos_phi), dim=-1)








# SO3


# Repulsion functions
def columb_energy(d: Tensor, k: int = 2) -> Tensor:
    return d ** (-k)

def repulse(
    grid: Tensor,
    steps: int = 200,
    step_size: float = 10,
    metric_fn: Callable = lambda x, y: x - y,
    transform_fn: Callable = lambda x: x,
    energy_fn: Callable = columb_energy,
    dist_normalization_constant: float = 1,
    alpha: float = 0.001,
    show_pbar: bool = True,
    in_place: bool = False,
) -> Tensor:
    pbar = trange(steps, disable=not show_pbar, desc="Optimizing")

    grid = grid if in_place else grid.clone()
    grid.requires_grad = True

    optimizer = torch.optim.SGD([grid], lr=step_size)

    for epoch in pbar:
        optimizer.zero_grad(set_to_none=True)

        grid_transform = transform_fn(grid)

        dists = metric_fn(grid_transform[:, None], grid_transform).sort(dim=-1)[0][:, 1:]
        energy_matrix = energy_fn(dists / dist_normalization_constant)

        mean_total_energy = energy_matrix.mean()
        mean_total_energy.backward()
        grid.grad += (
            (steps - epoch) / steps * alpha * torch.randn(grid.grad.shape, device=grid.device)
        )

        optimizer.step()

        pbar.set_postfix_str(f"mean total energy: {mean_total_energy.item():.3f}")

    grid.requires_grad = False

    return grid.detach()

# SO3 functions
def matrix_x(theta: Tensor) -> Tensor:
    """
    Returns rotation matrix around x-axis for given
    angles. New Tensor is created on the device of
    the input tensor (theta).

    Arguments:
        - theta: Tensor containing angles.

    Returns:
        - Tensor of shape '(*theta.shape, 3, 3)' of rotation matrices.
    """
    r = theta.new_empty(*theta.shape, 3, 3)

    cos_theta, sin_theta = torch.cos(theta), torch.sin(theta)

    r[..., 0, 0] = 1
    r[..., 0, 1] = 0
    r[..., 0, 2] = 0

    r[..., 1, 0] = 0
    r[..., 1, 1] = cos_theta
    r[..., 1, 2] = -sin_theta

    r[..., 2, 0] = 0
    r[..., 2, 1] = sin_theta
    r[..., 2, 2] = cos_theta

    return r


def matrix_y(theta: Tensor) -> Tensor:
    """
    Returns rotation matrix around y-axis for given
    angles. New Tensor is created on the device of
    the input tensor (theta).

    Arguments:
        - theta: Tensor containing angles.

    Returns:
        - Tensor of shape '(*theta.shape, 3, 3)' of rotation matrices.
    """
    r = theta.new_empty(*theta.shape, 3, 3)

    cos_theta, sin_theta = torch.cos(theta), torch.sin(theta)

    r[..., 0, 0] = cos_theta
    r[..., 0, 1] = 0
    r[..., 0, 2] = sin_theta

    r[..., 1, 0] = 0
    r[..., 1, 1] = 1
    r[..., 1, 2] = 0

    r[..., 2, 0] = -sin_theta
    r[..., 2, 1] = 0
    r[..., 2, 2] = cos_theta

    return r


def matrix_z(theta: Tensor) -> Tensor:
    """
    Returns rotation matrix around z-axis for given
    angles. New Tensor is created on the device of
    the input tensor (theta).

    Arguments:
        - theta: Tensor containing angles.

    Returns:
        - Tensor of shape '(*theta.shape, 3, 3)' of rotation matrices.
    """
    r = theta.new_empty(*theta.shape, 3, 3)

    cos_theta, sin_theta = torch.cos(theta), torch.sin(theta)

    r[..., 0, 0] = cos_theta
    r[..., 0, 1] = -sin_theta
    r[..., 0, 2] = 0

    r[..., 1, 0] = sin_theta
    r[..., 1, 1] = cos_theta
    r[..., 1, 2] = 0

    r[..., 2, 0] = 0
    r[..., 2, 1] = 0
    r[..., 2, 2] = 1

    return r

def euler_to_matrix(g: Tensor, eps: float = 1e-6) -> Tensor:
    """
    Transforms Euler parameterization to rotation matrices.

    Arguments:
        - g: Tensor of shape `(..., 3)`.
        - eps: Float for preventing numerical instabilities.

    Returns:
        - Tensor of shape `(..., 3, 3)`.
    """
    r = g.new_empty((*g.shape[:-1], 3, 3))

    cos_g = torch.cos(g)
    cos_alpha, cos_beta, cos_gamma = cos_g[..., 0], cos_g[..., 1], cos_g[..., 2]

    sin_g = torch.sin(g)
    sin_alpha, sin_beta, sin_gamma = sin_g[..., 0], sin_g[..., 1], sin_g[..., 2]

    r[..., 0, 0] = cos_alpha * cos_beta * cos_gamma - sin_alpha * sin_gamma
    r[..., 0, 1] = -cos_alpha * sin_gamma - cos_gamma * cos_beta * sin_alpha
    r[..., 0, 2] = cos_gamma * sin_beta

    r[..., 1, 0] = cos_gamma * sin_alpha + cos_beta * cos_alpha * sin_gamma
    r[..., 1, 1] = cos_gamma * cos_alpha - cos_beta * sin_alpha * sin_gamma
    r[..., 1, 2] = sin_gamma * sin_beta

    r[..., 2, 0] = -cos_alpha * sin_beta
    r[..., 2, 1] = sin_beta * sin_alpha
    r[..., 2, 2] = cos_beta

    mask = (r > eps) | (r < -eps)

    return r * mask

def random_quat(shape: tuple[int] | int, device: Optional[str] = None) -> Tensor:
    """
    Uniformly samples SO3 elements parameterized as quaternions.

    Arguments:
        - shape: Int or tuple denoting the shape of the output tensor.
        - device: device on which the new tensor is created.

    Returns:
        - Tensor of shape `(*shape, 4)`.
    """
    shape = shape if type(shape) is tuple else (shape,)

    q = torch.randn(*shape, 4, device=device)
    q = q / torch.linalg.norm(q, dim=-1, keepdim=True)

    return q

def random_euler(shape: tuple[int] | int, device: Optional[str] = None) -> Tensor:
    """
    Uniformly samples SO3 elements parameterized as euler angles (ZYZ).

    Arguments:
        - shape: Int or tuple denoting the shape of the output tensor.
        - device: device on which the new tensor is created.

    Returns:
        - Tensor of shape `(*shape, 3)`.
    """
    return quat_to_euler(random_quat(shape, device=device))

def quat_to_euler(q: Tensor) -> Tensor:
    """
    Converts quaternions to Euler angles.

    Arguments:
        - q: Tensor of shape `(..., 4)`.

    Returns:
        - Tensor of shape `(..., 3)`.
    """
    return matrix_to_euler(quat_to_matrix(q))

def matrix_to_euler(r: Tensor, eps: float = 1e-5, no_warn: bool = False) -> Tensor:
    """
    Transforms rotation matrices to euler angles. When a gimble lock is
    detected, the third euler angle (gamma) will be set to zero. Note
    that this will still result in the correct rotation.

    Adapted from 'scipy.spatial.transform._rotation.pyx/_euler_from_matrix'.

    Arguments:
        - r: Tensor of shape `(..., 3, 3)`.
        - eps: Float for preventing numerical instabilities.
        - no_warn: Bool to display gimble lock warning, default
                   `no_warn = False`.

    Returns:
        - Tensor of shape `(..., 3)`.
    """
    g = r.new_empty((*r.shape[:-2], 3))

    # step 1, 2
    c = r.new_tensor([[0.0, 1.0, 0.0], [-1.0, 0.0, 0.0], [0.0, 0.0, 1.0]])

    # step 3
    res = torch.matmul(c, r)
    matrix_trans = torch.matmul(res, c.T)

    # step 4
    matrix_trans[:, 2, 2] = torch.clamp(matrix_trans[:, 2, 2], -1, 1)

    g[:, 1] = torch.acos(matrix_trans[:, 2, 2])

    # step 5, 6
    safe1 = torch.abs(g[:, 1]) >= eps
    safe2 = torch.abs(g[:, 1] - pi) >= eps
    safe = safe1 & safe2

    not_safe1 = ~safe1
    not_safe2 = ~safe2

    # step 5b
    g[safe, 2] = torch.atan2(matrix_trans[safe, 0, 2], -matrix_trans[safe, 1, 2])
    g[safe, 0] = torch.atan2(matrix_trans[safe, 2, 0], matrix_trans[safe, 2, 1])

    g[~safe, 2] = 0
    g[not_safe1, 0] = torch.atan2(
        matrix_trans[not_safe1, 1, 0] - matrix_trans[not_safe1, 0, 1],
        matrix_trans[not_safe1, 0, 0] + matrix_trans[not_safe1, 1, 1],
    )
    g[not_safe2, 0] = -torch.atan2(
        matrix_trans[not_safe2, 1, 0] + matrix_trans[not_safe2, 0, 1],
        matrix_trans[not_safe2, 0, 0] - matrix_trans[not_safe2, 1, 1],
    )

    # step 7
    adjust_and_safe = ((g[:, 1] < 0) | (g[:, 1] > pi)) & safe

    g[adjust_and_safe, 0] -= pi
    g[adjust_and_safe, 1] *= -1
    g[adjust_and_safe, 2] += pi

    g[g < -pi] += 2 * pi
    g[g >= pi] -= 2 * pi

    if not no_warn and not torch.any(safe):
        warnings.warn(
            "Gimbal lock detected. Setting third angle to zero "
            "since it is not possible to uniquely determine "
            "all angles."
        )

    return g

def quat_to_matrix(q: Tensor) -> Tensor:
    """
    Transforms quaternions to rotation matrices.

    Arguments:
        - q: Tensor of shape `(..., 4)`.

    Returns:
        - Tensor of shape `(..., 3, 3)`.
    """
    r = q.new_empty((*q.shape[:-1], 3, 3))

    w, x, y, z = q[..., 0], q[..., 1], q[..., 2], q[..., 3]

    x2, y2, z2, w2 = x * x, y * y, z * z, w * w

    xy, zw, xz = x * y, z * w, x * z
    yw, yz, xw = y * w, y * z, x * w

    r[..., 0, 0] = x2 - y2 - z2 + w2
    r[..., 0, 1] = 2 * (xy - zw)
    r[..., 0, 2] = 2 * (xz + yw)

    r[..., 1, 0] = 2 * (xy + zw)
    r[..., 1, 1] = -x2 + y2 - z2 + w2
    r[..., 1, 2] = 2 * (yz - xw)

    r[..., 2, 0] = 2 * (xz - yw)
    r[..., 2, 1] = 2 * (yz + xw)
    r[..., 2, 2] = -x2 - y2 + z2 + w2

    return r

# def matrix_to_quat(r: Tensor) -> Tensor:
#     """
#     Transforms rotation matrices to quaternions.

#     Arguments:
#         - r: Tensor of shape `(..., 3, 3)`.

#     Returns:
#         - Tensor of shape `(..., 4)`.
#     """
#     batch_shape = r.shape[:-2]
#     q = r.new_empty(*batch_shape, 4)

#     decision = torch.diagonal(r, 0, dim1=-2, dim2=-1)
#     decision = torch.cat((decision, torch.sum(decision, dim=-1, keepdim=True)), dim=-1)

#     i = torch.argmax(decision, dim=-1)
#     j = (i + 1) % 3
#     k = (j + 1) % 3
#     c_m = i != 3
#     nc_m = ~c_m

#     # Create indices for advanced indexing
#     batch_indices = torch.arange(r.shape[0], device=r.device).reshape(-1, 1)
#     if len(batch_shape) > 1:
#         for dim_size in batch_shape[1:]:
#             batch_indices = batch_indices.unsqueeze(-1).expand(-1, dim_size)

#     c_i = i[c_m]
#     c_j = j[c_m]
#     c_k = k[c_m]

#     # Using advanced indexing to assign values
#     q[c_m, c_i + 1] = 1 - decision[c_m, 3] + 2 * r[c_m, c_i, c_i]
#     q[c_m, c_j + 1] = r[c_m, c_j, c_i] + r[c_m, c_i, c_j]
#     q[c_m, c_k + 1] = r[c_m, c_k, c_i] + r[c_m, c_i, c_k]
#     q[c_m, 0] = r[c_m, c_k, c_j] - r[c_m, c_j, c_k]

#     q[nc_m, 1] = r[nc_m, 2, 1] - r[nc_m, 1, 2]
#     q[nc_m, 2] = r[nc_m, 0, 2] - r[nc_m, 2, 0]
#     q[nc_m, 3] = r[nc_m, 1, 0] - r[nc_m, 0, 1]
#     q[nc_m, 0] = 1 + decision[nc_m, 3]

#     return q / torch.linalg.norm(q, dim=-1, keepdim=True)


def matrix_to_quat(r: Tensor) -> Tensor:
    """
    Transforms rotation matrices to a quaternions.

    Arguments:
        - r: Tensor of shape `(..., 3, 3)`.

    Returns:
        - Tensor of shape `(..., 4)`.
    """
    q = r.new_empty(*r.shape[:-2], 4)

    decision = torch.diagonal(r, 0, dim1=1, dim2=2)
    decision = torch.cat((decision, torch.sum(decision, dim=-1, keepdim=True)), dim=-1)

    i = torch.argmax(decision, dim=-1)
    j = (i + 1) % 3
    k = (j + 1) % 3
    c_m = i != 3
    nc_m = ~c_m

    c_i = i[c_m]
    c_j = j[c_m]
    c_k = k[c_m]

    q[c_m, c_i + 1] = 1 - decision[c_m, 3] + 2 * r[c_m, c_i, c_i]
    q[c_m, c_j + 1] = r[c_m, c_j, c_i] + r[c_m, c_i, c_j]
    q[c_m, c_k + 1] = r[c_m, c_k, c_i] + r[c_m, c_i, c_k]
    q[c_m, 0] = r[c_m, c_k, c_j] - r[c_m, c_j, c_k]

    q[nc_m, 1] = r[nc_m, 2, 1] - r[nc_m, 1, 2]
    q[nc_m, 2] = r[nc_m, 0, 2] - r[nc_m, 2, 0]
    q[nc_m, 3] = r[nc_m, 1, 0] - r[nc_m, 0, 1]
    q[nc_m, 0] = 1 + decision[nc_m, 3]

    return q / torch.linalg.norm(q, dim=-1, keepdim=True)

@torch.jit.script
def geodesic_distance(
    qx: torch.Tensor, qy: torch.Tensor, eps: float = 1e-7
) -> torch.Tensor:
    """
    Calculates the geodesic distance between quaternions `qx` and `qy`.
    Usual rules of broadcasting apply.

    Arguments:
        - qx, qy: Tensors of shape `(..., 4)`.
        - eps: Float for preventing numerical instabilities.

    Returns:
        - Tensor of shape `(...)`.
    """
    return torch.acos(torch.clamp((qx * qy).sum(-1).abs(), -1 + eps, 1 - eps))

# Main function to create uniform grid of 3D rotation matrices
def uniform_grid_so3(
    size: int,
    steps: int = 200,
    step_size: Optional[float] = None,
    show_pbar: bool = True,
    device: Optional[str] = None,
    cache_grid: bool = False
) -> Tensor:
    if size == 0:
        return torch.Tensor([], device=device)

    step_size = step_size if step_size is not None else size ** (1 / 3)

    # Generate a random grid of Euler angles
    grid = random_euler(size, device=device)

    # Perform repulsion to obtain a uniform distribution
    grid = repulse(
        grid,
        steps=steps,
        step_size=step_size,
        alpha=0.001,
        metric_fn=geodesic_distance,
        transform_fn=euler_to_matrix,
        dist_normalization_constant=torch.tensor([torch.pi / 2]),
        show_pbar=show_pbar,
        in_place=True,
    )

    # Convert the final Euler angles to rotation matrices
    uniform_grid = euler_to_matrix(grid)

    return uniform_grid


class GridGenerator_SO3(nn.Module):
    def __init__(
        self,
        dim: int,
        n: int,
        steps: int = 200,
        step_size: float = 0.01,
        device: torch.device = None,
    ):
        super().__init__()
        self.dim = dim
        self.n = n
        self.steps = steps
        self.step_size = step_size
        self.device = device if device else torch.device("cpu")

    def forward(self) -> torch.Tensor:
        return uniform_grid_so3(self.n, self.steps, self.step_size, show_pbar = True, device = self.device, cache_grid = False)