import pywt
import pywt.data
import torch
from torch import nn
from torch.autograd import Function
import torch.nn.functional as F


def create_filter(wave, in_size, type=torch.float):
    w = pywt.Wavelet(wave)
    dec_hi = torch.tensor(w.dec_hi[::-1], dtype=type)
    dec_lo = torch.tensor(w.dec_lo[::-1], dtype=type)
    filters = torch.stack([dec_lo.unsqueeze(0) * dec_lo.unsqueeze(1),
                           dec_lo.unsqueeze(0) * dec_hi.unsqueeze(1),
                           dec_hi.unsqueeze(0) * dec_lo.unsqueeze(1),
                           dec_hi.unsqueeze(0) * dec_hi.unsqueeze(1)], dim=0)

    filters = filters[:, None].repeat(in_size, 1, 1, 1)
    return filters


def wt(x, filters, in_size, level):
    _, _, h, w = x.shape
    res = F.conv2d(x, filters, stride=2, groups=in_size)
    if level > 1:
        res[:, ::4] = wt(res[:, ::4], filters, in_size, level - 1)
    res = res.reshape(-1, 2, h // 2, w // 2).transpose(1, 2).reshape(-1, in_size, h, w)
    return res


def iwt(x, inv_filters, in_size, level):
    _, _, h, w = x.shape
    res = x.reshape(-1, h // 2, 2, w // 2).transpose(1, 2).reshape(-1, 4 * in_size, h // 2, w // 2)
    if level > 1:
        res[:, ::4] = iwt(res[:, ::4], inv_filters, in_size, level - 1)
    res = F.conv_transpose2d(res, inv_filters, stride=2, groups=in_size)
    return res


def inverse_wavelet_transform_init(weight, in_size, level):
    class InverseWaveletTransform(Function):

        @staticmethod
        def forward(ctx, input):
            with torch.no_grad():
                x = iwt(input, weight, in_size, level)
            return x

        @staticmethod
        def backward(ctx, grad_output):
            grad = wt(grad_output, weight, in_size, level)
            return grad, None

    return InverseWaveletTransform().apply


def wavelet_transform_init(weight, in_size, level):
    class WaveletTransform(Function):

        @staticmethod
        def forward(ctx, input):
            with torch.no_grad():
                x = wt(input, weight, in_size, level)
            return x

        @staticmethod
        def backward(ctx, grad_output):
            grad = iwt(grad_output, weight, in_size, level)
            return grad, None

    return WaveletTransform().apply


class DwtCompress(nn.Module):
    def __init__(self, in_size, level, compress_rate, wave='db1', mode='zero', dtype=torch.float):
        # db1 same as haar
        super().__init__()
        self.level = level
        self.in_size = in_size
        self.filter = nn.Parameter(create_filter(wave, in_size, dtype), requires_grad=False)
        self.compress_rate = compress_rate
        self.wavelet_transform = wavelet_transform_init(self.filter, in_size, level)

    def forward(self, x):
        x = self.wavelet_transform(x)
        topk, ids = self.compress(x)
        return topk, ids

    def compress(self, x):
        b, c, h, w = x.shape
        acc = x.norm(dim=1).pow(2)
        acc = acc.view(b, h * w)
        k = int(h * w * self.compress_rate)
        ids = acc.topk(k, dim=1, sorted=False)[1]
        ids.unsqueeze_(dim=1)
        topk = x.reshape((b, c, h * w)).gather(dim=2, index=ids.repeat(1, c, 1))
        return topk, ids
