from numpy.core.fromnumeric import shape
import torch
import torch.nn as nn
from copy import deepcopy
import sys
import torch.nn.functional as F
from torch.nn.parameter import Parameter
import numpy as np
from numpy import mean
from Layers.common import initialize_conv
from Layers.layers import channel_prune, fake_channel_prune
from Layers import layers
from Models.cifar_shufflenet import ShuffleNet as cifar_shufflenet
from Models import cifar_mobilenetv2, imagenet_mobilenetv2
from tqdm import tqdm
from math import sqrt
from Utils import metrics
from Utils.generator import pointwise
import pickle
import math
import random


class Pruner:
    def __init__(self, masked_parameters, skip_last, prune_pw_only):
        self.masked_parameters = list(masked_parameters)
        self.skip_last = skip_last
        self.prune_pw_only = prune_pw_only
        if self.skip_last:
            self.masked_parameters = self.masked_parameters[:-1]
        self.scores = {}
        self.iter = 0
        self.group_number_list = None
        self.total_score = []
        self.sparsity_path = []

    def score(self, model, loss, dataloader, device):
        raise NotImplementedError

    def rescore(self):
        self.score(self.model, self.loss, self.dataloader, self.device, self.double)

    def _global_mask(self, sparsity):
        r"""Updates masks of model with scores by sparsity level globally.
        """
        # Set score for masked parameters to -inf
        # Threshold scores
        global_scores = torch.cat([torch.flatten(v) for v in self.scores.values()])
        k = int((1.0 - sparsity) * global_scores.numel())
        if not k < 1:
            threshold, _ = torch.kthvalue(global_scores, k)
            for mask, param in self.masked_parameters:
                score = self.scores[id(param)]
                zero = torch.tensor([0.]).to(mask.device)
                one = torch.tensor([1.]).to(mask.device)
                mask.copy_(torch.where(score <= threshold, zero, one))
                # print('in mask: ', mask.detach().sum().item()/mask.numel())
                # print(id(mask))

    def _random_filter_mask(self, sparsity_list, sparsity, verbose=False):
        r"""Updates masks of model with scores by sparsity level filter-wise.
        The last layer keeps all parameters in this setting.
        """
        for i, (mask, param) in enumerate(self.masked_parameters):
            filter_len = mask.shape[0] * mask.shape[1]
            keep_filter = int(filter_len * sparsity_list[i])
            perm = torch.randperm(filter_len).reshape(mask.shape[0], mask.shape[1])
            mask.zero_()
            mask[perm <= keep_filter, ...] = 1

    def random_global_mask(self, sparsity_list, sparsity, verbose=False):
        for i, (mask, param) in enumerate(self.masked_parameters):
            self.scores[id(param)] = torch.zeros_like(param).to(param.device)
            weight_num = mask.numel()
            keep_filter = round(weight_num * sparsity_list[i])
            perm = torch.randperm(weight_num).reshape(mask.shape)
            mask.zero_()
            mask[perm <= keep_filter] = 1

    @torch.no_grad()
    def sum_score(self):
        total_scores = 0
        for mask, param in self.masked_parameters:
            score = self.scores[id(param)]
            if score.shape != mask.shape:
                score = score.sum(dim=[2, 3], keepdim=True)
            temp_score = (mask.to(score.device) * score).sum().item()
            total_scores += temp_score
            # print(
            #     f'sparsity = {mask.sum().item()/mask.numel()}, param = {mask.sum().item()}, score = {temp_score}'
            # )
        return total_scores

    def pregrouping(self, sparsity_list, sparsity, trials=500, eps=1e-4, avg_times=3, relax_factor=1.0, verbose=False):
        best_score = self.sum_score()
        if verbose:
            print(f'relax factor = {relax_factor}')

        filter_scores = []
        for v in self.scores.values():
            assert len(v.shape) == 4
            temp_score = v.sum(dim=[2, 3]).detach().cpu().numpy()
            filter_scores.append(temp_score)
        choices = []
        probabilitys = []
        masks = []
        for l in range(len(filter_scores)):
            choice = []
            probability = []
            mask = []
            length = filter_scores[l].shape[0]
            width = filter_scores[l].shape[1]
            for g in range(1, min(length, width) + 1):
                if length % g == 0 and width % g == 0:
                    temp_mask = np.zeros_like(filter_scores[l])
                    l_step = length // g
                    w_step = width // g
                    for i in range(g):
                        temp_mask[i * l_step:(i + 1) * l_step, i * w_step:(i + 1) * w_step] = 1
                    prob = -abs(1 / g - sparsity_list[l])
                    choice.append(g)
                    probability.append(prob)
                    mask.append(temp_mask)
            choices.append(choice)
            probability = np.array(probability).astype(np.longdouble)
            probability = np.exp((probability - probability.max()) * 10 / sparsity)
            probability = np.array([float(x) for x in probability])
            probability = probability / np.sum(probability)
            probabilitys.append(probability)
            masks.append(mask)

        used_plan = []
        best_group_number_list = []
        best_score = 0
        for_range = range(trials)
        if verbose:
            for_range = tqdm(for_range)
        for trial in for_range:
            if len(used_plan) >= trials * 0.5:
                break
            group_number_list = []
            remaining_param, total_param = 0, 0
            remaining_param_list = []
            for mask, _ in self.masked_parameters:
                mask.fill_(1)
            for l, (mask, param) in enumerate(self.masked_parameters):
                choice = np.random.choice(len(choices[l]), 1, p=list(probabilitys[l]))[0]
                group_number_list.append(choices[l][choice])
                with torch.no_grad():
                    mask.copy_(torch.Tensor(np.expand_dims(masks[l][choice], (2, 3)))).to(mask.device)
                    remaining_param += mask.sum().item()
                    remaining_param_list.append(int(mask.sum().item()))
                    total_param += mask.numel()
                initialize_conv(param, mask)
            if group_number_list not in used_plan:
                if remaining_param / total_param < sparsity * relax_factor:
                    used_plan.append(deepcopy(group_number_list))
                    score = 0
                    for i in range(avg_times):
                        self.model._initialize_weights(sparse_init=True)
                        self.rescore()
                        score += self.sum_score()
                    score /= avg_times
                    if score > best_score:
                        best_score = score
                        best_group_number_list = group_number_list
                    if verbose:
                        info = f'param sparsity={round(remaining_param/total_param, 5)}'
                        print(info + f', target sparsity ={round(sparsity, 5)}, length of candidates = {len(used_plan)}, score = {round(score, 5)}')
        if len(used_plan) == 0:
            self.pregrouping(sparsity_list=sparsity_list, sparsity=sparsity, verbose=verbose, relax_factor=relax_factor * 1.2)
            return
        if len(used_plan) < trials * 0.5:
            print(f'WARNING: no enough candidates evaluated: should be {trials * 0.5}, but only {len(used_plan)}')
        self.group_number_list = best_group_number_list

    def precropping(self, sparsity_list, sparsity, verbose=False):
        if verbose:
            print(sparsity_list)

        # for VGG16BN only, hack
        # avg = True
        # if avg:
        #     assert len(sparsity_list) == 13
        #     new_sparsity_lisy = [0 for i in range(13)]
        #     sep = [0, 2, 4, 7, 10, 13]
        #     stage = 0
        #     for i in range(13):
        #         if i + 1 > sep[stage]:
        #             stage += 1
        #         new_sparsity_lisy[i] = mean(sparsity_list[sep[stage - 1]:sep[stage]])
        #     sparsity_list = new_sparsity_lisy
        #     print(sparsity_list)

        assert len(self.model.pruned_types) == len(sparsity_list)
        pruned_out_channels, pruned_in_channels = {}, {}

        prev_pruned_out = 0
        prev_max_out = 0
        for index, module in enumerate(self.model.pruned_types):
            in_channels = module.conv.in_channels
            out_channels = module.conv.out_channels
            if self.model.pruned_types[module] == 'in':
                sqrt_target_sparse = sqrt(sparsity_list[index])
                pruned_in_channels[module] = round((1 - sqrt_target_sparse) * in_channels)
                pruned_out_channels[module] = round((1 - sqrt_target_sparse) * out_channels)
                if in_channels - pruned_in_channels[module] > prev_max_out:
                    pruned_in_channels[module] = in_channels - prev_max_out
                # pruned_out_channels[module] = out_channels - round(
                #     (in_channels * out_channels * target_sparse) / (in_channels - pruned_in_channels[module]))
                prev_pruned_out = pruned_out_channels[module]
                prev_max_out = max(0, out_channels - prev_pruned_out)
            elif self.model.pruned_types[module] == 'out':
                target_sparse = sparsity_list[index]
                pruned_in_channels[module] = prev_pruned_out
                pruned_out_channels[module] = out_channels - int(target_sparse * in_channels * out_channels / (in_channels - prev_pruned_out))
                prev_pruned_out = pruned_out_channels[module]
                prev_max_out = max(0, out_channels - prev_pruned_out)
            elif self.model.pruned_types[module] == 'only_out':
                target_sparse = sqrt(sparsity_list[index])
                pruned_in_channels[module] = 0
                pruned_out_channels[module] = round((1 - target_sparse) * out_channels)
            elif self.model.pruned_types[module] == 'vgg_out':
                target_sparse = sqrt(sparsity_list[index])
                pruned_in_channels[module] = prev_pruned_out
                pruned_out_channels[module] = round((1 - target_sparse) * out_channels)
                prev_pruned_out = pruned_out_channels[module]
                prev_max_out = max(0, out_channels - prev_pruned_out)
            elif self.model.pruned_types[module] == 'skip':
                pruned_in_channels[module] = 0
                pruned_out_channels[module] = 0
            elif self.model.pruned_types[module] == 'nonresidual_in':
                target_sparse = sqrt(sparsity_list[index])
                pruned_in_channels[module] = in_channels
                pruned_out_channels[module] = round((1 - target_sparse) * out_channels)
                prev_pruned_out = pruned_out_channels[module]
                prev_max_out = max(0, out_channels - prev_pruned_out)
            else:
                raise ValueError(f'no prune type {self.model.pruned_types[module]}')
            print(self.model.pruned_types[module], sparsity_list[index])
            print(in_channels, out_channels, pruned_in_channels[module], pruned_out_channels[module])

            # avoid pruned out the whole layer
            if pruned_in_channels[module] == in_channels:
                pruned_in_channels[module] -= 1
            if pruned_out_channels[module] == out_channels:
                pruned_out_channels[module] -= 1
                prev_pruned_out -= 1
        channel_prune(self.model, pruned_in_channels, pruned_out_channels)

    def mask(self, sparsity, scope, verbose=False):
        r"""Updates masks of model with scores by sparsity according to scope.
        """
        if scope == 'global':
            self._global_mask(sparsity)
        elif scope in ['random_weight', 'random_filter']:
            if scope == 'random_weight':
                self.random_global_mask(self.sparsity_list, sparsity, verbose=verbose)
            elif scope == 'random_filter':
                self._random_filter_mask(self.sparsity_list, sparsity, verbose=verbose)
        elif scope in ['pregrouping', 'precropping']:
            if not hasattr(self, 'sparsity_list'):
                if isinstance(self, SynFlow):
                    prune_iter = 100
                elif isinstance(self, SNIP) or isinstance(self, GraSP):
                    prune_iter = 1
                assert sparsity > 0
                for_range = range(prune_iter)
                if verbose:
                    for_range = tqdm(for_range)
                for i in for_range:
                    sparse = 1.0 - (1.0 - sparsity) * ((i + 1) / prune_iter)
                    self.mask(sparse, 'global')
                    self.rescore()
                self.sparsity_list = []
                for mask, param in self.masked_parameters:
                    self.sparsity_list.append(mask.sum().item() / mask.numel())
                if verbose:
                    print(self.sparsity_list)
            if scope == 'pregrouping':
                self.pregrouping(self.sparsity_list, sparsity, verbose=verbose)
            elif scope == 'precropping':
                self.precropping(self.sparsity_list, sparsity, verbose=verbose)
        else:
            raise NotImplementedError
        # total_scores = self.sum_score()
        # self.total_score.append(float(total_scores))
        # self.sparsity_path.append(sparsity)

    @torch.no_grad()
    def apply_mask(self):
        r"""Applies mask to prunable parameters.
        """
        for mask, param in self.masked_parameters:
            param.mul_(mask)

    def alpha_mask(self, alpha):
        r"""Set all masks to alpha in model.
        """
        for mask, _ in self.masked_parameters:
            mask.fill_(alpha)

    # Based on https://github.com/facebookresearch/open_lth/blob/master/utils/tensor_utils.py#L43
    def shuffle(self, filter=False):
        for mask, param in self.masked_parameters:
            shape = mask.shape
            if filter:
                mask.reshape(-1, shape[2], shape[3])
                perm = torch.randperm(mask.shape[0])
                mask.reshape(shape)
            else:
                perm = torch.randperm(mask.nelement())
                mask = mask.reshape(-1)[perm].reshape(shape)

    def stats(self):
        r"""Returns remaining and total number of prunable parameters.
        """
        remaining_params, total_params = 0, 0
        for mask, param in self.masked_parameters:
            remaining_param = mask.detach().cpu().numpy().sum()
            total_param = param.numel()
            # times filter if filter-wise pruning
            if mask.shape != param.shape:
                filter_size = param.shape[2] * param.shape[3]
                remaining_param *= filter_size
                total_param *= filter_size
                assert mask.numel() * filter_size == param.numel()
            remaining_params += remaining_param
            total_params += param.numel()
        return remaining_params, total_params

    def conv_stats(self):
        r"""Returns number of params in conv1x1 and conv3x3 respectively.
        """
        conv1x1, conv3x3 = 0, 0
        for mask, param in self.masked_parameters:
            shape = param.shape
            if len(shape) == 4:
                if shape[3] == shape[2] and shape[2] == 1:
                    conv1x1 += param.numel()
                elif shape[3] == shape[2] and shape[2] == 3:
                    conv3x3 += param.numel()
                else:
                    raise NotImplementedError
        return conv1x1, conv3x3

    def set_group_number(self):
        pointer = 0
        for name, module in self.model.named_modules():
            if (pointwise(module) and self.prune_pw_only) or (isinstance(module, layers.Conv2d) and not self.prune_pw_only):
                module.set_group_number(self.group_number_list[pointer])
                pointer += 1
            if self.skip_last and len(self.group_number_list) == pointer:
                break
        assert len(self.group_number_list) == pointer

    def prune_detatched(self):
        if isinstance(self.model, cifar_mobilenetv2.MobileNetV2) or isinstance(self.model, imagenet_mobilenetv2.MobileNetV2):
            layers.replace_activation(self.model, 'relu')
        # print(self.model)
        self.score(self.model, None, self.dataloader, self.device, False)
        for mask, param in self.masked_parameters:
            score = self.scores[id(param)]
            zero = torch.tensor([0.]).to(mask.device)
            one = torch.tensor([1.]).to(mask.device)
            mask.copy_(torch.where(torch.bitwise_and(score == 0, param != 0), zero, one))
            # print(torch.sum(score == 0).item(), torch.sum(param != 0).item(), torch.sum(torch.bitwise_and(score == 0, param != 0)).item())
        if isinstance(self.model, cifar_mobilenetv2.MobileNetV2) or isinstance(self.model, imagenet_mobilenetv2.MobileNetV2):
            layers.replace_activation(self.model, 'relu6')
        # print([m.sum().item() / m.numel() for m, p in self.masked_parameters])


