import math
import transformers
import torch
import utils
import torch.nn as nn
import numpy as np


def prune(x, sparsity=0.5):

    #! there will be case for opt models, there is a reshape before the fc1,
    if len(x.shape)==2:
        x = x.unsqueeze(0)
    #* x shape: [batch_size, seq_len, hidden_dim]
    
    #! For each token or batch, find the threshold value (sort along hidden_dim axis)
    #! In reality, for decoding, batch and token will assumed be 1.
    thresh = torch.sort(torch.abs(x), dim=-1)[0][:, :, int(x.shape[-1] * sparsity)].unsqueeze(-1)

    # thresh, _ = torch.kthvalue(torch.abs(x),int(x.shape[-1] * sparsity))
    mask = torch.abs(x) >= thresh
    return x * mask

class ActPruner(torch.nn.Module):

    '''
        A class for pruning the activations. We only support (both sym. and asym.) per-token quantization
        for the activations.
    '''

    def __init__(self):
        super(ActPruner, self).__init__()
        self.sparsity = 0.
        self.mask = None
        self.annealr = 0.
        self.annealing = False
        self.annealer_cnter = 0


    def free(self):
        self.mask = None
        self.annealer_cnter = 0

    def forward(self, x):
        if self.sparsity > 0:
            x = prune(x, self.sparsity)
        return x

    def configure(self, sparsity, annealing=False, annealer = 0.0):
        self.sparsity=sparsity
        self.annealing=annealing
        self.annealer=annealer
        assert self.sparsity < 1 and self.sparsity >= 0, 'sparsity should be in [0, 1)'

class ActPruneWrapper(torch.nn.Module):
    '''
        This class is a wrapper for the activation pruning.
    '''

    def __init__(self, module:torch.nn.Linear, name=None):
        super(ActPruneWrapper, self).__init__()
        assert isinstance(module, torch.nn.Linear)
        if name:
            self.name = name
        else:
            self.name = 'nobody'
        self.module = module
        self.weight = module.weight
        self.bias = module.bias
        self.pruner = ActPruner()


    def extra_repr(self) -> str:
        str_ = f'Input Pruner Sparsity: {self.pruner.sparsity}'
        return str_

    def forward(self, x):
        x = self.pruner(x)
        x = self.module(x)
        return x


def add_actprune(module, name='', layers=[torch.nn.Linear,
                                          ActPruneWrapper,
                                          transformers.models.falcon.modeling_falcon.FalconLinear]):
    if isinstance(module, ActPruneWrapper):
        return
    for attr in dir(module):
        tmp = getattr(module, attr)
        if type(tmp) in layers:
            setattr(module, attr, ActPruneWrapper(tmp,name=attr))
        if type(tmp) == torch.nn.Sequential:
            replaced = []
            for i, child in enumerate(tmp.children()):
                if type(child) in layers:
                    replaced.append(ActPruneWrapper(child))
                else:
                    replaced.append(child)
            setattr(module, attr, torch.nn.Sequential(*replaced))
        if type(tmp) == torch.nn.ModuleList:
            replaced = []
            for i, child in enumerate(tmp.children()):
                if type(child) in layers:
                    replaced.append(ActPruneWrapper(child))
                else:
                    replaced.append(child)
            setattr(module, attr, torch.nn.ModuleList(replaced))
    for name1, child in module.named_children():
        add_actprune(child, name + '.' + name1 if name != '' else name1, layers)


def find_layers(module, layers=[torch.nn.Linear,
                                ActPruneWrapper], name=''):
    if type(module) in layers:
        return {name: module}
    res = {}
    for name1, child in module.named_children():
        res.update(find_layers(
            child, layers=layers, name=name + '.' + name1 if name != '' else name1
        ))
    return res


def disable_act_sparsity(module):
    for name, m in module.named_modules():
        if isinstance(m, ActPruneWrapper):
            m.pruner.sparsity = 0.0

#! For Lama2, supported names = [q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj]
def disable_act_sparsity_selective(module, names=[]):
    for name, m in module.named_modules():
        if isinstance(m, ActPruneWrapper):
            if m.name in names:
                m.pruner.sparsity = 0.0


def enable_act_sparsity(module, sparsity):
    for name, m in module.named_modules():
        if isinstance(m, ActPruneWrapper):
            m.pruner.sparsity = sparsity

def enable_act_sparsity_selective(module, sparsity, names=[]):
    for name, m in module.named_modules():
        if isinstance(m, ActPruneWrapper):
            if m.name in names:
                m.pruner.sparsity = sparsity


