import math
import torch
import numpy as np
import torch.nn as nn
from copy import deepcopy


def magnitude_redistribution(masking, name, weight, mask):
    mean_magnitude = torch.abs(weight)[mask.byte()].mean().item()
    return mean_magnitude


def magnitude_prune(masking, mask, weight, name):
    num_remove = math.ceil(masking.name2prune_rate[name] * masking.name2nonzeros[name])
    num_zeros = masking.name2zeros[name]
    k = math.ceil(num_zeros + num_remove)
    if num_remove == 0.0:
        return weight.data != 0.0

    x, idx = torch.sort(torch.abs(weight.data.view(-1)))
    mask.data.view(-1)[idx[:k]] = 0.0
    return mask


def random_growth(masking, name, new_mask, total_regrowth, weight):
    n = (new_mask == 0).sum().item()
    if n == 0:
        return new_mask
    expeced_growth_probability = (total_regrowth / n)
    new_weights = torch.rand(new_mask.shape).cuda() < expeced_growth_probability
    return new_mask.byte() | new_weights


class LinearDecay(object):
    """Anneals the pruning rate linearly with each step."""
    def __init__(self, prune_rate, T_max):
        self.decrement = prune_rate / float(T_max)
        self.current_prune_rate = prune_rate

    def step(self):
        self.current_prune_rate -= self.decrement

    def get_dr(self, prune_rate):
        return self.current_prune_rate


