import torch_pruning as tp
import torch
from torch import nn

import scipy
import numpy as np

import typing
from torch_pruning import function


from pruning import utils


class ZCAImportance(tp.importance.Importance):
    def __call__(self, group, **kwargs):
        for dep, idxs in group:
            layer = dep.target.module
            if isinstance(layer, (nn.Linear, nn.Conv2d)) and (dep.handler.__name__ == 'prune_in_channels' or dep.handler.__name__ == 'prune_in_features'):
                if not hasattr(layer, 'C'): raise NotImplementedError("To compute ZCA importance each parameterized layer should hold a cross-correlation matrix 'C'")
                if hasattr(layer, 'imp'): return layer.imp

                D = utils.mpow(layer.C.clone(), -1/2)
                D = torch.diag(D)
                D = D * D
                D = 1/D
                
                if 'attn.proj' in dep.target._name:
                    num_heads = layer.imp // 64
                    D = D.reshape(num_heads, -1).mean(dim=0)
                    imps = []
                    for i in range(num_heads):
                        imp = D.argsort().argsort() + ( i/num_heads )       # add a small value, such that precision errors dont mix up indices
                        imps.append(imp)
                    local_imp = torch.cat(imps)
                else:
                    D = utils.mpow(layer.C.clone(), -1/2)
                    D = torch.diag(D)
                    D = D * D
                    D = 1/D
                    local_imp = D
                setattr(layer, 'imp', local_imp)
                return local_imp
        return None


class VarImportance(tp.importance.Importance):
    def __call__(self, group, **kwargs):
        for dep, idxs in group:
            layer = dep.target.module
            if isinstance(layer, (nn.Linear, nn.Conv2d)) and (dep.handler.__name__ == 'prune_in_channels' or dep.handler.__name__ == 'prune_in_features'):
                if not hasattr(layer, 'D'): raise NotImplementedError("To compute ZCA importance each parameterized layer should hold a diagonal matrix 'D'")

                D = layer.D

                if 'attn.proj' in dep.target._name:
                    positions_in_layout = layer.order.clone()
                    head_indices = positions_in_layout // 64
                    num_heads = len(positions_in_layout) // 64

                    scored = torch.empty_like(D)

                    for h in range(num_heads):
                        idx = (head_indices == h).nonzero(as_tuple=True)[0]
                        d_head = D[idx].clone()

                        d_head = d_head / torch.sum(d_head)
                        scored_head = utils.rev_cumsum(d_head)

                        scored[idx] = scored_head

                    scored_model_layout = scored[layer.pivots]

                    return scored_model_layout
 
                D = D / torch.sum(D)
                score = utils.rev_cumsum(D)
                score = score[layer.pivots]  # revert ordering 
                return score
        return None
