import torch.nn as nn
import torch.nn.functional as F
import torch


class HarmonicEmbedding(torch.nn.Module):
    def __init__(self, n_harmonic_functions=60, omega0=0.1):
        """
        Given an input tensor `x` of shape [minibatch, ... , dim],
        the harmonic embedding layer converts each feature
        in `x` into a series of harmonic features `embedding`
        as follows:
            embedding[..., i*dim:(i+1)*dim] = [
                sin(x[..., i]),
                sin(2*x[..., i]),
                sin(4*x[..., i]),
                ...
                sin(2**(self.n_harmonic_functions-1) * x[..., i]),
                cos(x[..., i]),
                cos(2*x[..., i]),
                cos(4*x[..., i]),
                ...
                cos(2**(self.n_harmonic_functions-1) * x[..., i])
            ]

        Note that `x` is also premultiplied by `omega0` before
        evaluating the harmonic functions.
        """
        super().__init__()
        self.register_buffer(
            'frequencies',
            omega0 * (2.0 ** torch.arange(n_harmonic_functions)),
        )

    def forward(self, x):
        """
        Args:
            x: tensor of shape [..., dim]
        Returns:
            embedding: a harmonic embedding of `x`
                of shape [..., n_harmonic_functions * dim * 2]
        """
        embed = (x[..., None] * self.frequencies).view(*x.shape[:-1], -1)
        return torch.cat((embed.sin(), embed.cos()), dim=-1)


class NeuralRadianceField(torch.nn.Module):
    def __init__(self, width=512, n_harmonic_functions=60, n_hidden_neurons=256):
        super().__init__()
        self.harmonic_embedding = HarmonicEmbedding(n_harmonic_functions)
        self.w = width
        self.img = nn.Parameter(torch.zeros(width, width))
        embedding_dim = n_harmonic_functions * 2 * 2
        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(embedding_dim, n_hidden_neurons),
            torch.nn.Softplus(beta=10.0),
            torch.nn.Linear(n_hidden_neurons, n_hidden_neurons),
            torch.nn.Softplus(beta=10.0),
        )
        self.density_layer = torch.nn.Sequential(
            torch.nn.Linear(n_hidden_neurons, 1),
            torch.nn.Softplus(beta=10.0),
        )
        self.density_layer[0].bias.data[0] = -1.5


    def forward_coords(self, coords):
        """
        This function takes `features` predicted by `self.mlp`
        and converts them to `raw_densities` with `self.density_layer`.
        `raw_densities` are later mapped to [0-1] range with
        1 - inverse exponential of `raw_densities`.
        """
        return self.img[coords[:, 0], coords[:, 1]]

        # For each 3D world coordinate, we obtain its harmonic embedding.
        embeds = self.harmonic_embedding(
            coords
        )
        # embeds.shape = [minibatch x ... x self.n_harmonic_functions*6]

        # self.mlp maps each harmonic embedding to a latent feature space.
        features = self.mlp(embeds)
        # features.shape = [minibatch x ... x n_hidden_neurons]

        # Finally, given the per-point features,
        # execute the density branch

        raw_densities = self.density_layer(features)
        return 1 - (-raw_densities).exp()


    def forward(self):
        device = next(iter(self.parameters())).device
        coords = torch.cartesian_prod(torch.arange(self.w), torch.arange(self.w)).reshape(-1, 2).to(device)
        values = self.forward_coords(coords)
        return values.reshape(self.w, self.w)

def positional_encoding(x, L):
    out = [x]
    for j in range(L):
        out.append(torch.sin(2 ** j * x))
        out.append(torch.cos(2 ** j * x))
    return torch.cat(out, dim=1)


