import torch.nn.functional as F
import torch
import numpy as np
import cv2
from skimage.transform import iradon
from torchvision.transforms.functional import rotate
from chip.utils.utils import get_uniform_angles

class Sinogram:

    def __init__(self, sinogram, angles):
        assert(sinogram is not None)
        assert (angles is not None)

        self.angles = angles
        self.sinogram = sinogram

        if self.angles is not None:
            self._sort()

    def add(self, angles, observations):
        assert (len(angles.shape) == 1) and (len(observations.shape) == 2)
        assert angles.shape[0] == observations.shape[0]

        if self.angles is None:
            self.angles = angles
            self.sinogram = observations
        else:
            self.angles = torch.cat([self.angles, angles])
            self.sinogram = torch.cat([self.sinogram, observations])

        self._sort()

    def to(self, device):
        self.sinogram = self.sinogram.to(device)
        self.angles = self.angles.to(device)
        return self

    def __getitem__(self, key):
        return self.sinogram[key]

    def _sort(self):
        with torch.no_grad():
            # sort jointly
            self.angles, indices = torch.sort(self.angles)
            self.sinogram = self.sinogram[indices]

    @property
    def shape(self):
        if self.sinogram is None:
            return None
        return self.sinogram.shape

    def __len__(self):
        if self.angles is None:
            return 0
        return len(self.angles)


def batched_sinogram(images, sinogram_angles=None, theta=180):
    device = images.device
    sinogram_angles = sinogram_angles.to(device)
    unsqueezed = False
    if len(images.shape) == 2:
        images = images.unsqueeze(0)
        unsqueezed = True

    # Default is 180 angles from 0 to 179
    if sinogram_angles is None:
        sinogram_angles = torch.linspace(0., 179., theta)

    rotation_matrix = torch.stack([
        torch.stack([torch.cos(torch.deg2rad(sinogram_angles)), -torch.sin(torch.deg2rad(sinogram_angles)),
                     torch.zeros_like(sinogram_angles)], 1).to(device),
        torch.stack([torch.sin(torch.deg2rad(sinogram_angles)), torch.cos(torch.deg2rad(sinogram_angles)),
                     torch.zeros_like(sinogram_angles)], 1).to(device)
    ], 1)
    current_grid = F.affine_grid(rotation_matrix.to(images.device),
                                 images.repeat(len(sinogram_angles), 1, 1, 1).size(), align_corners=False)

    rotated = F.grid_sample(images.repeat(len(sinogram_angles), 1, 1, 1).float(), current_grid.repeat(1, 1, 1, 1),
                            align_corners=False)
    rotated = rotated.transpose(0, 1)
    # Sum over one of the dimensions to compute the projection
    sinogram = rotated.sum(axis=-2).squeeze(2)
    sinogram = sinogram[0] if unsqueezed else sinogram
    return sinogram


def compute_sinogram(image, angles=None):
    return batched_sinogram(image, sinogram_angles=angles)


def numpy_radon(images, angles):
    """
        Direct implementation of radon transform in numpy. Faster than the torch version below.

        :param images:
        :param angles:
        :return: (images.shape[0], image.shape[1], len(angles))
    """
    width = images.shape[1]
    center = width // 2
    projections = np.zeros(shape=(len(angles), len(images), width))
    for k, angle in enumerate(np.deg2rad(angles)):
        # radon transform using torchvision
        cos_a, sin_a = np.cos(-angle), np.sin(-angle)
        R = np.array([[cos_a, sin_a, -center * (cos_a + sin_a - 1)],
                      [-sin_a, cos_a, -center * (cos_a - sin_a - 1)]])

        for j, img in enumerate(images):
            projections[k, j] = cv2.warpAffine(img, R, img.shape).sum(0)

    return projections.swapaxes(0, 1).swapaxes(1, 2)

def sklearn_fbp(sinogram, sinogram_angles):
    filter_name = 'shepp-logan'
    num_angles = len(sinogram)
    reconstruction_fbp = iradon(sinogram.detach().numpy().T, theta=-sinogram_angles.numpy(), filter_name=filter_name)
    return torch.tensor(reconstruction_fbp)


def torch_radon(images, angles):
    """
    Torch implementation of radon transform. Slower than the numpy version above.
    Batched across images, but not across angles.

    :param images:
    :param angles:
    :return: (images.shape[0], image.shape[1], len(angles))
    """

    width = images.shape[1]
    projections = np.zeros(shape=(len(angles), len(images),  width))
    for k, angle in enumerate(angles):
        # radon transform using torchvision
        projections[k] = rotate(images, angle=angle).sum(2)
    return projections.swapaxes(0, 1).swapaxes(1, 2)


def get_observation_matrices(input_shape, forward_model, params):
    """
    :param input_shape: (height, width)
    :param forward_model: callable that takes a parameter and returns an m-dimensional observation
    :param params: list of n=len(params) parameters passed to the forward model
    :return: (n, *input_shape, m) projection matrices
    """

    dim = np.prod(input_shape)
    input = np.eye(dim).reshape((-1, *input_shape))
    output = forward_model(input, params)
    return output.swapaxes(1,2).swapaxes(0,1)


