import torch
import torch.nn as nn
from torch.nn.functional import normalize


def rodrigues_rotations(rodrigues_vector, alpha):
    """
    This function uses the Rodrigues forumla to create rotations around the given rodrigues_vector with the given angle alpha
    @param rodrigues_vector: The vector around which to create the rotation matrices
    @param alpha: The angle od the rotation
    @param angle_resolution: The number of samples to be generated.
    @return: A tensor of bs, angle_resolution, 3, 3 containing bs * angle resolution rotation matrices, each rotating the corresponding alpha degrees given
    """
    dtype = rodrigues_vector.dtype
    ip = normalize(rodrigues_vector)
    new_ip = ip.flip(1)
    new_ip[:, 1] *= -1
    A = to_skew_symmetric(new_ip).permute(0, 2, 1).to(ip.device, dtype=dtype)
    AA = torch.bmm(A, A)
    first = torch.diag(torch.ones(3, dtype=dtype, device=ip.device))
    second = A * torch.sin(alpha).to(dtype=dtype)[:, None, None]
    third = AA * (1 - torch.cos(alpha).to(dtype=dtype))[:, None, None]
    R = first.expand(second.shape) + second + third
    return R

def to_skew_symmetric(x):
    """
    Takes a tensor bs, 3 of vectors in R^3 and creates the skew symmetric matrix corresponding to them.
    @param x: The input vector
    @return: The tensor bs, 3, 3 with the skew symmetric matrices corresponding to x
    """
    out = torch.zeros(x.shape[0], 3, 3, device=x.device, dtype=x.dtype)
    out[:, 0, 1] = x[:, 0]
    out[:, 0, 2] = x[:, 1]
    out[:, 1, 0] = -x[:, 0]
    out[:, 1, 2] = x[:, 2]
    out[:, 2, 0] = -x[:, 1]
    out[:, 2, 1] = -x[:, 2]
    return out


def gaussian_2d(x, center, density, scale, rotation_3d, projection_3D):
    """Compute the value of a 2D Gaussian with center at (0,0) with given parameters."""
    # Apply rotation

    scaling = torch.diag(scale).to(center.device)
    x_c = x - center @ projection_3D.T
    Z_inv = torch.inverse(rotation_3d @ scaling)
    xZ = x_c @ (projection_3D @ Z_inv)
    return density * torch.exp(-0.5 * (xZ).unsqueeze(1) @ xZ.unsqueeze(-1)).squeeze(-1, -2)


b_gaussian_2d = torch.vmap(gaussian_2d)


def sample_from_3D_distribution(dist, num_samples):
    # Ensure the tensor is on the same device (CPU/GPU) before operations
    dist = dist.clone().to('cpu')

    # Flatten and normalize the tensor
    dist_flat = dist.view(-1)
    dist_flat /= dist_flat.sum()

    # Sample from the flattened distribution
    sampled_index = torch.multinomial(dist_flat, num_samples=num_samples)

    # Convert back to 2D coordinates
    n = dist.size(0)
    sampled_x = sampled_index // n
    sampled_y = sampled_index % n

    return torch.stack([sampled_x, sampled_y], -1).float()


