import torch
from torch import nn
from math import floor
import ShiftingWindowSetting


# currently brittle to the name of biases as name must contain 'bias' for code to work
# Also in this implementation we do not retune on masked out params due to the assumption we only see
# data once (if it is not saved to memory)
class PackNet(ShiftingWindowSetting.CLLearningAlgo):

    def __init__(self, args, mask_rate=0.25, maskable_modules=nn.Conv2d,
                 batch_norm_modules=(nn.BatchNorm2d, nn.InstanceNorm2d)):
        super().__init__(args=args)
        self.maskable_modules = maskable_modules
        self.batch_norm_modules = batch_norm_modules
        self.mask_rate = mask_rate
        self.task_index = 1

        # weight_mask contains the task id if the first task that
        # the weight is used for if no task has used the weight yet its mak value is zero
        self.weight_mask = {}
        self.saved_weights = {}
        for module_name, module in self.model.named_modules():
            if self.should_be_masked(module):
                for param_name, param in module.named_parameters():
                    if 'bias' not in param_name:
                        self.weight_mask[(module_name, param_name)] = torch.zeros_like(param)
                        self.saved_weights[(module_name, param_name)] = torch.zeros_like(param)

    def should_be_masked(self, module):
        return isinstance(module, self.maskable_modules)

    def is_batch_norm_layer(self, module):
        return isinstance(module, self.batch_norm_modules)

    # freeze batch norm layers and biases after first task
    def _after_first_task(self):
        for _, module in self.model.named_modules():
            if self.is_batch_norm_layer(module):
                # while original PackNet paper freezes params and running stats after
                # first task we find that this hurt performance in our case an instead
                # just use constant param values with calculating the running statistics
                # this is consistent with other PyTorch implementations of PackNet
                for param in module.parameters():
                    param.requires_grad = False
                    # this should hopefully make weight decay and momentum behave well with freezing batch_norm
                    param.grad = None

            if self.should_be_masked(module):
                for name, param in module.named_parameters():
                    if 'bias' in name:
                        param.requires_grad = False
                        # this should hopefully make weight decay and momentum behave well with freezing biases
                        param.grad = None

    # reset param values to what they should be
    # (did it this way instead of setting grad to zero so that momentum and weight decay work well)
    def after_optimiser_step(self):
        for module_name, module in self.model.named_modules():
            if self.should_be_masked(module):
                for param_name, param in module.named_parameters():
                    if 'bias' not in param_name:
                        with torch.no_grad():
                            param.copy_(param*(self.weight_mask[(module_name, param_name)] == 0)
                                        + self.saved_weights[(module_name, param_name)])

    # mask out weights used for current task
    def at_end_of_task(self):

        if self.task_index == 1:
            self._after_first_task()

        for module_name, module in self.model.named_modules():
            if self.should_be_masked(module):
                with torch.no_grad():
                    for param_name, param in module.named_parameters():
                        if 'bias' not in param_name:
                            mask = self.weight_mask[(module_name, param_name)]
                            num_not_pruned = torch.sum(mask == 0).item()
                            param_new_view = (param*(mask == 0)).view(-1)
                            sorted_index = torch.argsort(torch.abs(param_new_view), descending=True)
                            mask = mask.view(-1)
                            saved_weight = self.saved_weights[(module_name, param_name)].view(-1)
                            mask_indexes = sorted_index[:floor(num_not_pruned*self.mask_rate)]
                            mask[mask_indexes] = self.task_index
                            saved_weight[mask_indexes] = param_new_view[mask_indexes]

        self.task_index += 1

    # task_id starts from 1 in internal representation but from outside starts from 0
    def before_eval_on_task(self, task_id):
        task_id += 1
        with torch.no_grad():
            for module_name, module in self.model.named_modules():
                if self.should_be_masked(module):
                    for param_name, param in module.named_parameters():
                        if 'bias' not in param_name:
                            mask = self.weight_mask[(module_name, param_name)]
                            mask_condition = (mask <= task_id)
                            param[mask_condition] = self.saved_weights[(module_name, param_name)][mask_condition]
                            param[torch.logical_not(mask_condition)] = 0.0


if __name__ == '__main__':
    learning_algo = PackNet(mask_rate=0.1)
    learning_algo.run_setting()