def spectral_entropy_stable(XTX, eps=1e-10):
    eigvals = torch.linalg.eigvals(XTX).real
    eigvals = eigvals[eigvals > eps]  # More stable threshold
    eigvals_normalized = eigvals / torch.sum(eigvals)
    entropy = -torch.sum(eigvals_normalized * torch.log(eigvals_normalized))
    return entropy



#! Below are codes for quantizer, that we want to add the experiments for checking the quantization + weight pruning + activation pruning
def quantize(x, scale, zero, maxq):
    q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
    return scale * (q - zero)

class Quantizer(nn.Module):

    def __init__(self, shape=1):
        super(Quantizer, self).__init__()
        self.register_buffer('maxq', torch.tensor(0))
        self.register_buffer('scale', torch.zeros(shape))
        self.register_buffer('zero', torch.zeros(shape))

    def configure(
            self,
            bits, perchannel=False, sym=True, 
            mse=False, norm=2.4, grid=100, maxshrink=.8,
            grouprows=1
        ):
        self.maxq = torch.tensor(2 ** bits - 1)
        self.perchannel = perchannel
        self.sym = sym
        self.mse = mse
        self.norm = norm
        self.grid = grid
        self.maxshrink = maxshrink 
        self.grouprows = grouprows

    def find_params(self, x, weight=False):
        dev = x.device
        self.maxq = self.maxq.to(dev)

        shape = x.shape
        if self.perchannel:
            if weight:
                x = x.flatten(1)
                if self.grouprows > 1: 
                    x = x.reshape((x.shape[0] // self.grouprows, -1))
            else:
                if len(shape) == 4:
                    x = x.permute([1, 0, 2, 3])
                    x = x.flatten(1)
                if len(shape) == 3:
                    x = x.reshape((-1, shape[-1])).t()
                if len(shape) == 2:
                    x = x.t()
        else:
            x = x.flatten().unsqueeze(0)

        tmp = torch.zeros(x.shape[0], device=dev)
        xmin = torch.minimum(x.min(1)[0], tmp)
        xmax = torch.maximum(x.max(1)[0], tmp)

        if self.sym:
            xmax = torch.maximum(torch.abs(xmin), xmax)
            tmp = xmin < 0
            if torch.any(tmp):
                xmin[tmp] = -xmax[tmp]
        tmp = (xmin == 0) & (xmax == 0)
        xmin[tmp] = -1
        xmax[tmp] = +1

        self.scale = (xmax - xmin) / self.maxq
        if self.sym:
            self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2)
        else:
            self.zero = torch.round(-xmin / self.scale)

        if self.mse:
            best = torch.full([x.shape[0]], float('inf'), device=dev)
            for i in range(int(self.maxshrink * self.grid)):
                p = 1 - i / self.grid 
                xmin1 = p * xmin
                xmax1 = p * xmax
                scale1 = (xmax1 - xmin1) / self.maxq
                zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero
                q = quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq)
                q -= x
                q.abs_()
                q.pow_(self.norm)
                err = torch.sum(q, 1)
                tmp = err < best
                if torch.any(tmp):
                    best[tmp] = err[tmp]
                    self.scale[tmp] = scale1[tmp]
                    self.zero[tmp] = zero1[tmp]
        if not self.perchannel:
            if weight:
                tmp = shape[0]
            else:
                tmp = shape[1] if len(shape) != 3 else shape[2]
            self.scale = self.scale.repeat(tmp)
            self.zero = self.zero.repeat(tmp)

        if weight:
            if self.grouprows > 1:
                self.scale = self.scale.unsqueeze(1).repeat(1, self.grouprows)
                self.zero = self.zero.unsqueeze(1).repeat(1, self.grouprows)
            shape = [-1] + [1] * (len(shape) - 1)
            self.scale = self.scale.reshape(shape)
            self.zero = self.zero.reshape(shape)
            return
        if len(shape) == 4:
            self.scale = self.scale.reshape((1, -1, 1, 1))
            self.zero = self.zero.reshape((1, -1, 1, 1))
        if len(shape) == 3:
            self.scale = self.scale.reshape((1, 1, -1))
            self.zero = self.zero.reshape((1, 1, -1)) 
        if len(shape) == 2:
            self.scale = self.scale.unsqueeze(0)
            self.zero = self.zero.unsqueeze(0)

    def quantize(self, x):
        if self.ready():
            return quantize(x, self.scale, self.zero, self.maxq)
        return x

    def enabled(self):
        return self.maxq > 0

    def ready(self):
        return torch.all(self.scale != 0)