from torch.nn import functional as F

from chip.utils.fourier import fft_2D, ifft_2D, b_fft_2D, b_ifft_2D
import torch
import math
from torch.nn.functional import normalize



def make_affine(R):
    return torch.cat([R, torch.zeros(R.shape[0], R.shape[1], 1, device=R.device)], -1)


def rotations_z_axis(projection_angles):
    if type(projection_angles) != torch.Tensor:
        projection_angles = torch.tensor(projection_angles)
    device = projection_angles.device

    bs = len(projection_angles)

    alpha_rad = torch.deg2rad(projection_angles)

    # Create rotation matrices for the z-axis
    cos_alpha = torch.cos(alpha_rad).to(device)
    sin_alpha = torch.sin(alpha_rad).to(device)
    Rz = torch.zeros(bs, 3, 3).to(device)
    Rz[:, 0, 0] = cos_alpha
    Rz[:, 0, 1] = sin_alpha
    Rz[:, 1, 0] = -sin_alpha
    Rz[:, 1, 1] = cos_alpha
    Rz[:, 2, 2] = 1

    return make_affine(Rz)


def compute_grid_y_axis(beta, W):
    """
    Computes a grid to be used by A grid to be used by torch.nn.functional.grid_sample of
    a W^3 volume rotated by beta degrees around the y-axis
    Args:
        beta: Angle of rotation of the sample
        W: Dimensions of the cubic volume

    Returns: A grid to be used by torch.nn.functional.grid_sample

    """
    if type(beta) != torch.Tensor:
        beta = torch.tensor([beta])
    beta_rad = torch.deg2rad(beta)

    # Create rotation matrices for the y-axis
    cos_beta = torch.cos(beta_rad)
    sin_beta = torch.sin(beta_rad)
    Ry = torch.zeros(1, 3, 3)
    Ry[:, 0, 0] = cos_beta
    Ry[:, 0, 2] = sin_beta
    Ry[:, 1, 1] = 1
    Ry[:, 2, 0] = -sin_beta
    Ry[:, 2, 2] = cos_beta

    affine_Ry = make_affine(Ry)
    return torch.nn.functional.affine_grid(affine_Ry, (1, 1, W, W, W), align_corners=False)

_grid_y_axis = None

def laminography_projection(projection_angles:torch.Tensor, batched_object:torch.Tensor, beta):
    device = batched_object.device
    global _grid_y_axis
    if _grid_y_axis is None:
        _grid_y_axis = compute_grid_y_axis(beta, batched_object.shape[-1]).to(device)

    if type(beta) != torch.Tensor:
        beta = torch.tensor(beta)

    beta = beta.to(device)
    projection_angles = projection_angles.to(device)

    D, W, H = batched_object.shape

    Rz = rotations_z_axis(projection_angles)

    grid = torch.nn.functional.affine_grid(Rz.to(device), (len(Rz), 1, D, W, H), align_corners=False).to(device)

    rotated_volume = torch.nn.functional.grid_sample(
        batched_object.expand(len(Rz), 1, *batched_object.shape).to(device),
        grid, mode='bilinear', padding_mode='zeros', align_corners=False
    ).to(device)

    DD = abs(torch.ceil((W * torch.sin(torch.deg2rad(beta)) - D * torch.cos(torch.deg2rad(beta)))).int().item())

    ratio = W / D
    # width projection cropping
    crop_ratio = 0.64
    bounds = torch.tensor([
        [ W // 2 - DD // 2,W // 2 + DD // 2],
        [W // 2 - (W * crop_ratio) // 2, W // 2 + (W * crop_ratio) // 2],
    ]).int()
    subgrid = _grid_y_axis[:, bounds[0, 0]: bounds[0, 1], bounds[1, 0]:bounds[1, 1]].clone().to(device)
    subgrid[..., 2] *= ratio
    subgrid = subgrid.expand(len(Rz), *subgrid.shape[1:])

    final_volume = torch.nn.functional.grid_sample(
        rotated_volume.repeat(1, 1, 1, 1, 1),
        subgrid, mode='bilinear', padding_mode='zeros', align_corners=False
    )

    # Compute the projection by summing along the depth dimension
    projection = final_volume.sum(dim=-1).squeeze(0).squeeze(1)
    return projection

def project_2D(projection_angles, batched_object):
    """
    Projects a 3D object of shape (depth, width, weight) (the first argument can be thought as the batch size). It assumes width == height.
    Args:
        projection_angles: tensor with the list of angles onto which we want to project
        batched_object: the 3D object to be projected

    Returns: Produces a projection of shape (len(projection_angles), depth, width)

    """
    device = batched_object.device

    rotation_matrix = torch.stack([
        torch.stack([torch.cos(torch.deg2rad(projection_angles)), -torch.sin(torch.deg2rad(projection_angles)),
                     torch.zeros_like(projection_angles, device=device)], 1),
        torch.stack([torch.sin(torch.deg2rad(projection_angles)), torch.cos(torch.deg2rad(projection_angles)),
                     torch.zeros_like(projection_angles, device=device)], 1)
    ], 1)

    current_grid = F.affine_grid(rotation_matrix.to(batched_object.device),
                                 [len(projection_angles), *batched_object.size()], align_corners=False)
    rotated = F.grid_sample(batched_object.float().repeat(len(projection_angles), 1, 1, 1), current_grid,
                            align_corners=False)

    rotated = rotated.transpose(0, 1)
    # Sum over one of the dimensions to compute the projection
    projections = rotated.sum(axis=-2).squeeze(2)
    return projections.transpose(0, 1)

class Projections:

    def __init__(self, projections, angles, fourier_magnitude: bool = False, laminography:bool = False, laminography_tilt_angle=-1):
        assert (projections is not None)
        assert (angles is not None)

        self.angles = angles
        self.fourier_magnitude = fourier_magnitude
        self.laminography = laminography
        self.laminography_tilt_angle = laminography_tilt_angle

        if fourier_magnitude:
            self.real_projections = projections.clone()
            self.projections = torch.abs(b_fft_2D(projections))
            self.phases = torch.angle(b_fft_2D(projections))
        else:
            self.real_projections = self.projections = projections

    def add(self, angles, observations):
        assert (len(angles.shape) == 1) and (len(observations.shape) == 2)
        assert angles.shape[0] == observations.shape[0]

        if self.angles is None:
            self.angles = angles
            self.projections = observations
        else:
            self.angles = torch.cat([self.angles, angles])
            self.projections = torch.cat([self.projections, observations])

        self._sort()

    def __getitem__(self, key):
        return self.projections[key]

    def to(self, device):
        self.projections = self.projections.to(device)
        self.angles = self.angles.to(device)
        try:
            self.real_projections = self.real_projections.to(device)
            self.phases = self.phases.to(device)
        except:
            pass
        return self

    def _sort(self):
        # sort jointly
        self.angles, indices = torch.sort(self.angles)
        self.projections = self.projections[indices]

    @property
    def shape(self):
        if self.projections is None:
            return None
        return self.projections.shape

    def __len__(self):
        if self.angles is None:
            return 0
        return len(self.angles)

