import torch
import torch.nn.functional as F
import torch.nn as nn


class TomographicReconstruction(nn.Module):
    def __init__(self, prior, use_sigmoid: bool, sigmoid_alpha: float = 5.):
        super().__init__()
        self.computed_image = None
        self.device = prior.device
        if len(prior.shape) == 2:
            prior = prior.unsqueeze(0)
        self.bs = prior.shape[0]
        w = prior.shape[1]
        cp = torch.cartesian_prod(torch.arange(w), torch.arange(w))
        self.zero_indices = torch.where((cp[:, 0] - w / 2) ** 2 + (cp[:, 1] - w / 2) ** 2 > (w / 2) ** 2)[0]

        self.use_sigmoid = use_sigmoid
        self.sigmoid_alpha = sigmoid_alpha
        self.prior = prior.clone()
        self.img = nn.Parameter(torch.zeros_like(prior, device=self.device))


    def get_img(self, filter=None, circle_crop=False, mle=False):
        if self.use_sigmoid:
            image = self.sigmoid_alpha * (self.prior - 0.5) + self.img
            image = torch.sigmoid(image)
        else:
            image = self.prior + self.img

        if filter is not None:
            image = filter(image)

        if circle_crop:
            image.view(-1)[self.zero_indices] = 0.

        if mle:
            image = (image >= 0.5).float()
        image = image if len(image) > 1 else image[0]
        self.computed_image = image
        return image

    def forward(self, sinogram_angles=None, filter=None, filter_in_sinogram_space: bool = False):
        if not filter_in_sinogram_space:
            images = self.get_img(filter=filter)
        else:
            images = self.get_img(filter=None)

        if sinogram_angles is None:
            sinogram_angles = torch.tensor([*range(0, 180)]).to(self.device)

        rotation_matrix = torch.stack([
            torch.stack([torch.cos(torch.deg2rad(sinogram_angles)), -torch.sin(torch.deg2rad(sinogram_angles)),
                         torch.zeros_like(sinogram_angles, device=self.device)], 1),
            torch.stack([torch.sin(torch.deg2rad(sinogram_angles)), torch.cos(torch.deg2rad(sinogram_angles)),
                         torch.zeros_like(sinogram_angles, device=self.device)], 1)
        ], 1)
        current_grid = F.affine_grid(rotation_matrix.to(images.device),
                                     self.img[0].repeat(len(sinogram_angles), 1, 1, 1).size(), align_corners=False)

        rotated = F.grid_sample(images.repeat(len(sinogram_angles), 1, 1, 1).float(), current_grid.repeat(1, 1, 1, 1),
                                align_corners=False)
        rotated = rotated.transpose(0, 1)
        # Sum over one of the dimensions to compute the projection
        sinogram = rotated.sum(axis=-2).squeeze(2)
        sinogram = sinogram if len(sinogram) > 1 else sinogram[0]
        if not filter_in_sinogram_space:
            return sinogram
        else:
            return filter(sinogram)

    def get_mle_tr(self):
        return TomographicReconstruction(self.get_img(mle=True), use_sigmoid=False)



if __name__ == '__main__':
    import lovely_tensors as lt

    lt.monkey_patch()

    # tr = TomographicReconstruction(prior=torch.randn(512, 512), use_sigmoid=True)
    # sinogram = tr.forward(sinogram_angles=torch.tensor([*range(180)]))
    # print(sinogram)
    # mse = nn.MSELoss()
    # loss = mse(sinogram, torch.ones_like(sinogram))
    # loss.backward()

    device = torch.device('cpu')
    angles = torch.tensor([*range(180)])
    image = torch.zeros(10, 512, 512)
    image[:, ::10, :] = 1
    image[:, :, ::10] = 1
    rotation_matrix = torch.stack([
        torch.stack([
            torch.cos(torch.deg2rad(angles)),
            -torch.sin(torch.deg2rad(angles)),
            torch.zeros_like(angles, device=device)
        ], 1),
        torch.stack([
            torch.sin(torch.deg2rad(angles)),
            torch.cos(torch.deg2rad(angles)),
            torch.zeros_like(angles, device=device)
        ], 1)
    ], 1)
    current_grid = F.affine_grid(
        rotation_matrix.to(device),
        (len(angles), *image.shape),
        align_corners=False
    )

    rotated_image = F.grid_sample(image.repeat(len(angles), 1, 1, 1).float(), current_grid.repeat(1, 1, 1, 1),
                                  align_corners=False)
    print(rotated_image)
