import torch
import torch.nn.functional as F
import torch.nn as nn
from chip.models.gaussian_splatting import GaussianSplatter
class GaussianReconstruction(nn.Module):
    def __init__(self, prior, use_sigmoid: bool, num_gaussians:int, min_scale = 0.1, max_scale = 10, distribution=None):
        super(GaussianReconstruction, self).__init__()
        self.device = prior.device

        self.use_sigmoid = use_sigmoid
        self.prior = prior.clone()
        self.img = GaussianSplatter(prior.shape, num_gaussians, distribution=distribution, min_scale=min_scale, max_scale=max_scale)

    def get_img(self, degree=None, filter=None,  circle_crop=False, mle=False, ):
        if self.use_sigmoid:
            image = self.img(degree)
            image = torch.sigmoid(image)
        else:
            image = self.img(degree)

        if filter is not None:
            image = filter(image)

        image = image if len(image) > 1 else image[0]
        return image

    def forward(self, degree=None, sinogram_angles=None, filter=None, filter_in_sinogram_space: bool = False):
        if not filter_in_sinogram_space:
            images = self.get_img(degree, filter=filter)
        else:
            images = self.get_img(degree, filter=None)

        self.computed_image = images

        if sinogram_angles is None:
            sinogram_angles = torch.tensor([*range(0, 180)]).to(self.device)

        rotation_matrix = torch.stack([
            torch.stack([torch.cos(torch.deg2rad(sinogram_angles)), -torch.sin(torch.deg2rad(sinogram_angles)),
                         torch.zeros_like(sinogram_angles, device=self.device)], 1),
            torch.stack([torch.sin(torch.deg2rad(sinogram_angles)), torch.cos(torch.deg2rad(sinogram_angles)),
                         torch.zeros_like(sinogram_angles, device=self.device)], 1)
        ], 1)
        current_grid = F.affine_grid(rotation_matrix.to(images.device),
                                     self.prior[0].repeat(len(sinogram_angles), 1, 1, 1).size(), align_corners=False)

        rotated = F.grid_sample(images.repeat(len(sinogram_angles), 1, 1, 1).float(), current_grid.repeat(1, 1, 1, 1),
                                align_corners=False)
        rotated = rotated.reshape(1, len(sinogram_angles), *self.prior.shape[1:])
        # Sum over one of the dimensions to compute the projection
        sinogram = rotated.sum(axis=-2).squeeze(2)
        sinogram = sinogram if len(sinogram) > 1 else sinogram[0]
        if not filter_in_sinogram_space:
            return sinogram
        else:
            return filter(sinogram)

if __name__ == '__main__':
    import lovely_tensors as lt

    lt.monkey_patch()

    tr = GaussianReconstruction(prior=torch.randn(100, 512, 512), num_gaussians=100**2, use_sigmoid=True)
    tr.img.clean_gaussians()
    sinogram = tr.forward(sinogram_angles=torch.tensor([*range(180)]))
    print(sinogram)
    mse = nn.MSELoss()
    loss = mse(sinogram, torch.ones_like(sinogram))
    loss.backward()
