import torch
import torch.nn as nn
import torch.nn.functional as F
from . import quantized_ops


def str2layer(qtype):
    return getattr(quantized_ops, qtype)


# -- Fully Connected Layers
class QuantizedLinear(nn.Linear):
    def __init__(self, *args, qfx=None, **kwargs):
        super().__init__(*args, **{**kwargs, "bias": False})
        assert qfx is not None, "Quantization function is not set for QuantizedLinear"
        self.qfx = str2layer(qfx)

    def get_quantized_weights(self):
        return self.qfx(self.weight)

    def forward(self, nx):
        with torch.no_grad():
            delta_q = self.get_quantized_weights() - self.weight

        quantized_weight = self.weight + delta_q
        out = F.linear(nx, quantized_weight, None)
        return out


class PrunableQuantizedLinear(QuantizedLinear):
    def __init__(self, *args, soft=False, structured=False, **kwargs):
        super().__init__(*args, qfx="global_prune_ternarize", **{**kwargs, "bias": False})
        self.th_prune = -1
        self.soft =soft

    def set_prune(self, th, prune_perc):
        self.th_prune = th

    @torch.no_grad()
    def force_prune(self, th, prune_perc):
        self.set_prune(th, prune_perc)
        self.weight[:] = self.get_quantized_weights()

    def get_quantized_weights(self):
        return self.qfx(self.weight, self.th_prune, soft=self.soft)

    def forward(self, nx):  # Could be removed but this should be faster
        with torch.no_grad():
            if self.th_prune > 0:
                delta_q = self.get_quantized_weights() - self.weight
            else:
                delta_q = 0

        quantized_weight = self.weight + delta_q
        out = F.linear(nx, quantized_weight, None)
        return out

class RPrunableQuantizedLinear(QuantizedLinear):
    def __init__(self, *args, soft=False, structured=False, **kwargs):
        super().__init__(*args, qfx="global_prune_ternarize", **{**kwargs, "bias": False})
        self.th_prune = -1
        self.soft =soft

    def set_prune(self, th, prune_perc):
        self.th_prune = th

    @torch.no_grad()
    def force_prune(self, th, prune_perc):
        self.set_prune(th, prune_perc)
        self.weight[:] = self.get_quantized_weights()

    def get_quantized_weights(self):
        return self.qfx(self.weight, self.th_prune, soft=self.soft, rescale=True)

    def forward(self, nx):  # Could be removed but this should be faster
        with torch.no_grad():
            if self.th_prune > 0:
                delta_q = self.get_quantized_weights() - self.weight
            else:
                delta_q = 0

        quantized_weight = self.weight + delta_q
        out = F.linear(nx, quantized_weight, None)
        return out


class PrunableLinear(nn.Linear):
    def __init__(self, *args, structured=False, soft=False, **kwargs):
        super().__init__(*args, **{**kwargs, "bias": False})
        self.th_prune = -1
        self.prune_perc = 0
        self.soft = soft
        self._mask = nn.Parameter(torch.ones_like(self.weight, device=self.weight.device, requires_grad=False), requires_grad=False)
        self.hook_handle = self.weight.register_hook(lambda grad, mask=self._mask: grad*mask)

    @torch.no_grad()
    def set_prune(self, th, prune_perc):
        """
        Will set parts of the weigths to zero. Once set to zero those weights will remain null untill
        th is set back to zero which triggers a reset
        """
        if prune_perc == self.prune_perc:
            self.weight[:] = self.weight * self._mask
            return

        if th < 0:
            self._mask[:] = 1
        else:
            pruned_weights = (torch.abs(self.weight) <= th)
            self._mask[pruned_weights] = 0
            self.weight[:] = self.weight * self._mask

        self.prune_perc = prune_perc
        self.th_prune = th

    @torch.no_grad()
    def force_prune(self, th, prune_perc):
        self.weight[:] = self.weight * self._mask

    def forward(self, nx):
        if not self.soft:
            self.force_prune(self.th_prune, self.prune_perc)
            return super().forward(nx)
        else:
            with torch.no_grad():
                if self.th_prune > 0:
                    delta_q = torch.sign(self.weight)*(self.weight.abs()-self.th_prune).clamp_(min=0) - self.weight
                else:
                    delta_q = 0
            quantized_weight = self.weight + delta_q
            out = F.linear(nx, quantized_weight, None)
            return out


