from enum import IntEnum

import torch as th
import numpy as np
import os

def dct(x, norm=None):
    """
    Discrete Cosine Transform, Type II (a.k.a. the DCT)

    For the meaning of the parameter `norm`, see:
    https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html

    :param x: the input signal
    :param norm: the normalization, None or 'ortho'
    :return: the DCT-II of the signal over the last dimension
    """
    x_shape = x.shape
    N = x_shape[-1]
    x = x.contiguous().view(-1, N)

    v = th.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1)

    Vc = th.view_as_real(th.fft.fft(v, dim=1))

    k = - th.arange(N, dtype=x.dtype, device=x.device)[None, :] * th.pi / (2 * N)
    W_r = th.cos(k)
    W_i = th.sin(k)

    V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i

    if norm == 'ortho':
        V[:, 0] /= (N ** 0.5) * 2
        V[:, 1:] /= ((N / 2) ** 0.5) * 2

    V = 2 * V.view(*x_shape)

    return V

def idct(X, norm=None):
    """
    The inverse to DCT-II, which is a scaled Discrete Cosine Transform, Type III

    Our definition of idct is that idct(dct(x)) == x

    For the meaning of the parameter `norm`, see:
    https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html

    :param X: the input signal
    :param norm: the normalization, None or 'ortho'
    :return: the inverse DCT-II of the signal over the last dimension
    """

    x_shape = X.shape
    N = x_shape[-1]

    X_v = X.contiguous().view(-1, x_shape[-1]) / 2

    if norm == 'ortho':
        X_v[:, 0] *= (N ** 0.5) * 2
        X_v[:, 1:] *= ((N / 2) ** 0.5) * 2

    k = th.arange(x_shape[-1], dtype=X.dtype, device=X.device)[None, :] * th.pi / (2 * N)
    W_r = th.cos(k)
    W_i = th.sin(k)

    V_t_r = X_v
    V_t_i = th.cat([X_v[:, :1] * 0, -X_v.flip([1])[:, :-1]], dim=1)

    V_r = V_t_r * W_r - V_t_i * W_i
    V_i = V_t_r * W_i + V_t_i * W_r

    V = th.cat([V_r.unsqueeze(2), V_i.unsqueeze(2)], dim=2)

    v = th.fft.irfft(th.view_as_complex(V), n=V.shape[1], dim=1)
    x = v.new_zeros(v.shape)
    x[:, ::2] += v[:, :N - (N // 2)]
    x[:, 1::2] += v.flip([1])[:, :N // 2]

    return x.view(*x_shape)

def dct_2d(x, norm=None):
    """
    2-dimentional Discrete Cosine Transform, Type II (a.k.a. the DCT)

    For the meaning of the parameter `norm`, see:
    https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html

    :param x: the input signal
    :param norm: the normalization, None or 'ortho'
    :return: the DCT-II of the signal over the last 2 dimensions
    """
    X1 = dct(x, norm=norm)
    X2 = dct(X1.transpose(-1, -2), norm=norm)
    return X2.transpose(-1, -2)

def idct_2d(X, norm=None):
    """
    The inverse to 2D DCT-II, which is a scaled Discrete Cosine Transform, Type III

    Our definition of idct is that idct_2d(dct_2d(x)) == x

    For the meaning of the parameter `norm`, see:
    https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html

    :param X: the input signal
    :param norm: the normalization, None or 'ortho'
    :return: the DCT-II of the signal over the last 2 dimensions
    """
    x1 = idct(X, norm=norm)
    x2 = idct(x1.transpose(-1, -2), norm=norm)
    return x2.transpose(-1, -2)


class DCTLayer(th.nn.Linear):
    """
    Implement any DCT as a linear layer; in practice this executes around
    50x faster on GPU. Unfortunately, the DCT matrix is stored, which will 
    increase memory usage.
        :param in_features: size of expected input
        :param type: which dct function in this file to use
    """

    class Mode(IntEnum):
        DCT = 1
        IDCT = -1

    def __init__(self,
                 side: int = 32,
                 mode: Mode = Mode.DCT,
                 squared: bool = False
                 ):
        self.mode = mode
        self.side = side
        self.squared = squared
        super().__init__(side, side, bias=False)

    def reset_parameters(self):
        # initialise using dct function
        I = th.eye(self.side)
        
        if self.mode == DCTLayer.Mode.DCT:
            self.weight.data = dct(I, norm='ortho').data.t()
        elif self.mode == DCTLayer.Mode.IDCT:
            self.weight.data = idct(I, norm='ortho').data.t()
        else:
            raise NotImplementedError('Only allowed values are: Mode.DCT and Mode.IDCT')
        
        if self.mode == DCTLayer.Mode.DCT:
            name = 'dct' 
        elif self.mode == DCTLayer.Mode.IDCT:
            name = 'idct'
        else:
            name = 'dct2'

        np.savetxt(os.path.join('logs/cifar_ls/hes_blockcirc/evaluate/evaluator/example/best', 
                             name + ".txt"),
                            self.weight.data.detach().cpu().numpy(), delimiter=' ')
        if self.squared:
            self.weight.data = self.weight.data ** 2

        self.requires_grad_(False) # DO NOT learn this
    
    def forward(self, x):
        """
        Can be used with a LinearDCT layer to do a 2D DCT.
            :param x: the input signal
            :return: result of linear layer applied to last 2 dimensions
        """
        X1 = super().forward(x)
        X2 = super().forward(X1.transpose(-1, -2))
        return X2.transpose(-1, -2)
    
    def add_kron(self):
        self.kron = th.kron(
            self.weight.data, self.weight.data
        ).to(self.weight.device)