import time

import pywt
import pywt.data
import torch
from numpy import concatenate
import numpy as np
from torch import nn
from torch.autograd import Function
import torch.nn.functional as F

DEVICE = "cuda:0"


######## depreciated ########
def depreciated_wt(x, filters, in_size, level):
    """
    Based on https://github.com/t-vi/pytorch-tvmisc/blob/master/misc/2D-Wavelet-Transform.ipynb
    """
    h = x.size(2)
    w = x.size(3)
    res = torch.nn.functional.conv2d(x, filters, stride=2, groups=in_size)
    out_dim = filters.shape[0]
    if level > 1:
        channels = list(range(0, out_dim, 4))
        res[:, channels] = depreciated_wt(res[:, channels], filters, in_size, level - 1)
    res = res.view(-1, 2, h // 2, w // 2).transpose(1, 2).contiguous().view(-1, in_size, h, w)
    return res


def depreciated_iwt(x, inv_filters, in_size, level):
    b, c, h, w = x.shape
    res = x.view(-1, h // 2, 2, w // 2).transpose(1, 2).contiguous().view(-1, 4 * in_size, h // 2, w // 2).clone()
    in_dim = inv_filters.shape[0]
    if level > 1:
        channels = list(range(0, in_dim, 4))
        res[:, channels] = depreciated_iwt(res[:, channels], inv_filters, in_size, level - 1)
    res = torch.nn.functional.conv_transpose2d(res, inv_filters, stride=2, groups=in_size)
    # res = res[:, :, 2:-2, 2:-2]
    return res


######## depreciated ########


def unsigned_quantization(data, scale, bins):
    # tried to shorten the data transformation for better performance
    return ((data.clamp(max=scale) / scale) * bins).round() / (bins / scale)


def quantize(activation):
    channels = np.array([x for x, _ in activation])
    scale = 3 * channels.mean()  # + 2 * channels.var()
    return [(unsigned_quantization(torch.Tensor(c), scale, bins=4), s) for c, s in activation]


def get_threshold(coeff, topk):
    flatten = [coeff[0].reshape(-1)] + [x.reshape(-1) for cHn, cVn, cDn in coeff[1:] for x in [cHn, cVn, cDn]]
    flatten = np.abs(concatenate(flatten))
    flatten.sort()
    size = int(flatten.size * topk)
    return abs(flatten[-size])


def chop(x, t):
    x[np.abs(x) < t] = 0
    return x


def wave_transform(activation, level=3, topk=0.25):
    x, y = activation.shape
    coeff = pywt.wavedec2(activation, 'haar', mode='zero', level=level)
    # threshold = 25. | 12.5 - 1 bit
    t = get_threshold(coeff, topk=topk)
    # cAn, (cHn, cVn, cDn)
    coeff = [chop(coeff[0], t)] + [[chop(c1, t), chop(c2, t), chop(c3, t)] for c1, c2, c3 in coeff[1:]]
    transformed = pywt.waverec2(coeff, 'haar', mode='zero')
    return transformed[:x, :y]


def create_filter(wave, in_size, type=torch.float):
    w = pywt.Wavelet(wave)
    dec_hi = torch.tensor(w.dec_hi[::-1], dtype=type)
    dec_lo = torch.tensor(w.dec_lo[::-1], dtype=type)
    filters = torch.stack([dec_lo.unsqueeze(0) * dec_lo.unsqueeze(1),
                           dec_lo.unsqueeze(0) * dec_hi.unsqueeze(1),
                           dec_hi.unsqueeze(0) * dec_lo.unsqueeze(1),
                           dec_hi.unsqueeze(0) * dec_hi.unsqueeze(1)], dim=0)

    filters = filters[:, None].repeat(in_size, 1, 1, 1)
    return filters


def wt(x, filters, in_size, level):
    """
    Based on https://github.com/t-vi/pytorch-tvmisc/blob/master/misc/2D-Wavelet-Transform.ipynb
    """
    _, _, h, w = x.shape
    res = F.conv2d(x, filters, stride=2, groups=in_size)
    if level > 1:
        res[:, ::4] = wt(res[:, ::4], filters, in_size, level - 1)
    res = res.reshape(-1, 2, h // 2, w // 2).transpose(1, 2).reshape(-1, in_size, h, w)
    return res


def iwt(x, inv_filters, in_size, level):
    _, _, h, w = x.shape
    res = x.reshape(-1, h // 2, 2, w // 2).transpose(1, 2).reshape(-1, 4 * in_size, h // 2, w // 2)
    if level > 1:
        res[:, ::4] = iwt(res[:, ::4], inv_filters, in_size, level - 1)
    res = F.conv_transpose2d(res, inv_filters, stride=2, groups=in_size)
    return res


def inverse_wavelet_transform_init(weight, in_size, level):
    class InverseWaveletTransform(Function):
        """
        based on https://pytorch.org/docs/stable/notes/extending.html
        """

        @staticmethod
        def forward(ctx, input):
            with torch.no_grad():
                x = iwt(input, weight, in_size, level)
            return x

        @staticmethod
        def backward(ctx, grad_output):
            grad = wt(grad_output, weight, in_size, level)
            return grad, None

    return InverseWaveletTransform().apply


def wavelet_transform_init(weight, in_size, level):
    class WaveletTransform(Function):
        """
        based on https://pytorch.org/docs/stable/notes/extending.html
        """

        @staticmethod
        def forward(ctx, input):
            with torch.no_grad():
                x = wt(input, weight, in_size, level)
            return x

        @staticmethod
        def backward(ctx, grad_output):
            grad = iwt(grad_output, weight, in_size, level)
            return grad, None

    return WaveletTransform().apply


class DwtCompress(nn.Module):
    def __init__(self, in_size, level, compress_rate, wave='db1', mode='zero', dtype=torch.float):
        # db1 same as haar
        super().__init__()
        self.level = level
        self.in_size = in_size
        self.filter = nn.Parameter(create_filter(wave, in_size, dtype), requires_grad=False)
        self.compress_rate = compress_rate
        self.wavelet_transform = wavelet_transform_init(self.filter, in_size, level)

    def forward(self, x):
        x = self.wavelet_transform(x)
        topk, indexes = self.compress(x)
        return topk, indexes

    def compress(self, x):
        b, c, h, w = x.shape
        acc = x.norm(dim=1).pow(2)
        acc = acc.view(b, h * w)
        k = int(h * w * self.compress_rate)
        indexes = acc.topk(k, dim=1)[1]
        indexes.unsqueeze_(dim=1)
        topk = x.reshape((b, c, h * w)).gather(dim=2, index=indexes.repeat(1, c, 1))
        return topk, indexes


class DwtCompressDense(nn.Module):
    def __init__(self, in_size, level, compress_rate, wave='db1', mode='zero', dtype=torch.float):
        # db1 same as haar
        super().__init__()
        self.level = level
        self.in_size = in_size
        self.filter = nn.Parameter(create_filter(wave, in_size, dtype), requires_grad=False)
        self.compress_rate = compress_rate
        self.wavelet_transform = wavelet_transform_init(self.filter, in_size, level)

    def forward(self, x):
        x = self.wavelet_transform(x)
        topk = self.compress(x)
        return (topk - x).detach() + x

    def compress(self, x):
        b, c, h, w = x.shape
        acc = x.norm(dim=1).pow(2)
        acc = acc.view(b, h * w)
        k = int(h * w * self.compress_rate)
        indexes = acc.topk(k, dim=1)[1]
        indexes.unsqueeze_(dim=1)
        keep_mask = torch.zeros((b, c, h * w), device=x.device)
        keep_mask = keep_mask.view(b, c, h * w).scatter(dim=2, index=indexes.repeat(1, c, 1), src=torch.ones((b, c, h * w), device=x.device)).view(b, c, h, w)
        topk = x * keep_mask
        return topk


def depreciation_performance_tests():
    in_channels = 256
    level = 3
    device = "cuda:0"

    x = torch.FloatTensor(torch.rand(64, in_channels, 128, 128)).to(device)
    x2 = torch.FloatTensor(torch.rand(64, in_channels, 128, 128)).to(device)

    xx = torch.FloatTensor(torch.rand(64, in_channels, 128, 128)).to(device)
    xx2 = torch.FloatTensor(torch.rand(64, in_channels, 128, 128)).to(device)
    filters = create_filter(wave='db1', in_size=in_channels).to(device)
    iwt_filters = create_filter(wave='db1', in_size=in_channels).to(device)

    ###### 1
    start = time.time()
    wt_res = wt(x, filters, in_channels, level=level)
    end = time.time()
    print(f"wt time {end - start}")
    ###### 1

    ###### 2
    start = time.time()
    depreciated_res = depreciated_wt(x2, filters, in_channels, level=level)
    end = time.time()
    print(f"depreciated_wt time {end - start}")
    ###### 2

    ###### 3
    start = time.time()
    iwt(xx, iwt_filters, in_channels, level=level)
    end = time.time()
    print(f"iwt time {end - start}")
    ###### 3

    ###### 4
    start = time.time()
    depreciated_iwt(xx2, iwt_filters, in_channels, level=level)
    end = time.time()
    print(f"depreciated_iwt time {end - start}")
    ###### 4

# depreciation_performance_tests()
# dwt_comp = DwtCompress(3).cuda()
# dwt_comp(x)
#
# x = torch.zeros(4, 3, 10, 10)
# batch, channels, h, w = x.shape
# for b in range(batch):
#     for c in range(channels):
#         for i in range(h):
#             x[b, c, i, :] = torch.arange(start=1, end=11)
# x = x.cuda()
# dwt_comp = DwtCompress(3).cuda()
# dwt_comp(x)