class GaussianSplatter(nn.Module):
    def __init__(self, shape, num_gaussians, distribution=None, min_scale=0.1, max_scale=5):
        super(GaussianSplatter, self).__init__()
        self.bounding_box = None
        shape = shape if isinstance(shape, torch.Tensor) else torch.tensor(shape)
        self.register_buffer("shape", shape)
        self.max_scale = max_scale
        self.min_scale = min_scale
        self.num_gaussians = num_gaussians

        if distribution is None:
            grid = torch.cartesian_prod(
                torch.linspace(0, 1, int((num_gaussians / shape[-1])**0.5)),
                torch.linspace(0, 1, int((num_gaussians / shape[-1])**0.5)),
                torch.linspace(0, 1, shape[-1])
            )

            centers = torch.cat([
                grid,
                torch.rand((num_gaussians - len(grid)), 3)
            ], 0)
        else:
            # we want the range of values between 0 and 1 to apply inverse of sigmoid
            centers = sample_from_3D_distribution(distribution, num_gaussians) / self.shape[None, :]

        logit = lambda y: torch.log(y / (1 - y))
        centers = logit(centers)

        self.centers_ = nn.Parameter(
            centers
        )

        self.density_ = nn.Parameter(
            torch.zeros(num_gaussians, 1)
        )

        self.scale_ = nn.Parameter(
            torch.zeros(num_gaussians, 3).float()
        )

        self.rotation = nn.Parameter(
            torch.pi * torch.rand(num_gaussians, 1)
        )

        self.rodrigues = nn.Parameter(
            torch.rand(num_gaussians, 3)
        )

        self.sigmoid = nn.Sigmoid()

    def centers(self):
        return self.shape[None, :] * (self.sigmoid(self.centers_) - 0.5)

    def density(self):
        return self.sigmoid(self.density_)

    def scale(self):
        # We enforce Gaussains to have a minimum size and a maximum size
        return self.min_scale + self.max_scale * self.sigmoid(self.scale_)

    def clean_gaussians(self):
        with torch.no_grad():
            # remove points that have moved outside the field of view


            # We first remove transparent gaussians
            alive_mask = self.density().reshape(-1) > 0.1
            alive_mask &= torch.max(self.scale(), dim=-1)[0] > 0.05

            self.centers_ = nn.Parameter(self.centers_.data[alive_mask].clone())
            self.density_ = nn.Parameter(self.density_.data[alive_mask].clone())
            self.scale_ = nn.Parameter(self.scale_.data[alive_mask].clone())
            self.rotation = nn.Parameter(self.rotation.data[alive_mask].clone())
            self.num_gaussians = len(self.centers_)

    def span_new_gaussians(self):
        pass
        # with torch.no_grad():
        #     # We focus on gaussians with large gradients
        #     k = len(self.centers_) // 5
        #     _, indices_center = torch.topk(torch.abs(torch.sum(self.centers_.grad, dim=-1)), k)
        #     scale_grad = torch.cat([self.scale_x_.grad, self.scale_y_.grad], -1)
        #     _, indices_scale = torch.topk(torch.abs(torch.sum(scale_grad, dim=-1)), k)
        #
        #     indices = torch.unique(torch.cat([indices_center, indices_scale]))
        #     # we reset the scale of the new and old gaussians
        #     self.scale_x_[indices] *= 0
        #     self.scale_y_[indices] *= 0
        #
        #     device = self.centers_.device
        #     new_centers = self.centers_[indices].clone() + 0.1
        #     new_density = torch.zeros(len(new_centers), 1).to(device)
        #     new_scale_x = torch.zeros(len(new_centers), 1).to(device)
        #     new_scale_y = torch.zeros(len(new_centers), 1).to(device)
        #     new_rotation = torch.pi * torch.rand(len(new_centers), 1).to(device)
        #
        #     self.centers_ = nn.Parameter(torch.cat([self.centers_.data, new_centers], 0))
        #     self.density_ = nn.Parameter(torch.cat([self.density_.data, new_density], 0))
        #     self.scale_x_ = nn.Parameter(torch.cat([self.scale_x_.data, new_scale_x], 0))
        #     self.scale_y_ = nn.Parameter(torch.cat([self.scale_y_.data, new_scale_y], 0))
        #     self.rotation = nn.Parameter(torch.cat([self.rotation.data, new_rotation], 0))
        #     self.num_gaussians = len(self.centers_)

    def compute_pixels_values(self, mask, bounding_box, non_zero, projection_3D):
        cp = torch.cartesian_prod(
            torch.arange(-bounding_box // 2, bounding_box // 2),
            torch.arange(-bounding_box // 2, bounding_box // 2)
        ).to(self.centers_.device)
        pixels = cp.repeat(non_zero, 1, 1)

        rotation_3d = rodrigues_rotations(self.rodrigues[mask], self.rotation[mask].reshape(-1))

        gaussian_2d(
            pixels[0],
            self.centers()[mask][0] % 1,
            self.density()[mask][0],
            self.scale()[mask][0],
            rotation_3d[0],
            projection_3D
        )

        values = b_gaussian_2d(
            pixels,
            self.centers()[mask] % 1,
            self.density()[mask],
            self.scale()[mask],
            rotation_3d,
            projection_3D.expand(len(pixels), *projection_3D.shape)
        )
        offset = ((self.centers()[mask] @ projection_3D.T) // 1).int()
        return pixels + offset[:, None, :], values

    def render_image(self, pixels, values, img):

        unique_mask, inverse_indices = torch.unique(pixels, sorted=True, dim=0, return_inverse=True)
        # Aggregate the values for each unique coordinate
        aggregated_values = torch.zeros(unique_mask.shape[0], dtype=values.dtype).to(pixels.device)
        aggregated_values.scatter_add_(0, inverse_indices, values)

        img[unique_mask[:, 0], unique_mask[:, 1]] += aggregated_values

    def forward(self, degree, patch_shape=None, img_scale=1.):
        """
        Compute the rasterization of the 2D gaussians in this model
        :param degree: The direction onto which to project
        :param patch: The rectangular patch to be computed, parametrized as (top_left_x, top_left_y, width, height)
        :return: The rasterization of the image patch, if patch is None, all the image
        """
        # we first find the buckets of the given pixels (bs, 2)
        device = self.centers_.device

        if patch_shape is None:
            patch_shape = torch.tensor([0, 0, *self.shape[1:]]).to(device)
        elif not isinstance(patch_shape, torch.Tensor):
            patch_shape = torch.tensor(patch_shape).to(device)

        if not isinstance(degree, torch.Tensor):
            degree = torch.tensor(degree, device=device)

        if degree is not None:
            theta = -1 * torch.pi * degree / 180.
        else:
            theta = torch.tensor(0.1, device=device)

        # We compute the 3D rotation and projection to 2D space
        with torch.no_grad():
            cos_theta = torch.cos(theta)
            sin_theta = torch.sin(theta)
            # rotation in 3D
            R3D = torch.stack([
                torch.stack([cos_theta, -sin_theta, torch.tensor(0, device=device)]),
                torch.stack([sin_theta, cos_theta, torch.tensor(0, device=device)]),
                torch.tensor([0, 0, 1.], device=device)
            ]).to(device)
            # projection to 2D
            P2D = torch.diag(torch.FloatTensor([0, 1, 1]))[1:, :].to(device)

            projection_3D = P2D @ R3D


        # we compute the mask of pixels that corresponds to the patch
        # to that we first project down the centers of the gaussians to the image
        projected_centers = self.centers() @ projection_3D.T
        projected_centers += torch.tensor([self.shape[1] // 2, self.shape[2] // 2 ], device=device)
        # The mask comprises the projected centers that are bounding_box_size away from the patch
        in_patch_mask = projected_centers[:, 0] >= patch_shape[0]
        in_patch_mask &= projected_centers[:, 0] <= patch_shape[0] + patch_shape[2]
        in_patch_mask &= projected_centers[:, 1] >= patch_shape[1]
        in_patch_mask &= projected_centers[:, 1] <= patch_shape[1] + patch_shape[3]

        bounding_box_size = 4 * torch.max(self.scale(), dim=-1)[0].reshape(-1)

        masks = [in_patch_mask]
        limits = [int(torch.max(bounding_box_size).item())]
        non_zero = [torch.count_nonzero(m).item() for m in masks]
        self.bounding_box = non_zero

        img = torch.zeros(int(patch_shape[2].item()), int(patch_shape[3].item()), device=device)
        for bb, mask, nz in zip(limits, masks, non_zero):
            if nz > 0:
                # We compute the pixel values of each gaussian in a bounding box around their center
                pixels, values = self.compute_pixels_values(mask, bb, nz, projection_3D)
                pixels = pixels.reshape(-1, 2)
                pixels += self.shape[None, 1:] // 2
                values = values.reshape(-1)
                # the unique_mask might contain values that are either negative,
                # or larger than the size of the image
                in_bounds = (pixels[:, 0] < patch_shape[0] + patch_shape[2]) & (pixels[:, 0] >= patch_shape[0])
                in_bounds &= (pixels[:, 1] < patch_shape[1] + patch_shape[3]) & (pixels[:, 1] >= patch_shape[1])

                self.render_image(pixels[in_bounds] - patch_shape[:2], values[in_bounds], img=img)
        return img


if __name__ == '__main__':
    import lovely_tensors as lt

    lt.monkey_patch()
    # Test the model
    w = 512
    tr = GaussianSplatter([100, 100, 80], 5000)
    img = tr.forward(degree=0, patch_shape=[20, 10, 25, 50])
    tr.clean_gaussians()
    print(img)