class Masking(object):
    """Wraps PyTorch model parameters with a sparse mask.
    Creates a mask for each parameter tensor contained in the model. When
    `apply_mask()` is called, it applies the sparsity pattern to the parameters.
    Basic usage:
        optimizer = torch.optim.SGD(model.parameters(),lr=args.lr)
        decay = CosineDecay(args.prune_rate, len(train_loader)*(args.epochs))
        mask = Masking(optimizer, prune_rate_decay=decay)
        model = MyModel()
        mask.add_module(model)
    """
    def __init__(self, optimizer, T_max, prune_rate):
        self.prune_rate_decay = LinearDecay(prune_rate, T_max)

        self.masks = {}
        self.modules = []
        self.optimizer = optimizer

        self.adjusted_growth = 0
        self.adjustments = []
        self.baseline_nonzero = None
        self.name2baseline_nonzero = {}

        # stats
        self.name2variance = {}
        self.name2zeros = {}
        self.name2nonzeros = {}
        self.name2removed = {}
        self.total_removed = 0
        self.total_nonzero = 0
        self.prune_rate = prune_rate
        self.name2prune_rate = {}

    def init_growth_prune_and_redist(self):
        self.growth_func = random_growth
        self.prune_func = magnitude_prune
        self.redistribution_func = magnitude_redistribution

    def at_end_of_epoch(self):
        self.truncate_weights()
        print("prune_rate: ", self.prune_rate)

    def step(self):
        self.optimizer.step()
        self.apply_mask()
        self.prune_rate_decay.step()
        self.prune_rate = self.prune_rate_decay.get_dr(self.prune_rate)

    def add_module(self, module, density, init_masks):
        self.modules.append(module)
        self.sparsity = density
        self.init_growth_prune_and_redist()
        print(self.optimizer.param_groups[0]['lr'])
        self.baseline_nonzero = 0

        for name in init_masks:
            weight = deepcopy(init_masks[name])
            self.masks[name] = weight
            self.baseline_nonzero += weight.numel() * density

        self.apply_mask()
        total_size = 0
        for name, module in self.modules[0].named_modules():
            if hasattr(module, 'weight'):
                if isinstance(module, nn.BatchNorm1d):
                    continue
                else:
                    total_size += module.weight.numel()
            if hasattr(module, 'bias'):
                if module.bias is not None:
                    total_size += module.bias.numel()
        print('Total Model parameters:', total_size)

        total_size = 0
        for name, weight in self.masks.items():
            total_size += weight.numel()
        print('Total parameters after removed layers:', total_size)
        print('Total parameters under sparsity level of {0}: {1}'.format(density, density * total_size))

    def apply_mask(self):
        for module in self.modules:
            for name, tensor in module.named_parameters():
                if name in self.masks:
                    tensor.data = tensor.data * self.masks[name]

    def adjust_prune_rate(self):
        for module in self.modules:
            for name, weight in module.named_parameters():
                if name not in self.masks:
                    continue
                if name not in self.name2prune_rate:
                    self.name2prune_rate[name] = self.prune_rate

                self.name2prune_rate[name] = self.prune_rate

                sparsity = self.name2zeros[name] / float(self.masks[name].numel())
                if sparsity < 0.2:
                    # determine if matrix is relativly dense but still growing
                    expected_variance = 1.0 / len(list(self.name2variance.keys()))
                    actual_variance = self.name2variance[name]
                    expected_vs_actual = expected_variance / actual_variance
                    if expected_vs_actual < 1.0:
                        self.name2prune_rate[name] = min(sparsity, self.name2prune_rate[name])

    def truncate_weights(self):
        self.gather_statistics()
        self.adjust_prune_rate()

        total_nonzero_new = 0
        for module in self.modules:
            for name, weight in module.named_parameters():
                if name not in self.masks:
                    continue
                mask = self.masks[name]
                # prune
                new_mask = self.prune_func(self, mask, weight, name)
                removed = self.name2nonzeros[name] - new_mask.sum().item()
                self.total_removed += removed
                self.name2removed[name] = removed
                self.masks[name][:] = new_mask
        print('Remove newly through pruning: {}'.format(self.name2removed))
        name2regrowth = self.calc_growth_redistribution()

        for module in self.modules:
            for name, weight in module.named_parameters():
                if name not in self.masks:
                    continue
                new_mask = self.masks[name].data.byte()
                new_mask = self.growth_func(self, name, new_mask, math.floor(name2regrowth[name]), weight)
                new_nonzero = new_mask.sum().item()

                # exchanging masks
                self.masks.pop(name)
                self.masks[name] = new_mask.float()
                total_nonzero_new += new_nonzero
        self.apply_mask()
        self.adjustments.append(self.baseline_nonzero - total_nonzero_new)
        self.adjusted_growth = 0.25 * self.adjusted_growth + (0.75 * (self.baseline_nonzero - total_nonzero_new)) + np.mean(self.adjustments)
        if self.total_nonzero > 0:
            print('Nonzero before/after: {0}/{1}. Growth adjustment: {2:.2f}.'.format(
                  self.total_nonzero, total_nonzero_new, self.adjusted_growth))

    def gather_statistics(self):
        self.name2nonzeros, self.name2zeros, self.name2variance, self.name2removed = {}, {}, {}, {}
        self.total_variance, self.total_removed, self.total_nonzero, self.total_zero = 0, 0, 0, 0
        for module in self.modules:
            for name, weight in module.named_parameters():
                if name not in self.masks:
                    continue
                mask = self.masks[name]
                # redistribution
                self.name2variance[name] = self.redistribution_func(self, name, weight, mask)

                if not np.isnan(self.name2variance[name]):
                    self.total_variance += self.name2variance[name]
                self.name2nonzeros[name] = mask.sum().item()
                self.name2zeros[name] = mask.numel() - self.name2nonzeros[name]
                self.total_nonzero += self.name2nonzeros[name]
                self.total_zero += self.name2zeros[name]

        for name in self.name2variance:
            if self.total_variance != 0.0:
                self.name2variance[name] /= self.total_variance
            else:
                print('Total variance was zero!')
                print(self.growth_func, self.prune_func, self.redistribution_func, self.name2variance)

    def calc_growth_redistribution(self):
        residual = 0
        residual = 9999
        mean_residual = 0
        name2regrowth = {}
        i = 0
        while residual > 0 and i < 1000:
            residual = 0
            for name in self.name2variance:
                prune_rate = self.name2prune_rate[name]
                num_remove = math.ceil(prune_rate * self.name2nonzeros[name])
                num_zero = self.name2zeros[name]
                # It determines the max regrowth value based on what was pruned and what was zero earlier (added)
                max_regrowth = num_zero + num_remove

                if name in name2regrowth:
                    regrowth = name2regrowth[name]
                else:
                    regrowth = math.ceil(self.name2variance[name] * (self.total_removed + self.adjusted_growth))
                regrowth += mean_residual

                if regrowth > 0.99 * max_regrowth:
                    name2regrowth[name] = 0.99 * max_regrowth
                    residual += regrowth - name2regrowth[name]
                else:
                    name2regrowth[name] = regrowth
            if len(name2regrowth) == 0:
                mean_residual = 0
            else:
                mean_residual = residual / len(name2regrowth)
            i += 1

        if i == 1000:
            print('Error resolving the residual! Layers are too full! Residual left over: {0}'.format(residual))

        for module in self.modules:
            for name, weight in module.named_parameters():
                if name not in self.masks:
                    continue
        print('Prune mode: {}, name2regrowth {}'.format(self.prune_func, name2regrowth))
        return name2regrowth
