import torch
import numpy as np
import torch.nn as nn
import copy
# import generator


def reset_BN(net):
    with torch.no_grad():
        for m in net.modules():
            if isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                m.reset_running_stats()


class Pruner:
    def __init__(self, masked_parameters):
        self.masked_parameters = list(masked_parameters)
        self.scores = {}
        self.link_scores = {}
        self.index_mapping = []

    def score(self, model, loss, dataloader, device):
        raise NotImplementedError

    def _global_mask(self, sparsity):
        r"""Updates masks of model with scores by sparsity level globally.
        """
        # # Set score for masked parameters to -inf
        # for mask, param in self.masked_parameters:
        #     score = self.scores[id(param)]
        #     score[mask == 0.0] = -np.inf

        # Threshold scores
        global_scores = torch.cat([torch.flatten(v) for v in self.scores.values()])
        k = int((1.0 - sparsity) * global_scores.numel())
        if not k < 1:
            threshold, _ = torch.kthvalue(global_scores, k)
            for mask, param in self.masked_parameters:
                score = self.scores[id(param)]
                zero = torch.tensor([0.]).to(mask.device)
                one = torch.tensor([1.]).to(mask.device)
                mask.copy_(torch.where(score <= threshold, zero, one))


    def _global_mask_AB_linking(self, sparsity):
        print(len(self.index_mapping))
        global_scores = torch.cat([self.index_mapping[i][1].unsqueeze(0) for i in range(len(self.index_mapping))])
        # global_scores = [self.index_mapping[i][1] for i in range(len(self.index_mapping))]
        k = int((1.0 - sparsity) * global_scores.numel())
        # k = int((1.0 - sparsity) * global_scores.numel())
        if not k < 1:
            threshold, _ = torch.kthvalue(global_scores, k)

            num = 0

            for i in range(len(self.index_mapping)):
                if self.index_mapping[i][1] < threshold:
                    num += 1
                    mat_index = self.index_mapping[i][0][0]
                    row_index = self.index_mapping[i][0][1]
                    column_index = self.index_mapping[i][0][2]

                    self.masked_parameters[mat_index*2][0][row_index][:].zero_()
                    self.masked_parameters[mat_index*2+1][0][:][column_index].zero_()

                    ss = 1

            print('how many is pruned', num)
            print('pruning ratio', num / len(self.index_mapping))

            ddd = 1




    def _local_mask(self, sparsity):
        r"""Updates masks of model with scores by sparsity level parameter-wise.
        """
        for mask, param in self.masked_parameters:
            score = self.scores[id(param)]
            k = int((1.0 - sparsity) * score.numel())
            if not k < 1:
                threshold, _ = torch.kthvalue(torch.flatten(score), k)
                zero = torch.tensor([0.]).to(mask.device)
                one = torch.tensor([1.]).to(mask.device)
                mask.copy_(torch.where(score <= threshold, zero, one))

    def mask(self, sparsity, scope):
        r"""Updates masks of model with scores by sparsity according to scope.
        """
        if scope == 'global':
            self._global_mask(sparsity)
        if scope == 'local':
            self._local_mask(sparsity)

    @torch.no_grad()
    def apply_mask(self):
        r"""Applies mask to prunable parameters.
        """
        for mask, param in self.masked_parameters:
            param.mul_(mask)

    def alpha_mask(self, alpha):
        r"""Set all masks to alpha in model.
        """
        for mask, _ in self.masked_parameters:
            mask.fill_(alpha)

    # Based on https://github.com/facebookresearch/open_lth/blob/master/utils/tensor_utils.py#L43
    def shuffle(self):
        for mask, param in self.masked_parameters:
            shape = mask.shape
            perm = torch.randperm(mask.nelement())
            mask = mask.reshape(-1)[perm].reshape(shape)

    def invert(self):
        for v in self.scores.values():
            v.div_(v ** 2)

    def stats(self):
        r"""Returns remaining and total number of prunable parameters.
        """
        remaining_params, total_params = 0, 0
        for mask, _ in self.masked_parameters:
            remaining_params += mask.detach().cpu().numpy().sum()
            total_params += mask.numel()
        return remaining_params, total_params


class Rand(Pruner):
    def __init__(self, masked_parameters, *args, **kwargs):
        super(Rand, self).__init__(masked_parameters)

    def score(self, model, loss, dataloader, device, *args, **kwargs):
        for _, p in self.masked_parameters:
            self.scores[id(p)] = torch.randn_like(p)