class SynFlow(Pruner):
    def score(self, model, loss, dataloader, device, double=False):
        self.model = model
        self.loss = loss
        self.dataloader = dataloader
        self.device = device
        self.double = double

        @torch.no_grad()
        def linearize(model):
            signs = {}
            # use double for MobileNetv2
            # otherwise, numerial issue will arise
            if double:
                double_model(model)
            for name, param in model.state_dict().items():
                signs[name] = torch.sign(param)
                param.abs_()
            return signs

        @torch.no_grad()
        def nonlinearize(model, signs):
            for name, param in model.state_dict().items():
                param.mul_(signs[name])
            if double:
                undouble_model(model)

        signs = linearize(model)

        (data, _) = next(iter(dataloader))
        input_dim = list(data[0, :].shape)
        self.input_shape = input_dim
        if double:
            input = torch.ones([1] + input_dim, dtype=torch.float64).to(device)
        else:
            input = torch.ones([1] + input_dim).to(device)
        output = model(input)
        torch.sum(output).backward()

        for m, p in self.masked_parameters:
            self.scores[id(p)] = torch.clone(p.grad * p).detach().abs_()
            p.grad.data.zero_()

        nonlinearize(model, signs)


# Based on https://github.com/mi-lad/snip/blob/master/snip.py#L18
class SNIP(Pruner):
    def score(self, model, loss, dataloader, device, double=False):
        self.model = model
        self.loss = loss
        self.dataloader = dataloader
        self.device = device
        self.double = double

        # allow masks to have gradient
        for m, _ in self.masked_parameters:
            m.requires_grad = True

        if double:
            double_model(model)

        # compute gradient
        for batch_idx, (data, target) in enumerate(dataloader):
            data, target = data.to(device), target.to(device)
            if double:
                data = data.double()
            output = model(data)
            loss(output, target).backward()

        # calculate score |g * theta|
        for m, p in self.masked_parameters:
            self.scores[id(p)] = torch.clone(m.grad).detach().abs_()
            p.grad.data.zero_()
            m.grad.data.zero_()
            m.requires_grad = False

        # normalize score
        all_scores = torch.cat([torch.flatten(v) for v in self.scores.values()])
        norm = torch.sum(all_scores)
        for _, p in self.masked_parameters:
            self.scores[id(p)].div_(norm)

        if double:
            undouble_model(model)