class NeRFWithEncoding(nn.Module):
    def __init__(self, num_freqs_coords=10, w=512):
        super(NeRFWithEncoding, self).__init__()
        self.w = w
        input_dim = 2 * (2 * num_freqs_coords + 1)
        hd = 128
        self.layers = nn.Sequential(
            nn.Linear(input_dim, hd),
            nn.ReLU(),
            nn.Linear(hd, hd),
            nn.ReLU(),
            nn.Linear(hd, hd),
            nn.ReLU(),
            nn.Linear(hd, 1)
        )
        self.num_freqs_coords = num_freqs_coords

    def forward_coords(self, coords):
        inputs = positional_encoding(coords, self.num_freqs_coords)
        return self.layers(inputs)

    def forward(self):
        device = next(iter(self.parameters())).device
        coords = torch.cartesian_prod(torch.arange(self.w), torch.arange(self.w)).reshape(-1, 2).to(device)
        values = self.forward_coords(coords)
        return values.reshape(self.w, self.w)


class TomographicReconstruction(nn.Module):
    def __init__(self, prior, use_sigmoid: bool, theta: int = 180):
        super(TomographicReconstruction, self).__init__()

        w = 512
        cp = torch.cartesian_prod(torch.arange(w), torch.arange(w))
        self.zero_indices = torch.where((cp[:, 0] - w / 2) ** 2 + (cp[:, 1] - w / 2) ** 2 > (w / 2) ** 2)[0]

        self.use_sigmoid = use_sigmoid
        self.prior = prior.clone()
        self.img = NeuralRadianceField(512, 60, 128)
        # self.img = nn.Parameter(torch.zeros(512, 512))

        self.theta = theta
        # Default is 180 angles from 0 to 179
        self.theta_range = torch.linspace(0, 179, self.theta)

        # We compute the rotation matrices in advance
        self.rotation_matrix = torch.stack([
            torch.stack([torch.cos(torch.deg2rad(self.theta_range)), -torch.sin(torch.deg2rad(self.theta_range)),
                         torch.zeros_like(self.theta_range)], 1),
            torch.stack([torch.sin(torch.deg2rad(self.theta_range)), torch.cos(torch.deg2rad(self.theta_range)),
                         torch.zeros_like(self.theta_range)], 1)
        ], 1)

        self.grid = F.affine_grid(self.rotation_matrix, self.prior.repeat(len(self.theta_range), 1, 1, 1).size(),
                                  align_corners=False)

    def get_img(self, filter=None, circle_crop=False, mle=False):
        if self.use_sigmoid:
            image = 5 * (self.prior - 0.5) + self.img()
            # image = 5 * (self.prior - 0.5) + self.img.to(self.prior.device)
            image = torch.sigmoid(image)
        else:
            image = self.prior# + self.img().to(self.prior.device)

        if filter is not None:
            image = filter(image)

        if circle_crop:
            image.view(-1)[self.zero_indices] = 0.

        if mle:
            image = (image >= 0.5).float()
        return image

    def forward(self, sinogram_indices, filter=None, filter_in_sinogram_space: bool = False):
        if not filter_in_sinogram_space:
            images = self.get_img(filter=filter)

        if sinogram_indices is not None:
            curent_theta_range = self.theta_range[sinogram_indices]
            current_grid = F.affine_grid(self.rotation_matrix[sinogram_indices].to(images.device),
                                         self.prior.repeat(len(curent_theta_range), 1, 1, 1).size(), align_corners=False)
        else:
            curent_theta_range = self.theta_range
            current_grid = self.grid.to(images.device)

        rotated = F.grid_sample(images.repeat(len(curent_theta_range), 1, 1, 1).float(), current_grid,
                                align_corners=False)
        # Sum over one of the dimensions to compute the projection
        sinogram = rotated.sum(axis=-2).squeeze(1)
        if not filter_in_sinogram_space:
            return sinogram
        else:
            return filter(sinogram)

    def get_mle_tr(self):
        return TomographicReconstruction(self.get_img(mle=True), use_sigmoid=False, theta=self.theta)



if __name__ == '__main__':
    import lovely_tensors as lt

    lt.monkey_patch()

    tr = TomographicReconstruction(torch.zeros(512, 512), use_sigmoid=True, theta=180)
    print(tr(torch.arange(180)))