class Mag(Pruner):
    def __init__(self, masked_parameters, *args, **kwargs):
        super(Mag, self).__init__(masked_parameters)

    def score(self, model, loss, dataloader, device, *args, **kwargs):
        for _, p in self.masked_parameters:
            self.scores[id(p)] = torch.clone(p.data).detach().abs_()


# Based on https://github.com/mi-lad/snip/blob/master/snip.py#L18
class SNIP(Pruner):
    def __init__(self, masked_parameters, *args, **kwargs):
        super(SNIP, self).__init__(masked_parameters)

    def score(self, model, loss, dataloader, device, *args, **kwargs):

        # allow masks to have gradient
        for m, _ in self.masked_parameters:
            m.requires_grad = True

        # compute gradient
        for batch_idx, (data, target) in enumerate(dataloader):
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss(output, target).backward()

        # calculate score |g * theta|
        for m, p in self.masked_parameters:
            self.scores[id(p)] = torch.clone(m.grad).detach().abs_()
            p.grad.data.zero_()
            m.grad.data.zero_()
            m.requires_grad = False

        # normalize score
        all_scores = torch.cat([torch.flatten(v) for v in self.scores.values()])
        norm = torch.sum(all_scores)
        for _, p in self.masked_parameters:
            self.scores[id(p)].div_(norm)