# Based on https://github.com/alecwangcq/GraSP/blob/master/pruner/GraSP.py#L49
class GraSP(Pruner):
    def __init__(self, masked_parameters, skip_last, prune_pw_only):
        super(GraSP, self).__init__(masked_parameters, skip_last, prune_pw_only)
        self.temp = 200
        self.eps = 1e-8

    def score(self, model, loss, dataloader, device, double=True):
        self.model = model
        self.loss = loss
        self.dataloader = dataloader
        self.device = device
        self.double = double
        if double:
            double_model(model)

        # first gradient vector without computational graph
        stopped_grads = 0
        for batch_idx, (data, target) in enumerate(dataloader):
            data, target = data.to(device), target.to(device)
            if double:
                data = data.double()
            output = model(data) / self.temp
            L = loss(output, target)

            grads = torch.autograd.grad(L, [p for (_, p) in self.masked_parameters], create_graph=False)
            flatten_grads = torch.cat([g.reshape(-1) for g in grads if g is not None])
            stopped_grads += flatten_grads

        # second gradient vector with computational graph
        for batch_idx, (data, target) in enumerate(dataloader):
            data, target = data.to(device), target.to(device)
            if double:
                data = data.double()
            output = model(data) / self.temp
            L = loss(output, target)

            grads = torch.autograd.grad(L, [p for (_, p) in self.masked_parameters], create_graph=True)
            flatten_grads = torch.cat([g.reshape(-1) for g in grads if g is not None])

            gnorm = (stopped_grads * flatten_grads).sum()
            gnorm.backward()

        # calculate score Hg * theta (negate to remove top percent)
        for _, p in self.masked_parameters:
            self.scores[id(p)] = torch.clone(p.grad * p.data).detach()
            p.grad.data.zero_()

        # normalize score
        all_scores = torch.cat([torch.flatten(v) for v in self.scores.values()])
        norm = torch.abs(torch.sum(all_scores)) + self.eps
        for _, p in self.masked_parameters:
            self.scores[id(p)].div_(norm)

        if double:
            undouble_model(model)


class Mag(Pruner):
    def score(self, model, loss, dataloader, device, double=False):
        self.model = model
        self.loss = loss
        self.dataloader = dataloader
        self.device = device
        self.double = double
        for _, p in self.masked_parameters:
            self.scores[id(p)] = torch.clone(p.data).detach().abs_()


def double_model(model):
    for name, module in model.named_modules():
        if isinstance(module, layers.Conv2d) or isinstance(module, layers.BatchNorm2d) or isinstance(module, layers.Linear) or isinstance(
                module, nn.Linear):
            mask = module.weight_mask
            module.double()
            module.weight_mask = mask


def undouble_model(model):
    for name, module in model.named_modules():
        if isinstance(module, layers.Conv2d) or isinstance(module, layers.BatchNorm2d) or isinstance(module, layers.Linear) or isinstance(
                module, nn.Linear):
            mask = module.weight_mask
            module.float()
            module.weight_mask = mask
