import torch, numbers, math
import torch.nn as nn
import torch.nn.functional as torchfunc
from operators.operator import LinearOperator

# from scipy.fft import dct
import numpy as np
from scipy.fftpack import dct, idct
import random, dival, odl
from odl.contrib.torch import OperatorModule
from odl.contrib.torch.operator import OperatorModule
# from odl.contrib import torch as odl_torch


class addNoise(LinearOperator):
    def __init__(self):
        super(addNoise, self).__init__()

    def forward(self, input):
        return input

    def gramian(self, input):
        return input

    def adjoint(self, input):
        return input


def gen_sampMatrix(args):
    if args.A_type == 'Gaussian':
        A = torch.normal(0, args.Gauss_var, size=(args.img_dim ** 2, args.img_dim ** 2))
        norm = torch.norm(A, dim=0)
        return A / norm[None, :]
    if args.A_type == 'Bin':
        A = torch.empty(args.img_dim ** 2, args.img_dim ** 2).uniform_(0, 1)
        A = torch.bernoulli(A)
        norm = torch.norm(A, dim=0)
        return A / norm[None, :]
    if args.A_type == 'DCT':
        return torch.from_numpy(dct(np.eye(args.img_dim ** 2), axis=0, norm='ortho')).type(torch.FloatTensor)


def gen_partial_row_F(A, args):
    nsample = int(args.img_dim ** 2 * args.s)
    k = random.sample(range(1, args.img_dim ** 2), nsample - 1)  # Subsampling locations, k: rows to KEEP!
    k.append(0)
    k = np.sort(k)
    A_omega = torch.clone(A[k, :])
    return A_omega


class compressedSensing(LinearOperator):
    def __init__(self, args):
        super(compressedSensing, self).__init__()
        self.A = gen_sampMatrix(args)
        self.Ap = 0
        self.s = args.s
        self.img_dim = args.img_dim
        self.randA = args.randA

    def genA(self, A):
        nsample = int(self.img_dim ** 2 * self.s)
        k = random.sample(range(1, self.img_dim ** 2), nsample - 1)  # Subsampling locations, k: rows to KEEP!
        k.append(0)
        k = np.sort(k)
        Ap = torch.clone(self.A[k, :])
        return Ap

    def gen_partial_row_A(self):
        if self.randA or (not self.randA and self.Ap == 0):
            self.Ap = self.genA(self.A)

    def forward(self, x):
        self.gen_partial_row_F()
        return self.Ap @ x

    def gramian(self, x):
        return self.Ap.T @ self.Ap @ x

    def adjoint(self, y):
        return self.Ap.T @ y


class GaussianBlur(LinearOperator):
    def __init__(self, sigma, kernel_size=5, n_channels=1, n_spatial_dimensions=2):
        super(GaussianBlur, self).__init__()
        self.groups = n_channels
        if isinstance(kernel_size, numbers.Number):
            self.padding = int(math.floor(kernel_size / 2))
            kernel_size = [kernel_size] * n_spatial_dimensions
        else:
            print('KERNEL SIZE MUST BE A SINGLE INTEGER - RECTANGULAR KERNELS NOT SUPPORTED AT THIS TIME')
            exit()
        self.gaussian_kernel = torch.nn.Parameter(self.create_gaussian_kernel(sigma, kernel_size, n_channels),
                                                  requires_grad=False)

    def create_gaussian_kernel(self, sigma, kernel_size, n_channels):
        kernel = 1
        meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size])
        for size, mgrid in zip(kernel_size, meshgrids):
            mean = (size - 1) / 2
            kernel *= torch.exp(-((mgrid - mean) / sigma) ** 2 / 2)

        # Make sure norm of values in gaussian kernel equals 1.
        kernel = kernel / torch.sum(kernel)

        # Reshape to depthwise convolutional weight
        kernel = kernel.view(1, 1, *kernel.size())
        kernel = kernel.repeat(n_channels, *[1] * (kernel.dim() - 1))
        return kernel

    def forward(self, x):
        return torchfunc.conv2d(x, weight=self.gaussian_kernel, groups=self.groups, padding=self.padding)

    def adjoint(self, x):
        return torchfunc.conv2d(x, weight=self.gaussian_kernel, groups=self.groups, padding=self.padding)


class CT(LinearOperator):
    def __init__(self, bs, ray_trafo):
        super(CT, self).__init__()
        self.bs = bs
        self.ray_trafo = ray_trafo
        # self.forward_op = dival.util.torch_utility.get_torch_ray_trafo_parallel_2d(self.ray_trafo, z_shape=self.bs)
        # self.adjoint_op = dival.util.torch_utility.get_torch_ray_trafo_parallel_2d_adjoint(self.ray_trafo, z_shape=self.bs)

        self.forward_op = OperatorModule(self.ray_trafo)
        self.adjoint_op = OperatorModule(self.ray_trafo.adjoint)
        partial0 = odl.PartialDerivative(self.ray_trafo.domain, axis=0)
        partial1 = odl.PartialDerivative(self.ray_trafo.domain, axis=1)
        self.reg_op = OperatorModule(partial0.adjoint * partial0 +
                                      partial1.adjoint * partial1)

    def forward(self, x):
        return self.forward_op(x)

    def adjoint(self, y):
        return self.adjoint_op(y)
