import torch
from torch import nn

from pruning import utils

import numpy as np

import abc

class Prunable(nn.Module):
    def __init__(self, layer, pivots=None, prune_by_removal=True, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

        self.layer = layer

        self.pre_pivots = pivots[0] if pivots is not None else None
        self.post_pivots = pivots[1] if pivots is not None else None

        self.params = list([p.data.clone() for p in self.layer.parameters()])

        self.prune_by_removal = prune_by_removal

        self.pre_dim=None

    @abc.abstractmethod
    def prune_inputs(self, indices):
        pass

    @abc.abstractmethod
    def prune_outputs(self, indices):
        pass

    def __refresh_layer_in_out_map(self):
        if isinstance(self.layer, nn.Linear): 
            self.layer.in_features  = self.layer.weight.shape[1]
            self.layer.out_features = self.layer.weight.shape[0]
        elif isinstance(self.layer, nn.Conv2d):
            self.layer.in_channels  = self.layer.weight.shape[1]
            self.layer.out_channels = self.layer.weight.shape[0]
        else:
            raise NotImplementedError()

    def _prune_weights(self, indices, prune_outputs=False):
        if self.prune_by_removal:

            layer_sums = torch.sum(torch.abs(self.layer.weight), dim=[int(prune_outputs), 2, 3] if isinstance(self.layer, nn.Conv2d) else [int(prune_outputs)])

            keep_indices = torch.nonzero(layer_sums).flatten()

            if prune_outputs:
                self.layer.weight.data = self.layer.weight[keep_indices]
                if self.layer.bias is not None:
                    self.layer.bias.data = self.layer.bias[keep_indices]
            else:
                if self.pre_dim is not None and self.layer.weight.shape[1] != self.pre_dim:
                    W = self.layer.weight.data.clone()

                    n_prev_output_channels = self.pre_dim
                    squared_kernel_size = W.shape[1] // n_prev_output_channels

                    W = torch.reshape(W, (len(W), -1, squared_kernel_size))

                    layer_sums = torch.sum(torch.abs(W), dim=[0, 2])

                    keep_indices = torch.nonzero(layer_sums).flatten()

                    W = W[:, keep_indices]
                    W = torch.flatten(W, start_dim=1, end_dim=2)
                    self.layer.weight.data = W
                else:

                    self.layer.weight.data = self.layer.weight[:, keep_indices]
            self.__refresh_layer_in_out_map()

        return keep_indices if self.prune_by_removal else None

    def reset(self):
        self.layer.weight.data = self.params[0].clone()
        if len(self.params) > 1:
            self.layer.bias.data = self.params[1].clone()
        else: 
            self.layer.bias = None

        if self.prune_by_removal:
            self.__refresh_layer_in_out_map()

    def forward(self, x):
        return self.layer.forward(x)
    

class ElementwisePruningLayer(Prunable):
    def __init__(self, layer, pivots=None, pre_dim=None, *args, **kwargs) -> None:
        super().__init__(layer, pivots, *args, **kwargs)
        self.pre_dim = pre_dim

    def prune_inputs(self, indices):
        if len(indices) >= self.layer.weight.shape[1]:
            indices = indices[1:]
        if self.pre_pivots is not None:
            indices = self.pre_pivots[indices]

        W = self.layer.weight.data.clone()
        squared_kernel_size = None
        if self.pre_dim is not None and self.layer.weight.shape[1] != self.pre_dim:
            n_prev_output_channels = self.pre_dim
            squared_kernel_size = W.shape[1] // n_prev_output_channels
            print(W.shape, self.pre_dim)
            W = torch.reshape(W, (len(W), -1, squared_kernel_size))
            W[:, indices] = 0.
            W = torch.flatten(W, start_dim=1, end_dim=2)
        else:
            W[:, indices] = 0.
        self.layer.weight.data = W

        self._prune_weights(indices, prune_outputs=False)
        
    def prune_outputs(self, indices):
        if len(indices) >= len(self.layer.weight):
            indices = indices[1:]
        if self.post_pivots is not None:
            indices = self.post_pivots[indices]

        self.layer.weight.data[indices] = 0.
        if self.layer.bias is not None:
            self.layer.bias.data[indices] = 0.

        self._prune_weights(indices, prune_outputs=True)


class FactPrunable(Prunable):
    def __init__(self, layer, R_inv, mean=None, *args, **kwargs) -> None:
        super().__init__(layer, *args, **kwargs)
        self.R_inv = R_inv
        self.mean = mean

        self.R = torch.inverse(R_inv)

        self.pre_dim = len(self.R)

        if self.mean is not None:
            self.Wmu = self.forward_mu()

    def prune_inputs(self, indices):
        indices_orig = indices.clone()
        if len(indices) >= len(self.R):
            indices = indices[1:]

        R_pruned = self.R.clone()
        R_inv_pruned = self.R_inv.clone()

        R_pruned[indices] = 0.
        R_inv_pruned[:, indices] = 0.

        R_comb = R_inv_pruned @ R_pruned

        if self.pre_pivots is not None:
            R_comb = R_comb[self.pre_pivots][:, self.pre_pivots]
        utils.combine_weights(self.layer, R_comb)

        if self.pre_pivots is not None: 
            indices = self.pre_pivots[indices_orig]
            if len(indices) >= len(self.R):
                indices = indices[1:]

        return self._prune_weights(indices, prune_outputs=False)

    def prune_outputs(self, indices):
        if self.post_pivots is not None:
            indices = self.post_pivots[indices]

        if len(indices) >= len(self.layer.weight):
            indices = indices[1:]

        self.layer.weight.data[indices] = 0.
        if self.layer.bias is not None:
            self.layer.bias.data[indices] = 0.

        return self._prune_weights(indices, prune_outputs=True)


    def forward_mu(self):
        W = self.layer.weight.clone()
        if isinstance(self.layer, nn.Conv2d):
            Wmu = nn.functional.conv2d(
                self.mean[None, ...], W,
                stride=self.layer.stride,
                padding=self.layer.padding,
                dilation=self.layer.padding,
                groups=self.layer.groups if self.layer.groups is not None else None
            ).squeeze()
        elif isinstance(self.layer, nn.Linear):
            Wmu = W @ self.mean
        return Wmu


class FactPrunableDemeanConv2d(FactPrunable):
    def __init__(self, layer, R_inv, *args, **kwargs) -> None:
        super().__init__(layer, R_inv, *args, **kwargs)
        self.bias_term = torch.zeros_like(self.Wmu)

    def prune_inputs(self, indices):
        super().prune_inputs(indices)

        WRmu = self.forward_mu()
        self.bias_term = self.Wmu - WRmu

    def prune_outputs(self, indices):
        keep_indices = super().prune_outputs(indices)

        if keep_indices is not None:
            self.bias_term = self.bias_term[keep_indices]

    def reset(self):
        super().reset()
        self.bias_term = torch.zeros_like(self.Wmu)

    def forward(self, x):
        x = super().forward(x)
        x = x + self.bias_term
        return x


class FactPrunableLayer(FactPrunable):
    def __init__(self, layer, R_inv, mean=None, *args, **kwargs) -> None:
        if isinstance(layer, nn.Linear) and mean is not None and len(mean.shape) > 1:
            mean = mean.flatten()
        super().__init__(layer, R_inv, mean, *args, **kwargs)

    def prune_inputs(self, indices):
        super().prune_inputs(indices)

        if self.mean is not None:
            WRmu = self.forward_mu()

            bias_term = self.Wmu - WRmu
            
            if self.layer.bias is not None: self.layer.bias.data = self.layer.bias.data + bias_term
            else: self.layer.bias = nn.Parameter(bias_term)
