"""Synthetic Regression Datasets."""
import torch
import numpy as np

from .operators import MatrixOperator


class Cosine(torch.utils.data.Dataset):
    """A dataset of noisy cosine waves. Task: Denoising."""

    def __init__(self, num_datapoints, noise_level=0.1, freq_range=[0, 2 * np.pi], wave_length=50):
        """Initialize with some noise level."""
        self.wave_length = wave_length
        self.noise_level = noise_level
        self.operator = MatrixOperator(self.get_matrix(), dimension=1, channels=1)
        self.op_cpu = MatrixOperator(self.get_matrix(), dimension=1, channels=1)

        self.indices = torch.linspace(-np.pi / 2, np.pi / 2, wave_length)
        self.freq_range = freq_range

        self.num_datapoints = num_datapoints  # dummy

    def get_matrix(self):
        return torch.eye(self.wave_length)

    def __len__(self):
        return self.num_datapoints

    # @torch.jit.script
    def __getitem__(self, index):
        """Note that y is the generated data and x is the original GT data."""
        frequency = torch.rand(1,) * (self.freq_range[1] - self.freq_range[0]) + self.freq_range[0]
        offset_x = torch.randn(1,)
        offset_y = torch.randn(1,)
        x = torch.cos(frequency * self.indices + offset_x) + offset_y
        x = x.unsqueeze(0)   # 1 channel
        y = self.op_cpu(x)
        if self.noise_level is not None:
            y += torch.randn_like(y) * self.noise_level
        return y, x


class SincWave(Cosine):
    """Modify input wave to be a sinc wave."""

    def __getitem__(self, index):
        """Note that y is the generated data and x is the original GT data."""
        frequency = torch.rand(1,) * (self.freq_range[1] - self.freq_range[0]) + self.freq_range[0]
        offset_x = torch.randn(1,)
        offset_y = torch.randn(1,)
        x = torch.sinc(frequency * self.indices + offset_x) + offset_y
        x = x.unsqueeze(0)   # 1 channel
        y = self.op_cpu(x)
        if self.noise_level is not None:
            y += torch.randn_like(y) * self.noise_level
        return y, x


class PiecewiseConstant(Cosine):
    """Modify input wave to be a piece-wise constant wave.

    Note: Number of jumps is N(0, 1) > 1.0, so 1 - Phi(1) = 0.1587, evaluated self.wavelength many times
    """

    def __getitem__(self, index):
        """Note that y is the generated data and x is the original GT data."""
        x0 = (torch.randn(self.wave_length) > 1).cumsum(dim=0)
        jumps = x0.max() + 1
        offset_x = torch.randn(jumps)
        x = offset_x[x0]
        x = x.unsqueeze(0)   # 1 channel
        y = self.op_cpu(x)
        if self.noise_level is not None:
            y += torch.randn_like(y) * self.noise_level
        return y, x


class Integration(Cosine):
    """A dataset of measurements of the integrals of noisy cosine waves. Task: Differentation and denoising."""

    def get_matrix(self):
        op = torch.tril(torch.ones(self.wave_length, self.wave_length))
        return op / op.sum(dim=0, keepdim=True)


class Differentiation(Cosine):
    """A dataset of measurements of the time differences of noisy cosine waves. Task: Integration and denoising."""

    def get_matrix(self):
        operator = torch.eye(self.wave_length)
        operator -= torch.diag_embed(torch.ones(self.wave_length - 1), offset=-1)
        return operator


class Blur(Cosine):
    """A dataset of measurements of the blurs of noisy cosine waves. Task: Deblurring and denoising."""

    def __init__(self, num_datapoints, noise_level=0.1, freq_range=[np.pi / 2, 2 * np.pi], wave_length=50, sigma=20):
        self.sigma = sigma
        super().__init__(num_datapoints, noise_level, freq_range, wave_length)

    def get_matrix(self):
        blur_op = torch.eye(self.wave_length)
        for offset in range(-7, 7):
            vals = torch.ones(self.wave_length - np.abs(offset)) * offset / self.wave_length
            input = torch.exp(- 0.5 * (vals / self.sigma) ** 2)
            # input = input / self.sigma / np.sqrt(2 * np.pi)
            blur_op += torch.diag_embed(input, offset=offset)
        blur_op = blur_op / blur_op.sum(dim=0, keepdim=True)
        return blur_op


class Downsampling(Cosine):
    """A dataset of measurements of downsampled blurs of noisy cosine waves. Task: Deblurring, denoising and upsampling."""

    def __init__(self, num_datapoints, noise_level=0.1, freq_range=[np.pi / 2, 2 * np.pi], wave_length=50, sigma=20,
                 downsampling=4):
        self.sigma = sigma
        self.downsampling = downsampling
        super().__init__(num_datapoints, noise_level, freq_range, wave_length)

    def get_matrix(self):
        blur_op = torch.eye(self.wave_length)
        for offset in range(-7, 7):
            vals = torch.ones(self.wave_length - np.abs(offset)) * offset / self.wave_length
            input = torch.exp(- 0.5 * (vals / self.sigma) ** 2)
            # input = input / self.sigma / np.sqrt(2 * np.pi)
            blur_op += torch.diag_embed(input, offset=offset)
        blur_op = blur_op / blur_op.sum(dim=0, keepdim=True)
        # Dont return all rows:
        blur_op = blur_op[::self.downsampling]
        return blur_op


# Template other combinations


class IntegrationSinc(SincWave, Integration):
    pass


class DifferentationSinc(SincWave, Differentiation):
    pass


class BlurSinc(SincWave, Blur):
    pass


class DownsamplingSinc(SincWave, Downsampling):
    pass


class IntegrationPiecewise(PiecewiseConstant, Integration):
    pass


class DifferentationPiecewise(PiecewiseConstant, Differentiation):
    pass


class BlurPiecewise(PiecewiseConstant, Blur):
    pass


class DownsamplingPiecewise(PiecewiseConstant, Downsampling):
    pass