# Based on https://github.com/mi-lad/snip/blob/master/snip.py#L18
class IterSNIP(Pruner):
    def __init__(self, masked_parameters, *args, **kwargs):
        super(IterSNIP, self).__init__(masked_parameters)

    def score(self, model, loss, dataloader, device, *args, **kwargs):

        # allow masks to have gradient
        for m, _ in self.masked_parameters:
            m.requires_grad = True

        # compute gradient
        for batch_idx, (data, target) in enumerate(dataloader):
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss(output, target).backward()

        # calculate score |g * theta|
        for m, p in self.masked_parameters:
            self.scores[id(p)] = torch.clone(p.grad * p).detach().abs_()
            p.grad.data.zero_()
            m.grad.data.zero_()
            m.requires_grad = False

        # normalize score
        all_scores = torch.cat([torch.flatten(v) for v in self.scores.values()])
        norm = torch.sum(all_scores)
        for _, p in self.masked_parameters:
            self.scores[id(p)].div_(norm)


    def score_llm(self, model_instance, model, dataloader):   # make sure model_instance is wrapped by transformers.Trainer

        for m, _ in self.masked_parameters:
            m.requires_grad = True

        # for m, p in self.masked_parameters:
        #     m.requires_grad = True
        #     p.requires_grad = True

        # with torch.autograd_detect_anomaly():
        for step, inputs in enumerate(dataloader):
            loss = model_instance.training_step(model, inputs)  # backward is conducted in this line


        for m, p in self.masked_parameters:
            self.scores[id(p)] = torch.clone(p.grad * p).detach().abs_()
            # print(self.scores[id(p)].size())
            # print(id(p))
            p.grad.data.zero_()
            m.grad.data.zero_()
            m.requires_grad = False



        # normalize score
        all_scores = torch.cat([torch.flatten(v) for v in self.scores.values()])
        norm = torch.sum(all_scores)
        for _, p in self.masked_parameters:
            self.scores[id(p)].div_(norm)


    def sec_score_llm(self, model_instance, model, dataloader):   # make sure model_instance is wrapped by transformers.Trainer

        for m, _ in self.masked_parameters:
            m.requires_grad = True

        # for m, p in self.masked_parameters:
        #     m.requires_grad = True
        #     p.requires_grad = True

        # with torch.autograd_detect_anomaly():
        for step, inputs in enumerate(dataloader):
            loss = model_instance.local_trainer.training_step(model, inputs)  # backward is conducted in this line

            for m, p in self.masked_parameters:
                p.grad_64 = p.grad.double()
                p.grad_64 = p.grad_64 * p.grad_64 / len(dataloader)
                if hasattr(p, 'acc_grad'):
                    p.acc_grad_64 += p.grad_64
                else:
                    p.acc_grad_64 = copy.deepcopy(p.grad_64)
                # print('xxx')
        model.zero_grad()
        del loss.grad

        for step, inputs in enumerate(dataloader):
            loss = model_instance.local_trainer.training_step(model, inputs)  # backward is conducted in this line



        for m, p in self.masked_parameters:
            self.scores[id(p)] = torch.clone(p.grad * p - 0.5 * p * p.acc_grad_64 * p).detach().abs_()

            # self.scores[id(p)] = torch.clone(p.grad * p).detach().abs_()
            # print(self.scores[id(p)].size())
            # print(id(p))
            p.grad.data.zero_()
            m.grad.data.zero_()
            m.requires_grad = False



        # normalize score
        all_scores = torch.cat([torch.flatten(v) for v in self.scores.values()])
        norm = torch.sum(all_scores)
        for _, p in self.masked_parameters:
            self.scores[id(p)].div_(norm)


    def score_llm_AB_linking(self, model_instance, model, dataloader):

        for m, _ in self.masked_parameters:
            m.requires_grad = True

        for step, inputs in enumerate(dataloader):
            loss = model_instance.local_trainer.training_step(model, inputs) # backward is conducted in this line

        # loss = model_instance.local_trainer.training_step(model, inputs)   # backward is conducted in this line

        for m, p in self.masked_parameters:

            self.scores[id(p)] = torch.clone(p.grad * p).detach().abs_()

            # structured saliency score
            if p.size()[0] < p.size()[1]:
                self.scores[id(p)] = torch.sum(self.scores[id(p)], dim=1)
            else:
                self.scores[id(p)] = torch.sum(self.scores[id(p)], dim=0)

            p.grad.data.zero_()
            m.grad.data.zero_()
            m.requires_grad = False


        # pair-wise value calculation, do broadcast for every two tensors
        self.link_scores = self.compute_iterative_tensors(self.scores)


        # normalize score
        all_scores = torch.cat([torch.flatten(v) for v in self.link_scores.values()])
        norm = torch.sum(all_scores)
        for k in self.link_scores.keys():
            self.link_scores[k].div_(norm)


        key_index = 0
        self.index_mapping = []

        # index_mapping[i][j] meaning:
        # i --> (i^th matrix, position_index, potision_index)
        # j --> j^th
        # self.index_mapping = []

        for matrix_index in range(len(self.link_scores)):

            # print('key of self.link_scores', f'tensor{key_index}_plus_{key_index + 1}')
            # print(matrix_index)

            for i in range(16):
                for j in range(i, 16):
                    self.index_mapping.append(((matrix_index, i, j), self.link_scores[f'tensor{key_index}_plus_{key_index+1}'][i][j]))

            key_index += 2


        # all_index_mapping.extend(index_mapping)
        #
        # vector = [value for (_, value) in index_mapping]
        #
        # all_vectors.extend(vector)

        aa = 0



    def compute_iterative_tensors(self, dict):
        result_dict = {}
        tensors = list(dict.values())

        for i in range(0, len(tensors), 2):
            result_key = f'tensor{i}_plus_{i+1}'
            # print(tensors.size())
            tensors[i] = tensors[i].reshape(1, tensors[i].size()[0])
            tensors[i+1] = tensors[i+1].reshape(1, tensors[i+1].size()[0])
            result = tensors[i].T + tensors[i+1]
            result_dict[result_key] = result

        return result_dict




# Based on https://github.com/alecwangcq/GraSP/blob/master/pruner/GraSP.py#L49
class GraSP(Pruner):
    def __init__(self, masked_parameters, *args, **kwargs):
        super(GraSP, self).__init__(masked_parameters)
        self.temp = 200
        self.eps = 1e-10

    def score(self, model, loss, dataloader, device, *args, **kwargs):

        # first gradient vector without computational graph
        stopped_grads = 0
        inputs = []

        for batch_idx, (data, target) in enumerate(dataloader):
            inputs.append((data.clone(), target.clone()))
            data, target = data.to(device), target.to(device)
            output = model(data) / self.temp
            L = loss(output, target)

            grads = torch.autograd.grad(L, [p for (_, p) in self.masked_parameters], create_graph=False)
            flatten_grads = torch.cat([g.reshape(-1) for g in grads if g is not None])
            stopped_grads += flatten_grads

        # second gradient vector with computational graph
        for batch_idx, (data, target) in enumerate(inputs):
            data, target = data.to(device), target.to(device)
            output = model(data) / self.temp
            L = loss(output, target)

            grads = torch.autograd.grad(L, [p for (_, p) in self.masked_parameters], create_graph=True)
            flatten_grads = torch.cat([g.reshape(-1) for g in grads if g is not None])

            gnorm = (stopped_grads * flatten_grads).sum()
            gnorm.backward()

        # calculate score Hg * theta (negate to remove top percent)
        for _, p in self.masked_parameters:
            self.scores[id(p)] = torch.clone(p.grad * p.data).detach()
            p.grad.data.zero_()

        # normalize score
        all_scores = torch.cat([torch.flatten(v) for v in self.scores.values()])
        norm = torch.abs(torch.sum(all_scores)) + self.eps
        for _, p in self.masked_parameters:
            self.scores[id(p)].div_(norm)