class DifferentiablePrunableLinear(nn.Linear):
    def __init__(self, *args, soft=False, **kwargs):
        super().__init__(*args, **{**kwargs, "bias": False})
        self.th_prune = -1
        self.soft = soft

    @torch.no_grad()
    def set_prune(self, th, prune_perc):
        self.th_prune = th

    @torch.no_grad()
    def force_prune(self, th, prune_perc):
        self.set_prune(th, prune_perc)
        self.weight[:] = self.get_quantized_weights()

    def get_quantized_weights(self):
        if self.th_prune > 0:
            if self.soft:
                return torch.sign(self.weight) * (self.weight.abs()-self.th_prune).clamp_(min=0)
            else:
                weight = self.weight.clone()
                weight[weight.abs()<=self.th_prune] = 0
                return weight
        else:
            return self.weight

    def forward(self, nx):  # Could be removed but this should be faster if th_prune==0
        quantized_weight = self.get_quantized_weights()
        out = F.linear(nx, quantized_weight, None)
        return out


# -- Conv2d Layers
class QuantizedConv2d(nn.Conv2d):
    def __init__(self, *args, qfx=None, **kwargs):
        super().__init__(*args, **{**kwargs, "bias": False})
        assert qfx is not None, "Quantization function is not set for QuantizedConv2d"
        self.qfx = str2layer(qfx)

    def get_quantized_weights(self):
        return self.qfx(self.weight)

    def forward(self, nx):
        with torch.no_grad():
            delta_q = self.get_quantized_weights() - self.weight

        quantized_weight = self.weight + delta_q
        nx = F.conv2d(
            nx,
            quantized_weight,
            None, self.stride,
            self.padding, self.dilation, self.groups
        )
        return nx


class PrunableQuantizedConv2d(QuantizedConv2d):
    def __init__(self, *args, structured=False, soft=False, **kwargs):
        super().__init__(*args, qfx="global_prune_ternarize", **{**kwargs, "bias": False})
        self.th_prune = -1
        self.structured = structured
        self.soft = soft

    @torch.no_grad()
    def set_prune(self, th, prune_perc):
        self.th_prune = th

    @torch.no_grad()
    def force_prune(self, th, prune_perc):
        self.set_prune(th, prune_perc)
        self.weight[:] = self.get_quantized_weights()

    def get_quantized_weights(self):
        return self.qfx(self.weight, self.th_prune, self.structured, self.soft)

    def forward(self, nx):  # Could be removed but this should be faster if th_prune==0
        with torch.no_grad():
            if self.th_prune > 0:
                delta_q = self.get_quantized_weights() - self.weight
            else:
                delta_q = 0

        quantized_weight = self.weight + delta_q
        nx = F.conv2d(
            nx,
            quantized_weight,
            None, self.stride,
            self.padding, self.dilation, self.groups
        )
        return nx

class RPrunableQuantizedConv2d(QuantizedConv2d):
    def __init__(self, *args, structured=False, soft=False, **kwargs):
        super().__init__(*args, qfx="global_prune_ternarize", **{**kwargs, "bias": False})
        self.th_prune = -1
        self.structured = structured
        self.soft = soft

    @torch.no_grad()
    def set_prune(self, th, prune_perc):
        self.th_prune = th

    @torch.no_grad()
    def force_prune(self, th, prune_perc):
        self.set_prune(th, prune_perc)
        self.weight[:] = self.get_quantized_weights()

    def get_quantized_weights(self):
        return self.qfx(self.weight, self.th_prune, self.structured, self.soft, True)

    def forward(self, nx):  # Could be removed but this should be faster if th_prune==0
        with torch.no_grad():
            if self.th_prune > 0:
                delta_q = self.get_quantized_weights() - self.weight
            else:
                delta_q = 0

        quantized_weight = self.weight + delta_q
        nx = F.conv2d(
            nx,
            quantized_weight,
            None, self.stride,
            self.padding, self.dilation, self.groups
        )
        return nx

