import time
import copy
import numpy as np
import torch
import torch.nn.utils.prune as prune
import optimizers.optimizers as optimizers
from config import datasetDict, trainTransformDict, testTransformDict
from utilities import AverageMeter, RunningAverage, MultiStepLRExtended
from utilities import Utilities as Utils
from strategies import strategies
from metrics import metrics
from typing import Dict, Union, List
import wandb
import importlib
import os
from barbar import Bar
import sys
from math import sqrt, cos, pi
import json
from collections import OrderedDict


class Runner:

    def __init__(self, config, debug_mode=False):
        self.config = config
        self.dataParallel = True if (self.config.dataset == 'imagenet' and torch.cuda.device_count() > 1) else False
        if not self.dataParallel:
            self.device = torch.device(config.device)
            if 'gpu' in config.device:
                torch.cuda.set_device(self.device)
        else:
            # Use all GPUs
            self.device = torch.device("cuda:0")
            torch.cuda.device(self.device)

        # Set a couple useful variables
        self.k_accuracy = 5 if self.config.dataset in ['cifar100', 'imagenet'] else 3
        self.aborted_model = False
        self.checkpoint_file = None
        self.comparison_model = None
        self.trained_test_accuracy = None
        self.after_pruning_metrics = {}
        self.totalTrainTime = None
        self.totalFinetuneTime = None
        self.loss_search_rewinding_occurence = None
        self.debug_mode = debug_mode  # If active, don't use abort criterion, use less desired_sparsities, etc
        self.seed = None
        self.trainedModelFile = None
        self.trained_norm_square = None
        self.rewinding_checkpoints = None
        self.warm_up_fn = None
        self.n_warmup_epochs = None
        self.trainIteration = None
        self.momentum_trainIteration = None
        self.momentum_warmup_iterations = None
        self.momentum_warmup_fn = None
        self.lr_trainIteration = None
        self.lr_warmup_iterations = None
        self.lr_warmup_fn = None
        self.gradient_callback = None
        self.trainIterationCtr = 1

        # IMP Sparsity settings
        if self.debug_mode:
            self.desired_sparsities = [0.9, 0.98]
        else:
            # Default settings for CIFAR-10/100
            self.desired_sparsities = [0.9, 0.93, 0.95, 0.98]

        # Define the loss object and metrics
        # Important note: for the correct computation of loss/accuracy it's important to have reduction == 'mean'
        self.loss_criterion = torch.nn.CrossEntropyLoss().to(device=self.device)

        self.train_loss, self.train_accuracy, self.train_k_accuracy = AverageMeter(), AverageMeter(), AverageMeter()
        self.test_loss, self.test_accuracy, self.test_k_accuracy = AverageMeter(), AverageMeter(), AverageMeter()

    def reset_averaged_metrics(self):
        """Resets all metrics"""
        self.train_loss.reset()
        self.train_accuracy.reset()
        self.train_k_accuracy.reset()

        self.test_loss.reset()
        self.test_accuracy.reset()
        self.test_k_accuracy.reset()

    def get_metrics(self):
        with torch.no_grad():
            self.strategy.start_forward_mode() # Necessary to have correct computations for DPF
            x_input, y_target = next(iter(self.testLoader))
            x_input, y_target = x_input.to(self.device), y_target.to(self.device)  # Move to CUDA if possible
            n_flops, n_nonzero_flops = metrics.get_flops(model=self.model, x_input=x_input)
            n_total, n_nonzero = metrics.get_parameter_count(model=self.model)
            p_active, p_inactive = {}, {}  # Pass empty dicts to not be listed by wandb
            if self.config.extended_logging:
                p_active, p_inactive = metrics.get_active_inactive_ratio(model=self.model, comparison_model=self.comparison_model)
            per_layer_thresholds = self.strategy.get_per_layer_thresholds()
            threshold_histogram = {}
            if len(per_layer_thresholds) > 0:
                num_bins = min(512, len(per_layer_thresholds))
                threshold_histogram = wandb.Histogram(per_layer_thresholds, num_bins=num_bins)



            loggingDict = dict(
                train=dict(
                    loss=self.train_loss.result(),
                    accuracy=self.train_accuracy.result(),
                    k_accuracy=self.train_k_accuracy.result(),
                ),
                test=dict(
                    loss=self.test_loss.result(),
                    accuracy=self.test_accuracy.result(),
                    k_accuracy=self.test_k_accuracy.result(),
                ),
                global_sparsity=metrics.global_sparsity(module=self.model),
                global_almost_sparsity=metrics.global_almost_sparsity(self.model),
                global_compression=metrics.compression_rate(module=self.model),
                global_almost_compression=metrics.almost_compression_rate(module=self.model),
                nonzero_inference_flops=n_nonzero_flops,
                baseline_inference_flops=n_flops,
                theoretical_speedup=metrics.get_theoretical_speedup(n_flops=n_flops, n_nonzero_flops=n_nonzero_flops),
                n_total_params=n_total,
                n_nonzero_params=n_nonzero,
                learning_rate=float(self.optimizer.param_groups[0]['lr']),
                active_params=p_active,
                inactive_params=p_inactive,
                distance_to_origin=metrics.get_distance_to_origin(self.model),
                threshold_histogram=threshold_histogram,
            )
            self.strategy.end_forward_mode()    # Necessary to have correct computations for DPF
        return loggingDict

    def get_dataloaders(self):
        # Load datasets from torchvision
        rootPath = f'./datasets_pytorch/{self.config.dataset}-data'
        if self.config.dataset in ['imagenet']:
            # Training data lies on NO_BACKUP
            rootPath = '/home/mzimmer/NO_BACKUP/imagenet-data'
            trainData = datasetDict[self.config.dataset](root=rootPath, split='train',
                                                         transform=trainTransformDict[self.config.dataset])
            testData = datasetDict[self.config.dataset](root=rootPath, split='val',
                                                        transform=testTransformDict[self.config.dataset])
        else:
            trainData = datasetDict[self.config.dataset](root=rootPath, train=True, download=True,
                                                         transform=trainTransformDict[self.config.dataset])
            testData = datasetDict[self.config.dataset](root=rootPath, train=False,
                                                        transform=testTransformDict[self.config.dataset])

        if self.config.dataset in ['imagenet', 'cifar100']:
            num_workers = 4 if torch.cuda.is_available() else 0
        else:
            num_workers = 2 if torch.cuda.is_available() else 0

        trainLoader = torch.utils.data.DataLoader(trainData, batch_size=self.config.batch_size, shuffle=True,
                                                  pin_memory=torch.cuda.is_available(), num_workers=num_workers)
        testLoader = torch.utils.data.DataLoader(testData, batch_size=self.config.batch_size, shuffle=False,
                                                 pin_memory=torch.cuda.is_available(), num_workers=num_workers)

        return trainLoader, testLoader

    def get_model(self, load_checkpoint: bool = False, load_initial: bool = False) -> torch.nn.Module:
        if not load_checkpoint:
            # Define the model
            model = getattr(importlib.import_module('models_pytorch.' + self.config.dataset), self.config.model)()
        if load_checkpoint:
            # self.model must exist already
            model = self.model
            # Note, we have to get rid of all existing prunings, otherwise we cannot load the state_dict as it is unpruned
            self.strategy.remove_pruning_hooks(model=self.model)
            file = self.checkpoint_file if not load_initial else os.path.join(wandb.run.dir, self.trainedModelFile)

            # We need to check whether the model was loaded using dataparallel, in that case remove 'module'
            # original saved file with DataParallel
            state_dict = torch.load(
                file, map_location=self.device)
            # create new OrderedDict that does not contain `module.`
            new_state_dict = OrderedDict()
            for key, val in state_dict.items():
                new_key = key
                if key.startswith("module."):
                    new_key = key[7:]  # remove `module.`
                new_state_dict[new_key] = val
            model.load_state_dict(new_state_dict)
        elif load_initial:
            # Load initial model from specified path

            # We need to check whether the model was loaded using dataparallel, in that case remove 'module'
            # original saved file with DataParallel
            state_dict = torch.load(os.path.join(wandb.run.dir, self.trainedModelFile), map_location=self.device)
            # create new OrderedDict that does not contain `module.`
            new_state_dict = OrderedDict()
            for key, val in state_dict.items():
                new_key = key
                if key.startswith("module."):
                    new_key = key[7:]  # remove `module.`
                new_state_dict[new_key] = val
            # load params
            model.load_state_dict(new_state_dict)

        if self.dataParallel and not load_checkpoint:   # Only apply DataParallel when re-initializing the model!
            # We use DataParallelism
            model = torch.nn.DataParallel(model)
        model = model.to(device=self.device)
        return model

    def define_optimizer_scheduler(self):
        # Check if we want to use a learning rate scheduler
        self.per_iteration_update = False
        if type(self.config.learning_rate) is str:
            # Learning rate scheduler in the form (type, kwargs)
            tupleStr = self.config.learning_rate.strip()
            # Remove parenthesis
            if tupleStr[0] == '(':
                tupleStr = tupleStr[1:]
            if tupleStr[-1] == ')':
                tupleStr = tupleStr[:-1]
            name, *kwargs = tupleStr.split(',')
            if name in ['StepLR', 'MultiStepLR', 'CosineAnnealingWarmRestarts', 'MultiStepLRExtended']:
                self.scheduler = (name, kwargs)
                self.initial_lr = float(kwargs[0])
            else:
                raise NotImplementedError(f"LR Scheduler {name} not implemented.")
        else:
            self.scheduler = None
            self.initial_lr = self.config.learning_rate

        if self.config.strategy in ['GSM', 'GSM_IMP_grad', 'GSM_IMP', 'LC', 'SparseSGD_IMP']:
            assert self.config.optimizer == 'SGD', f'{self.config.strategy} only works with SGD, but {self.config.optimizer} specified.'
            if self.config.strategy == 'LC':
                assert self.config.weight_decay > 0 and self.config.weight_decay_ord == 2, 'LearningCompression requires L2 weight_decay.'
            else:
                assert self.config.momentum != 0, f'{self.config.strategy} requires active momentum.'
        elif self.config.strategy == 'GREG':
            assert self.config.weight_decay_ord == 2, "GREG requires L2-wd."

        # Define the optimizer
        if self.config.optimizer in ['SGD', 'Proj_SGD']:
            wd = self.config.weight_decay if self.config.strategy != 'CS' else None # For CS, apply wd manually
            self.optimizer = optimizers.SGD(params=self.model.parameters(), lr=self.initial_lr,
                                            momentum=self.config.momentum, dampening=self.config.dampening,
                                            weight_decay=wd,
                                            weight_decay_ord=self.config.weight_decay_ord,
                                            nesterov=self.config.nesterov, global_constraint=None, gradient_callback=self.gradient_callback)
        elif self.config.optimizer == 'Prox_SGD':
            # Define the proximal operator, then initialize Prox_SGD optimizer
            if self.config.weight_decay_ord == 1:
                prox_operator = optimizers.ProximalOperator.soft_thresholding(weight_decay=self.config.weight_decay)
                global_constraint = False
            elif self.config.weight_decay_ord == "1-[k]":
                prox_operator = optimizers.ProximalOperator.knorm_soft_thresholding(weight_decay=self.config.weight_decay, k=self.strategy.get_K())
                global_constraint = True
            else:
                raise NotImplementedError(f"Proximal operator for weight_decay_ord {self.config.weight_decay_ord} not implemented.")
            self.optimizer = optimizers.Prox_SGD(params=self.model.parameters(), prox_operator=prox_operator, global_constraint=global_constraint,
                                                 lr=self.initial_lr, momentum=self.config.momentum,
                                                 dampening=self.config.dampening, nesterov=self.config.nesterov)
        else:
            raise NotImplementedError(f"Optimizer {self.config.optimizer} not implemented.")

        if self.scheduler:
            # We define a scheduler
            name, kwargs = self.scheduler
            if name == 'StepLR':
                # Tuple of form ('StepLR', initial_lr, step_size, gamma)
                # Reduces initial_lr by gamma every step_size epochs
                step_size, gamma = int(kwargs[1]), float(kwargs[2])
                self.scheduler = torch.optim.lr_scheduler.StepLR(optimizer=self.optimizer, step_size=step_size,
                                                                 gamma=gamma)
            elif name == 'MultiStepLR':
                # Tuple of form ('MultiStepLR', initial_lr, milestones, gamma)
                # Reduces initial_lr by gamma every epoch that is in the list milestones
                milestones, gamma = kwargs[1].strip(), float(kwargs[2])
                # Remove square bracket
                if milestones[0] == '[':
                    milestones = milestones[1:]
                if milestones[-1] == ']':
                    milestones = milestones[:-1]
                milestones = [int(ms) for ms in milestones.split('|')]
                self.scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer=self.optimizer, milestones=milestones,
                                                                      gamma=gamma)
            elif name == 'MultiStepLRExtended':
                # Tuple of form ('MultiStepLRExtended', initial_lr, milestones, gammas)
                # Reduces initial_lr by gamma[epoch] every epoch that is in the list milestones
                milestones, gammas = kwargs[1].strip(), kwargs[2].strip()
                # Remove square bracket
                if milestones[0] == '[':
                    milestones = milestones[1:]
                if milestones[-1] == ']':
                    milestones = milestones[:-1]
                if gammas[0] == '[':
                    gammas = gammas[1:]
                if gammas[-1] == ']':
                    gammas = gammas[:-1]
                milestones = [int(ms) for ms in milestones.split('|')]
                gammas = [float(gm) for gm in gammas.split('|')]
                self.scheduler = MultiStepLRExtended(optimizer=self.optimizer, milestones=milestones,
                                                                      gammas=gammas)
            elif name == 'CosineAnnealingWarmRestarts':
                eta_min, epochs_per_restart = float(kwargs[1].strip()), float(kwargs[2].strip())
                n_it = int(epochs_per_restart * len(self.trainLoader))
                self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer=self.optimizer, T_0=n_it, eta_min=eta_min)
                self.per_iteration_update = True


    def define_strategy(self):
        sparsityList = [
            self.config.gsm_desired_sparsity] if self.config.gsm_desired_sparsity is not None else self.desired_sparsities
        if self.config.strategy == 'Dense':
            return strategies.Dense()
        elif self.config.strategy == 'IMP':
            return strategies.IMP(desired_sparsities=sparsityList,
                                  n_epochs_finetune=self.config.n_epochs_finetune, n_phases=self.config.n_phases)
        elif self.config.strategy == 'IRP':
            return strategies.IRP(desired_sparsities=sparsityList,
                                  n_epochs_finetune=self.config.n_epochs_finetune)
        elif self.config.strategy == 'LAMP_IMP':
            return strategies.LAMP_IMP(desired_sparsities=sparsityList,
                                  n_epochs_finetune=self.config.n_epochs_finetune)
        elif self.config.strategy == 'Uniform_IMP':
            return strategies.Uniform_IMP(desired_sparsities=sparsityList,
                                  n_epochs_finetune=self.config.n_epochs_finetune)
        elif self.config.strategy == 'UniformPlus_IMP':
            return strategies.UniformPlus_IMP(desired_sparsities=sparsityList,
                                  n_epochs_finetune=self.config.n_epochs_finetune)
        elif self.config.strategy == 'ERK_IMP':
            return strategies.ERK_IMP(desired_sparsities=sparsityList,
                                  n_epochs_finetune=self.config.n_epochs_finetune)
        elif self.config.strategy == 'GSM':
            return strategies.GSM(desired_sparsity=self.config.gsm_desired_sparsity,
                                           n_epochs_finetune=self.config.n_epochs_finetune)
        elif self.config.strategy == 'GGSM':
            return strategies.GGSM(desired_sparsity=self.config.gsm_desired_sparsity,
                                      n_epochs_finetune=self.config.n_epochs_finetune)
        elif self.config.strategy == 'LC':
            return strategies.LC(desired_sparsity=self.config.gsm_desired_sparsity,
                                 n_epochs_finetune=self.config.n_epochs_finetune,
                                 change_weight_decay_callback = self.change_weight_decay_callback,
                                 n_epochs_total = self.config.nepochs,
                                 initial_weight_decay = self.config.weight_decay)
        elif self.config.strategy == 'GradualPruning':
            momentum_callback = self.change_momentum_warmup_callback if self.config.momentum_warmup else None
            lr_callback = self.change_lr_warmup_callback if self.config.lr_warmup else None
            if not self.config.use_uniform:
                return strategies.GradualPruning(model=self.model, n_train_epochs=self.config.nepochs,
                                             n_epochs_finetune=self.config.n_epochs_finetune,
                                             desired_sparsity=self.config.gsm_desired_sparsity,
                                             pruning_steps=self.config.pruning_steps,
                                             allow_recovering=self.config.allow_recovering,
                                             after_pruning_callback=self.after_pruning_callback,
                                             change_warmup_callback=momentum_callback, change_lr_callback=lr_callback, time_mode=self.config.time_mode)
            else:
                return strategies.GradualPruning_Uniform(model=self.model, n_train_epochs=self.config.nepochs,
                                             n_epochs_finetune=self.config.n_epochs_finetune,
                                             desired_sparsity=self.config.gsm_desired_sparsity,
                                             pruning_steps=self.config.pruning_steps,
                                             allow_recovering=self.config.allow_recovering,
                                             after_pruning_callback=self.after_pruning_callback,
                                             change_warmup_callback=momentum_callback, change_lr_callback=lr_callback, time_mode=self.config.time_mode)
        elif self.config.strategy == 'DPF':
            return strategies.DPF(model=self.model, n_train_epochs=self.config.nepochs,
                                  n_epochs_finetune=self.config.n_epochs_finetune,
                                  desired_sparsity=self.config.gsm_desired_sparsity,
                                  pruning_steps=self.config.pruning_steps,
                                  after_pruning_callback=self.after_pruning_callback)
        elif self.config.strategy == 'CS':
            return strategies.CS(n_epochs_finetune=self.config.n_epochs_finetune, s_initial=self.config.s_initial,
                                 beta_final=self.config.beta_final, T_it=int(len(self.trainLoader)*self.config.nepochs))
        elif self.config.strategy == 'STR':
            return strategies.STR(n_epochs_finetune=self.config.n_epochs_finetune, s_initial=self.config.s_initial,
                                 use_global_threshold=self.config.use_global_threshold)
        elif self.config.strategy == 'DST':
            return strategies.DST(n_epochs_finetune=self.config.n_epochs_finetune)
        elif self.config.strategy == 'DNW':
            return strategies.DNW(desired_sparsity=self.config.gsm_desired_sparsity, n_epochs_finetune=self.config.n_epochs_finetune)

    def update_learning_rates(self, training_warmup_fn_epoch=None):
        if training_warmup_fn_epoch:
            lr = self.training_warmup_fn(training_warmup_fn_epoch)
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = lr
        elif self.scheduler and not self.per_iteration_update:
            self.scheduler.step()

    def log(self, runTime, finetuning: bool = False, desired_sparsity=None):
        loggingDict = self.get_metrics()
        self.strategy.start_forward_mode()
        loggingDict.update({'epoch_run_time': runTime})
        if not finetuning:
            if self.config.gsm_desired_sparsity is not None:
                distance_to_pruned, rel_distance_to_pruned = metrics.get_distance_to_pruned(model=self.model, sparsity=self.config.gsm_desired_sparsity)
                loggingDict.update({'distance_to_pruned': distance_to_pruned,
                                    'relative_distance_to_pruned': rel_distance_to_pruned})

            # Update final trained metrics (necessary to be able to filter via wandb)
            for metric_type, val in loggingDict.items():
                wandb.run.summary[f"trained.{metric_type}"] = val
            if self.totalTrainTime:
                # Total train time captured, hence training is done
                wandb.run.summary["trained.total_train_time"] = self.totalTrainTime
            # The usual logging of one epoch
            wandb.log(
                loggingDict
            )

        else:
            if desired_sparsity is not None:
                finalDict = dict(finetune=loggingDict,
                         pruned=self.after_pruning_metrics[desired_sparsity],   # Metrics directly after pruning
                         desired_sparsity=desired_sparsity,
                         total_finetune_time=self.totalFinetuneTime,
                         )
                wandb.log(finalDict)
                # Dump sparsity distribution to json and upload
                sparsity_distribution = metrics.per_layer_sparsity(model=self.model)
                fPath = os.path.join(wandb.run.dir, f'sparsity_distribution_{desired_sparsity}.json')
                with open(fPath, 'w') as fp:
                    json.dump(sparsity_distribution, fp)
                wandb.save(fPath)

            else:
                wandb.log(
                    dict(finetune=loggingDict,
                         ),
                )
        self.strategy.end_forward_mode()

    def final_log(self, actual_sparsity=None):
        s = self.config.gsm_desired_sparsity
        if actual_sparsity is not None and self.config.gsm_desired_sparsity is None:
            wandb.run.summary[f"actual_sparsity"] = actual_sparsity
            s = actual_sparsity
        elif self.config.gsm_desired_sparsity is None:
            # Note: This function may only be called if a desired_sparsity has been given upfront, i.e. GSM etc
            raise AssertionError("Final logging was called even though no gsm_desired_sparsity was given.")

        # Recompute accuracy and loss
        sys.stdout.write(
            f"\nFinal logging\n")
        self.reset_averaged_metrics()
        self.evaluate_model(data='train')
        self.evaluate_model(data='test')
        # Update final trained metrics (necessary to be able to filter via wandb)
        loggingDict = self.get_metrics()
        for metric_type, val in loggingDict.items():
            wandb.run.summary[f"final.{metric_type}"] = val

        # Update after prune metrics
        if s in self.after_pruning_metrics:
            for metric_type, val in self.after_pruning_metrics[s].items():
                wandb.run.summary[f"pruned.{metric_type}"] = val

    def after_pruning_callback(self, desired_sparsity: float, prune_momentum: bool = False, reset_momentum: bool = False) -> None:
        # This function must be called once for every sparsity, directly after pruning
        # Compute losses, accuracies after pruning
        sys.stdout.write(f"\nDesired sparsity {desired_sparsity} - Computing incurred losses after pruning.\n")
        self.reset_averaged_metrics()
        self.evaluate_model(data='train')
        self.evaluate_model(data='test')
        if self.trained_norm_square is not None:
            L2_norm_square = Utils.get_model_norm_square(self.model)
            norm_drop = sqrt(abs(self.trained_norm_square - L2_norm_square))
            if float(sqrt(self.trained_norm_square)) > 0:
                relative_norm_drop = norm_drop / float(sqrt(self.trained_norm_square))
            else:
                relative_norm_drop = {}
        else:
            norm_drop, relative_norm_drop = {}, {}
        if self.trained_test_accuracy is not None and self.trained_test_accuracy > 0:
            pruning_instability = (self.trained_test_accuracy - self.test_accuracy.result())/self.trained_test_accuracy
            pruning_stability = 1 - pruning_instability
        else:
            pruning_instability, pruning_stability = {}, {}
        self.after_pruning_metrics[desired_sparsity] = dict(
            train = dict(
                loss=self.train_loss.result(),
                accuracy=self.train_accuracy.result(),
                k_accuracy=self.train_k_accuracy.result(),
            ),
            test = dict(
                loss=self.test_loss.result(),
                accuracy=self.test_accuracy.result(),
                k_accuracy=self.test_k_accuracy.result(),
            ),
            norm_drop=norm_drop,
            relative_norm_drop=relative_norm_drop,
            pruning_instability=pruning_instability,
            pruning_stability=pruning_stability,
        )
        if reset_momentum:
            sys.stdout.write(
                f"Resetting momentum_buffer (if existing) for potential finetuning.\n")
            self.optimizer.reset_momentum()
        elif prune_momentum:
            sys.stdout.write(
                f"Pruning momentum_buffer (if existing).\n")
            self.strategy.prune_momentum(optimizer=self.optimizer)

    def change_scheduler_callback(self, scheduler):
        self.scheduler = scheduler

    def change_momentum_warmup_callback(self, n_epochs_warmup: float):
        self.momentum_trainIteration = 1
        self.momentum_warmup_iterations = int(n_epochs_warmup * len(self.trainLoader))
        def warm_up_fn(t):
            if t == self.momentum_warmup_iterations:
                self.momentum_warmup_fn = None
            return self.config.momentum + 0.5 * (0 - self.config.momentum) * (1 + cos(pi * t / self.momentum_warmup_iterations))
        self.momentum_warmup_fn = warm_up_fn

    def change_lr_warmup_callback(self, n_epochs_warmup: float):
        initial_lr = float(self.optimizer.param_groups[0]['lr'])
        self.lr_trainIteration = 1
        self.lr_warmup_iterations = int(n_epochs_warmup * len(self.trainLoader))

        def warm_up_fn(t):
            if t == self.lr_warmup_iterations:
                self.lr_warmup_fn = None
            return initial_lr + 0.5 * (0 - initial_lr) * (1 + cos(pi * t / self.lr_warmup_iterations))

        self.lr_warmup_fn = warm_up_fn

    def change_weight_decay_callback(self, penalty):
        for group in self.optimizer.param_groups:
            group['weight_decay'] = penalty
        print(f"Changed weight decay to {penalty}.")

    def restore_model(self, from_initial=False) -> None:
        outputStr = 'initial' if from_initial else 'checkpoint'
        sys.stdout.write(f"Restoring {outputStr} model from {self.checkpoint_file if not from_initial else self.trainedModelFile}.\n")
        self.model = self.get_model(load_checkpoint=True, load_initial=(from_initial is True))

    def save_model(self, model_type: str, remove_pruning_hooks: bool = False) -> str:
        if model_type == 'initial':
            fName = self.trainedModelFile
        else:
            fName = f"{self.config.dataset}_{self.config.optimizer}_{self.config.model}_{model_type}_{self.config.run_id}_{self.seed}.pt"
        if model_type == 'trained':
            # Save the trained model name to wandb to use it in lr_analysis
            wandb.summary['trained_model_file'] = fName
        fPath = os.path.join(wandb.run.dir, fName)
        if remove_pruning_hooks:
            self.strategy.remove_pruning_hooks(model=self.model)
        torch.save(self.model.state_dict(), fPath)  # Save the state_dict to the wandb directory
        return fPath

    def evaluate_model(self, data='train'):
        return self.train_epoch(data=data, evaluation_mode=True)

    def rewinding(self, current_epoch: int, n_epochs_finetune: int, desired_sparsity: float) -> None:
        if self.config.lr_rewinding in ['constant', 'retraction']:
            sys.stdout.write(f"Rewinding method {self.config.lr_rewinding}: No rewinding to specific epoch.\n")
            return
        elif self.config.lr_rewinding == 'lr-rewinding':
            # Learning rate rewinding as in Renda2020
            method = 'lr-rewinding'
            rewind_epoch = self.config.nepochs - n_epochs_finetune + current_epoch
        elif self.config.lr_rewinding == 'copycat':
            method = 'copycat'
            rewind_epoch = 'N/A'
            epoch_lr = self.copycat_lr[current_epoch]
        elif self.config.lr_rewinding == 'SLR':
            method = 'SLR'
            if current_epoch <= self.n_warmup_epochs:
                # Warmup is handled in train_epoch
                sys.stdout.write(
                    f"Rewinding method {method}: No rewinding, still in warmup phase. Current lr: {float(self.optimizer.param_groups[0]['lr'])}\n")
                return
            rewind_epoch = 'N/A'
            self.warm_up_fn = None  # If set to None, then warm up has finished
            epoch_lr = self.SLR_lr[current_epoch]

        # Rewind learning rate
        if self.config.lr_rewinding not in ['copycat', 'SLR']:
            epoch_lr = self.strategy.lr_dict[rewind_epoch]
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = epoch_lr
        sys.stdout.write(f"Rewinding method {method}: Rewinding to epoch {rewind_epoch}, lr {epoch_lr}.\n")

    def fine_tuning(self, desired_sparsity, n_epochs_finetune, phase=1):
        n_phases = self.config.n_phases or 1
        self.trainIteration = 0
        if self.config.momentum_warmup_epochs and self.config.momentum_warmup_epochs > 0:
            self.change_momentum_warmup_callback(n_epochs_warmup=self.config.momentum_warmup_epochs)
        if phase == 1:
            self.finetuneStartTime = time.time()
        if self.config.lr_rewinding == 'SLR':
            maxLR = max(self.strategy.lr_dict.values()) # This is equivalent to the initial lr!
            after_warmup_index = min([idx for idx, val in self.strategy.lr_dict.items() if val == maxLR and idx > 0])
            minLR = min(list(self.strategy.lr_dict.values())[after_warmup_index:])
            self.n_warmup_epochs = int(
                0.1 * n_epochs_finetune)  # 10% of retraining time, makes sense for n_epochs_finetune >= 10
            n_warmup_iterations = float(len(self.trainLoader) * self.n_warmup_epochs)

            def warm_up_fn(t):
                return maxLR + 0.5 * (minLR - maxLR) * (1 + cos(pi * t / n_warmup_iterations))

            self.warm_up_fn = warm_up_fn

            # After self.n_warmup_epochs, do copycat_start
            epochLRs = [self.strategy.lr_dict[i] if i >= after_warmup_index else maxLR for i in range(1, self.config.nepochs + 1, 1)]
            remainingEpochs = n_epochs_finetune - self.n_warmup_epochs
            interpolation_width = self.config.nepochs / remainingEpochs  # in general not an integer
            reducedLRs = [epochLRs[int(j * interpolation_width)] for j in range(remainingEpochs)]
            self.SLR_lr = {self.n_warmup_epochs + j + 1: reducedLRs[j] for j in range(remainingEpochs)}


        for epoch in range(1, n_epochs_finetune + 1, 1):
            self.reset_averaged_metrics()
            sys.stdout.write(f"\nDesired sparsity {desired_sparsity} - Finetuning: phase {phase}/{n_phases} | epoch {epoch}/{n_epochs_finetune}\n")
            # Train
            t = time.time()
            if self.lr_warmup_fn is None or epoch == 1:
                self.rewinding(current_epoch=epoch, n_epochs_finetune=n_epochs_finetune, desired_sparsity=desired_sparsity)
            if epoch == 1 and self.config.lr_warmup_epochs and self.config.lr_warmup_epochs > 0:
                self.change_lr_warmup_callback(n_epochs_warmup=self.config.lr_warmup_epochs)
            self.train_epoch(data='train')
            self.evaluate_model(data='test')
            sys.stdout.write(
                f"\nTest accuracy after this epoch: {self.test_accuracy.result()} (lr = {float(self.optimizer.param_groups[0]['lr'])})\n")

            if epoch == n_epochs_finetune and phase == n_phases:
                # Training complete, log the time
                self.totalFinetuneTime = time.time() - self.finetuneStartTime

            # As opposed to previous runs, push information every epoch, but only link to desired_sparsity at end
            dsParam = None
            if epoch == n_epochs_finetune and phase == n_phases:
                dsParam = desired_sparsity
            self.log(runTime=time.time() - t, finetuning=True, desired_sparsity=dsParam)

    def train_epoch(self, data='train', evaluation_mode=False):
        if data == 'train':
            loader = self.trainLoader
            mean_loss, mean_accuracy, mean_k_accuracy = self.train_loss, self.train_accuracy, self.train_k_accuracy
        elif data == 'test':
            loader = self.testLoader
            mean_loss, mean_accuracy, mean_k_accuracy = self.test_loss, self.test_accuracy, self.test_k_accuracy
        assert not (data == ' test' and evaluation_mode == False), "Can't train on test set."
        if evaluation_mode:
            with torch.no_grad():
                sys.stdout.write(f"Evaluation of {data} data:\n")
                self.strategy.start_forward_mode()  # Called once before evaluating whole dataset, different to training
                for x_input, y_target in Bar(loader):
                    x_input, y_target = x_input.to(self.device), y_target.to(self.device)  # Move to CUDA if possible
                    output = self.model.eval()(x_input)

                    loss = self.loss_criterion(output, y_target)
                    mean_loss(loss.item(), len(y_target))
                    mean_accuracy(Utils.categorical_accuracy(y_true=y_target, output=output), len(y_target))
                    mean_k_accuracy(Utils.categorical_accuracy(y_true=y_target, output=output, topk=self.k_accuracy),
                                    len(y_target))
                self.strategy.end_forward_mode()   # Called once after evaluating whole dataset, different to training

        else:
            sys.stdout.write(f"Training:\n")
            for x_input, y_target in Bar(self.trainLoader):
                if self.warm_up_fn is not None:
                    # This is the original warmup from SLR
                    iteration_lr = self.warm_up_fn(self.trainIteration)
                    for param_group in self.optimizer.param_groups:
                        param_group['lr'] = iteration_lr
                    self.trainIteration += 1
                if self.momentum_warmup_fn is not None:
                    iteration_momentum = self.momentum_warmup_fn(self.momentum_trainIteration)
                    for param_group in self.optimizer.param_groups:
                        param_group['momentum'] = iteration_momentum
                    self.momentum_trainIteration += 1
                if self.lr_warmup_fn is not None:
                    # This is the new general lr warmup
                    iteration_lr = self.lr_warmup_fn(self.lr_trainIteration)
                    for param_group in self.optimizer.param_groups:
                        param_group['lr'] = iteration_lr
                    self.lr_trainIteration += 1
                x_input, y_target = x_input.to(self.device), y_target.to(self.device)  # Move to CUDA if possible
                self.optimizer.zero_grad()  # Zero the gradient buffers
                self.strategy.start_forward_mode(enable_grad=True)
                output = self.model.train()(x_input)
                loss = self.loss_criterion(output, y_target)
                loss = self.strategy.before_backward(loss=loss, weight_decay=self.config.weight_decay, penalty=self.config.lmbd)
                loss.backward()  # Backpropagation
                self.strategy.during_training(opt=self.optimizer, trainIteration=self.trainIterationCtr)
                self.optimizer.step()
                #self.strategy.enforce_prunedness()
                self.strategy.end_forward_mode()    # Has no effect for DPF
                self.train_loss(loss.item(), len(y_target))
                self.train_accuracy(Utils.categorical_accuracy(y_true=y_target, output=output), len(y_target))
                self.train_k_accuracy(Utils.categorical_accuracy(y_true=y_target, output=output, topk=self.k_accuracy),
                                      len(y_target))
                self.strategy.after_training_iteration(it=self.trainIterationCtr)
                self.trainIterationCtr += 1
                if self.per_iteration_update:
                    self.scheduler.step()

    def train(self):
        if self.config.momentum_warmup_epochs and self.config.momentum_warmup_epochs > 0:
            self.change_momentum_warmup_callback(n_epochs_warmup=self.config.momentum_warmup_epochs)
        if self.config.lr_warmup_epochs and self.config.lr_warmup_epochs > 0:
            self.change_lr_warmup_callback(n_epochs_warmup=self.config.lr_warmup_epochs)
        self.training_warmup_fn = None
        if self.config.training_warmup_epochs and self.config.training_warmup_epochs > 0:
            goalLR = self.initial_lr
            def fn(epoch):
                if epoch == self.config.training_warmup_epochs:
                    self.training_warmup_fn = None
                return goalLR * float(epoch)/self.config.training_warmup_epochs
            self.training_warmup_fn = fn
        trainStartTime = time.time()
        for epoch in range(self.config.nepochs + 1):
            self.reset_averaged_metrics()
            sys.stdout.write(f"\n\nEpoch {epoch}/{self.config.nepochs}\n")
            t = time.time()
            if epoch == 0:
                # Just evaluate the model once to get the metrics
                if self.debug_mode:
                    # Skip thistep
                    sys.stdout.write(f"Skipping since we are in debug mode")
                    continue
                self.evaluate_model(data='train')
                epoch_lr = float(self.optimizer.param_groups[0]['lr'])
            else:
                # Train
                if self.training_warmup_fn:
                    self.update_learning_rates(training_warmup_fn_epoch=epoch)
                self.train_epoch(data='train')
                # Save the learning rate for potential rewinding before updating
                epoch_lr = float(self.optimizer.param_groups[0]['lr'])
            self.evaluate_model(data='test')
            self.strategy.at_epoch_end(epoch=epoch, epoch_lr=epoch_lr, train_loss=self.train_loss.result())

            # Save info to wandb
            if epoch == self.config.nepochs:
                # Training complete, log the time
                self.totalTrainTime = time.time() - trainStartTime
            self.log(runTime=time.time() - t)
            if epoch > 0 and epoch < self.config.nepochs:
                # Update the learning rates (moved this to ensure that learning rates are logged correctly)
                self.update_learning_rates()
            if self.config.abort_active:
                if not self.debug_mode and Utils.check_abort_condition(accuracy=self.train_accuracy.result(), epoch=epoch,
                                                                       nepochs=self.config.nepochs, threshold=0.25):
                    self.aborted_model = True
                    wandb.run.tags = ['abort_condition']
                    wandb.run.save()
                    break
        self.trained_test_accuracy = self.test_accuracy.result()

    def run(self):
        if self.config.fixed_init:
            # If not existing, start a new model, otherwise use existing one with same run-id
            entity, project = wandb.run.entity, wandb.run.project
            api = wandb.Api()
            filterDict = {"$and": [{"config.run_id": self.config.run_id}, {"config.fixed_init": True}, ]}
            runs = api.runs(f"{entity}/{project}", filters=filterDict)
            for run in runs:
                if run.state == 'failed':
                    # Ignore this run
                    continue
                self.trainedModelFile = run.summary.get('initial_model_file')
                seed = run.config.get('seed')
                if self.trainedModelFile is not None and seed is not None:
                    run.file(self.trainedModelFile).download(root=wandb.run.dir)
                    self.seed = seed
                    break
            outputStr = f"Found {self.trainedModelFile} with seed {seed}" if self.trainedModelFile is not None else "Nothing found."
            sys.stdout.write(f"Trying to find reference initial model in project: {outputStr}\n")

        if self.seed is None:
            # Generate a random seed
            self.seed = int( (os.getpid() + 1) * time.time()) % 2**32

        wandb.config.update({'seed':self.seed}) # Push the seed to wandb

        # Set a unique random seed
        np.random.seed(self.seed)
        torch.manual_seed(self.seed)
        # Remark: If you are working with a multi-GPU model, this function is insufficient to get determinism. To seed all GPUs, use manual_seed_all().
        torch.cuda.manual_seed(self.seed)  # This works if CUDA not available

        torch.backends.cudnn.benchmark = True

        self.trainLoader, self.testLoader = self.get_dataloaders()
        self.model = self.get_model(load_initial=(self.trainedModelFile is not None))
        # Save initial model before training
        if self.config.fixed_init:
            if self.trainedModelFile is None:
                self.trainedModelFile = f"initial_model_run-{self.config.run_id}_seed-{self.seed}.pt"
                sys.stdout.write(f"Creating {self.trainedModelFile}.\n")
                self.save_model(model_type='initial')
            wandb.summary['initial_model_file'] = self.trainedModelFile
            wandb.save(self.trainedModelFile)
        self.strategy = self.define_strategy()
        self.strategy.after_initialization(model=self.model)
        self.define_optimizer_scheduler()

        self.strategy.at_train_begin(model=self.model, LRScheduler=self.scheduler)
        self.save_model(model_type='untrained')

        # wandb.watch(self.model, log='all')
        if self.config.extended_logging:
            # Track additional metrics (doubles up memory)
            self.comparison_model = copy.deepcopy(self.model)

        # Do initial prune if necessary
        self.strategy.initial_prune()

        self.train()
        if self.aborted_model and not self.debug_mode: return  # Abort the run
        # Save trained (unpruned) model
        self.checkpoint_file = self.save_model(model_type='trained')

        self.strategy.start_forward_mode()
        self.trained_norm_square = Utils.get_model_norm_square(model=self.model)

        self.strategy.end_forward_mode()
        self.strategy.at_train_end(model=self.model, finetuning_callback=self.fine_tuning,
                                   restore_callback=self.restore_model, save_model_callback=self.save_model,
                                   after_pruning_callback=self.after_pruning_callback, opt=self.optimizer)

        self.strategy.final(model=self.model, final_log_callback=self.final_log)