class SynFlow(Pruner):
    def __init__(self, masked_parameters, *args, **kwargs):
        super(SynFlow, self).__init__(masked_parameters)

    def score(self, model, loss, dataloader, device, *args, **kwargs):

        @torch.no_grad()
        def linearize(model):
            # model.double()
            signs = {}
            for name, param in model.state_dict().items():
                signs[name] = torch.sign(param)
                param.abs_()
            return signs

        @torch.no_grad()
        def nonlinearize(model, signs):
            # model.float()
            for name, param in model.state_dict().items():
                param.mul_(signs[name])

        signs = linearize(model)

        (data, _) = next(iter(dataloader))
        input_dim = list(data[0, :].shape)
        try:
            input = torch.ones([1] + input_dim).to(device)  # , dtype=torch.float64).to(device)
            output = model(input)
        except:
            input = torch.ones([1] + input_dim, dtype=torch.float64).to(device)
            output = model(input)
        torch.sum(output).backward()

        for _, p in self.masked_parameters:
            self.scores[id(p)] = torch.clone(p.grad * p).detach().abs_()
            p.grad.data.zero_()

        nonlinearize(model, signs)


# CLEAN version

class NTKSAP(Pruner):
    def __init__(self, masked_parameters, *args, **kwargs):
        super(NTKSAP, self).__init__(masked_parameters)
        self.epsilon = kwargs['ntksap_epsilon']
        self.R = kwargs['ntksap_R']

    def score(self, model, loss, dataloader, device, *args, **kwargs):

        def perturb(model_orig, model_copy):
            with torch.no_grad():
                for (m_orig, p_orig), (m_copy, p_copy) in zip(generator.masked_parameters(model_orig),
                                                              generator.masked_parameters(model_copy)):
                    p_copy.data = p_orig.data + self.epsilon * torch.randn_like(p_orig.data)

            for module, module_mod in zip(model_orig.modules(), model_copy.modules()):
                if isinstance(module, nn.BatchNorm2d):
                    with torch.no_grad():
                        module_mod.running_mean = module.running_mean
                        module_mod.running_var = module.running_var
                        module_mod.num_batches_tracked = module.num_batches_tracked

        for m, p in self.masked_parameters:
            m.requires_grad = True
            p.requires_grad = False

        # Copy a same model
        model_mod = copy.deepcopy(model)

        # Set model mod to evaluation mode
        model_mod.eval()

        # Make two models share the same weight masks
        for module, module_mod in zip(model.modules(), model_mod.modules()):
            if hasattr(module, 'weight_mask'):
                module_mod.weight_mask = module.weight_mask
            if isinstance(module, nn.BatchNorm2d):
                module.momentum = 1.0
                module_mod.momentum = 1.0

        for _ in range(self.R):
            for index, (data, _) in enumerate(dataloader):
                if isinstance(model, nn.DataParallel):
                    model.module._initialize_weights()
                else:
                    model._initialize_weights()
                input = torch.randn_like(data).to(device)

                reset_BN(model)
                with torch.no_grad():
                    output_orig = model(input)

                model.eval()
                # Compute the true graph using eval mode
                output_orig = model(input)
                perturb(model, model_mod)
                output_mod = model_mod(input)
                jac_approx = (torch.norm(output_orig - output_mod, dim=-1) ** 2).sum()
                jac_approx.backward()
                model.train()

        for m, p in self.masked_parameters:
            self.scores[id(p)] = torch.clone(m.grad * (m != 0)).detach().abs_()
            m.grad.data.zero_()
            m.requires_grad = False
            p.requires_grad = True

        # Reset momentum of BatchNorm2d
        for module in model.modules():
            if isinstance(module, nn.BatchNorm2d):
                module.momentum = 0.1

        del model_mod