class PrunableConv2d(nn.Conv2d):
    def __init__(self, *args, structured=False, soft=False, **kwargs):
        super().__init__(*args, **{**kwargs, "bias": False})
        self.th_prune = -1
        self.prune_perc = 0
        self.structured = structured
        self._mask = nn.Parameter(torch.ones_like(self.weight, device=self.weight.device, requires_grad=False), requires_grad=False)
        self.soft = soft
        self.hook_handle = self.weight.register_hook(lambda grad, mask=self._mask: grad*mask)

    @torch.no_grad()
    def set_prune(self, th, prune_perc):
        """
        Will set parts of the weigths to zero. Once set to zero those weights will remain null untill
        th is set back to zero which triggers a reset
        """
        if prune_perc == self.prune_perc:
            self.weight[:] = self.weight * self._mask
            return

        if th < 0:
            self._mask[:] = 1
        else:
            if not self.structured:
                pruned_weights = (self.weight.abs() <= th)
            else:
                pruned_weights = (self.weight.abs().mean(axis=(-2,-1)) <= th)
            self._mask[pruned_weights] = 0
            self.weight[:] = self.weight * self._mask

        self.prune_perc = prune_perc
        self.th_prune = th

    @torch.no_grad()
    def force_prune(self, th, prune_perc):
        self.weight[:] = self.weight * self._mask


    def forward(self, nx):
        if not self.soft:
            self.force_prune(self.th_prune, self.prune_perc)
            return super().forward(nx)
        else:
            with torch.no_grad():
                if self.th_prune > 0:
                    delta_q = torch.sign(self.weight)*(self.weight.abs()-self.th_prune).clamp_(min=0) - self.weight
                else:
                    delta_q = 0
            quantized_weight = self.weight + delta_q
            nx = F.conv2d(
                nx,
                quantized_weight,
                None, self.stride,
                self.padding, self.dilation, self.groups
            )
            return nx


class DifferentiablePrunableConv2d(nn.Conv2d):
    def __init__(self, *args, soft=False, **kwargs):
        super().__init__(*args, **{**kwargs, "bias": False})
        self.th_prune = -1
        self.soft = soft

    @torch.no_grad()
    def set_prune(self, th, prune_perc):
        self.th_prune = th

    @torch.no_grad()
    def force_prune(self, th, prune_perc):
        self.set_prune(th, prune_perc)
        self.weight[:] = self.get_quantized_weights()

    def get_quantized_weights(self):
        if self.th_prune > 0:
            if self.soft:
                return torch.sign(self.weight) * (self.weight.abs()-self.th_prune).clamp_(min=0)
            else:
                weight = self.weight.clone()
                weight[weight.abs()<=self.th_prune] = 0
                return weight
        else:
            return self.weight

    def forward(self, nx):  # Could be removed but this should be faster if th_prune==0
        quantized_weight = self.get_quantized_weights()
        nx = F.conv2d(
            nx,
            quantized_weight,
            None, self.stride,
            self.padding, self.dilation, self.groups
        )
        return nx


# -- Activation layers
class QuantizedActivation(nn.Module):
    def __init__(self, qfx=None):
        super().__init__()
        assert qfx is not None, "Quantization function is not set for QuantizedActivation"
        self.qfx = qfx

    def forward(self, nx):
        nx = self.qfx(nx)
        return nx
