import importlib
import os
import shutil
import sys
import tempfile
import time
from collections import OrderedDict
from math import sqrt

import numpy as np
import pandas as pd
import torch
import wandb
from barbar import Bar
from torch.cuda.amp import autocast
from torchmetrics import MeanMetric, Accuracy
from utilities.utilities import WorstClassAccuracy

from config import datasetDict, trainTransformDict, testTransformDict
from metrics import metrics
from optimizers.optimizers import SGD
from strategies import scratchStrategies, pretrainedStrategies
from utilities.lr_schedulers import SequentialSchedulers, FixedLR
from utilities.wd_schedulers import StepWD
from utilities.utilities import FairnessStatistics
from utilities.utilities import Utilities as Utils
import torch.nn.utils.prune as prune

class baseRunner:
    """Base class for all runners, defines the general functions"""

    def __init__(self, config):
        self.config = config
        self.dataParallel = (torch.cuda.device_count() > 1)
        if not self.dataParallel:
            self.device = torch.device(config.device)
            if 'gpu' in config.device:
                torch.cuda.set_device(self.device)
        else:
            # Use all visible GPUs
            self.device = torch.device("cuda:0")
            torch.cuda.device(self.device)

        # Set a couple useful variables
        self.checkpoint_file = None
        self.trained_test_accuracy = None
        self.trained_train_loss = None
        self.trained_train_accuracy = None
        self.after_pruning_metrics = None
        self.seed = None
        self.squared_model_norm = None
        self.n_warmup_epochs = None
        self.trainIterationCtr = 1
        self.tmp_dir = tempfile.mkdtemp()
        self.ampGradScaler = None  # Note: this must be reset before training, and before retraining
        self.num_workers = None

        # Variables to be set by inheriting classes
        self.strategy = None
        self.trainLoader = None
        self.valLoader = None
        self.testLoader = None
        self.n_datapoints = None
        self.model = None
        self.dense_model = None
        self.wd_scheduler = None
        self.label_vector = None
        self.trainData = None
        self.n_total_iterations = None

        self.ultimate_log_dict = None

        if self.config.dataset in ['mnist', 'cifar10']:
            self.n_classes = 10
        elif self.config.dataset in ['cifar100']:
            self.n_classes = 100
        elif self.config.dataset in ['tinyimagenet']:
            self.n_classes = 200
        elif self.config.dataset in ['imagenet']:
            self.n_classes = 1000
        else:
            raise NotImplementedError

        self.frequency_collected = False    # If true, sample_freq does not get updated anymore
        self.sample_freq = torch.zeros(self.n_classes).to(device=self.config.device)

        # Define the loss object and metrics
        # Important note: for the correct computation of loss/accuracy it's important to have reduction == 'mean'
        self.loss_criterion = torch.nn.CrossEntropyLoss(reduction='mean').to(device=self.device)

        k_accuracy = 5 if self.config.dataset in ['cifar100', 'imagenet', 'tinyimagenet'] else 3
        self.metrics = {mode: {'loss': MeanMetric().to(device=self.device),
                               'accuracy': Accuracy().to(device=self.device),
                               'k_accuracy': Accuracy(top_k=k_accuracy).to(device=self.device),
                               'worst_group_accuracy': WorstClassAccuracy(num_classes=self.n_classes).to(device=self.device),
                               'ips_throughput': MeanMetric().to(device=self.device)}
                        for mode in ['train', 'val', 'test']}

        self.class_statistics = {
            mode: {time: FairnessStatistics(n_classes=self.n_classes, device=self.config.device) for time in
                   ['pruned', 'retrained']} for mode in ['train', 'val', 'test']}

    def reset_averaged_metrics(self):
        """Resets all metrics"""
        for mode in self.metrics.keys():
            for metric in self.metrics[mode].values():
                metric.reset()

    def get_metrics(self):
        with torch.no_grad():
            self.strategy.start_forward_mode()  # Necessary to have correct computations for DPF
            n_total, n_nonzero = metrics.get_parameter_count(model=self.model)

            x_input, y_target, indices = next(iter(self.valLoader))
            x_input, y_target = x_input.to(self.device), y_target.to(self.device)  # Move to CUDA if possible
            n_flops, n_nonzero_flops = metrics.get_flops(model=self.model, x_input=x_input)

            distance_to_pruned, rel_distance_to_pruned = {}, {}
            if self.config.goal_sparsity is not None:
                distance_to_pruned, rel_distance_to_pruned = metrics.get_distance_to_pruned(model=self.model,
                                                                                            sparsity=self.config.goal_sparsity)

            loggingDict = dict(
                train={metric_name: metric.compute() for metric_name, metric in self.metrics['train'].items()},
                val={metric_name: metric.compute() for metric_name, metric in self.metrics['val'].items()},
                global_sparsity=metrics.global_sparsity(module=self.model),
                modular_sparsity=metrics.modular_sparsity(parameters_to_prune=self.strategy.parameters_to_prune),
                n_total_params=n_total,
                n_nonzero_params=n_nonzero,
                nonzero_inference_flops=n_nonzero_flops,
                baseline_inference_flops=n_flops,
                theoretical_speedup=metrics.get_theoretical_speedup(n_flops=n_flops, n_nonzero_flops=n_nonzero_flops),
                learning_rate=float(self.optimizer.param_groups[0]['lr']),
                distance_to_origin=metrics.get_distance_to_origin(self.model),
                distance_to_pruned=distance_to_pruned,
                rel_distance_to_pruned=rel_distance_to_pruned,
            )

            loggingDict['test'] = dict()
            for metric_name, metric in self.metrics['test'].items():
                try:
                    # Catch case where MeanMetric mode not set yet
                    loggingDict['test'][metric_name] = metric.compute()
                except Exception as e:
                    continue

            self.strategy.end_forward_mode()  # Necessary to have correct computations for DPF
        return loggingDict

    def get_dataloaders(self):
        # Determine where the data lies
        for root in ['/software/ais2t/pytorch_datasets/', './datasets_pytorch/']:
            rootPath = f"{root}{self.config.dataset}"
            if os.path.isdir(rootPath):
                break

        if root == '/software/ais2t/pytorch_datasets/':
            # We copy the data to have it on locally attached hardware
            local = '/scratch/local/'
            if not os.path.isdir(os.path.join(local, 'mzimmer')): os.mkdir(os.path.join(local, 'mzimmer'))
            local = local + 'mzimmer/'
            localPath = f"{local}{self.config.dataset}"
            inProcessFile = os.path.join(local, f"{self.config.dataset}-inprocess.lock")
            doneFile = os.path.join(local, f"{self.config.dataset}-donefile.lock")

            wait_it = 0
            while True:
                is_done = os.path.exists(doneFile) and os.path.isdir(f"{local}{self.config.dataset}")
                is_busy = os.path.exists(inProcessFile)
                if is_done:
                    # Dataset exists locally, continue with the training
                    rootPath = f"{local}{self.config.dataset}"
                    print("Local data storage: Done file exists.")
                    break
                elif is_busy:
                    # Wait for 10 seconds, then check again
                    time.sleep(10)
                    print("Local data storage: Is still busy - wait.")
                    continue
                else:
                    # Create the inProcessFile
                    open(inProcessFile, mode='a').close()

                    # Copy the dataset
                    print("Local data storage: Starts copying.")
                    shutil.copytree(src=rootPath, dst=localPath)
                    print("Local data storage: Copying done.")
                    # Create the doneFile
                    open(doneFile, mode='a').close()

                    # Remove the inProcessFile
                    os.remove(inProcessFile)

                wait_it += 1
                if wait_it == 360:
                    # Waited 1 hour, this should be done by now, check for errors
                    raise Exception("Waiting time too long.")

        if self.config.dataset in ['imagenet']:
            trainData = Utils.get_overloaded_dataset(datasetDict[self.config.dataset])(root=rootPath, split='train',
                                                         transform=trainTransformDict[self.config.dataset])
            testData = Utils.get_overloaded_dataset(datasetDict[self.config.dataset])(root=rootPath, split='val',
                                                        transform=testTransformDict[self.config.dataset])
        elif self.config.dataset == 'tinyimagenet':
            traindir = os.path.join(rootPath, 'train')
            valdir = os.path.join(rootPath, 'val')
            trainData = Utils.get_overloaded_dataset(datasetDict[self.config.dataset])(root=traindir,
                                                         transform=trainTransformDict[self.config.dataset])
            testData = Utils.get_overloaded_dataset(datasetDict[self.config.dataset])(root=valdir, transform=testTransformDict[self.config.dataset])
        else:
            trainData = Utils.get_overloaded_dataset(datasetDict[self.config.dataset])(root=rootPath, train=True, download=True,
                                                         transform=trainTransformDict[self.config.dataset])

            testData = Utils.get_overloaded_dataset(datasetDict[self.config.dataset])(root=rootPath, train=False,
                                                        transform=testTransformDict[self.config.dataset])
        train_size = int(0.9 * len(trainData))
        val_size = len(trainData) - train_size
        self.trainData, valData = torch.utils.data.random_split(trainData, [train_size, val_size],
                                                           generator=torch.Generator().manual_seed(42))
        self.n_datapoints = train_size
        # This only works if we do not have a class with label -5
        self.label_vector = torch.zeros(train_size + val_size, dtype=torch.int64).fill_(-5).to(device=self.device)
        if self.config.dataset in ['imagenet', 'cifar100', 'tinyimagenet']:
            self.num_workers = 4 * torch.cuda.device_count() if torch.cuda.is_available() else 0
        else:
            self.num_workers = 2 if torch.cuda.is_available() else 0

        trainLoader = torch.utils.data.DataLoader(self.trainData, batch_size=self.config.batch_size, shuffle=True,
                                                  pin_memory=torch.cuda.is_available(), num_workers=self.num_workers)
        valLoader = torch.utils.data.DataLoader(valData, batch_size=self.config.batch_size, shuffle=False,
                                                pin_memory=torch.cuda.is_available(), num_workers=self.num_workers)
        testLoader = torch.utils.data.DataLoader(testData, batch_size=self.config.batch_size, shuffle=False,
                                                 pin_memory=torch.cuda.is_available(), num_workers=self.num_workers)

        return trainLoader, valLoader, testLoader

    def get_model(self, reinit: bool, temporary: bool = True) -> torch.nn.Module:
        if reinit:
            # Define the model
            model = getattr(importlib.import_module('models.' + self.config.dataset), self.config.arch)()
        else:
            # The model has been initialized already
            model = self.model
            # Note, we have to get rid of all existing prunings, otherwise we cannot load the state_dict as it is unpruned
            if self.strategy:
                self.strategy.make_pruning_permant(model=self.model)

        file = self.checkpoint_file
        if file is not None:
            dir = wandb.run.dir if not temporary else self.tmp_dir
            fPath = os.path.join(dir, file)

            state_dict = torch.load(fPath, map_location=self.device)

            new_state_dict = OrderedDict()
            require_DP_format = isinstance(model,
                                           torch.nn.DataParallel)  # If true, ensure all keys start with "module."
            for k, v in state_dict.items():
                is_in_DP_format = k.startswith("module.")
                if require_DP_format and is_in_DP_format:
                    name = k
                elif require_DP_format and not is_in_DP_format:
                    name = "module." + k  # Add 'module' prefix
                elif not require_DP_format and is_in_DP_format:
                    name = k[7:]  # Remove 'module.'
                elif not require_DP_format and not is_in_DP_format:
                    name = k

                v_new = v  # Remains unchanged if not in _orig format
                if k.endswith("_orig"):
                    # We loaded the _orig tensor and corresponding mask
                    name = name[:-5]  # Truncate the "_orig"
                    if f"{k[:-5]}_mask" in state_dict.keys():
                        v_new = v * state_dict[f"{k[:-5]}_mask"]

                new_state_dict[name] = v_new

            maskKeys = [k for k in new_state_dict.keys() if k.endswith("_mask")]
            for k in maskKeys:
                del new_state_dict[k]

            # Load the state_dict
            model.load_state_dict(new_state_dict)

        if self.dataParallel and reinit and not isinstance(model,
                                                           torch.nn.DataParallel):  # Only apply DataParallel when re-initializing the model!
            # We use DataParallelism
            model = torch.nn.DataParallel(model)
        model = model.to(device=self.device)
        return model

    def define_optimizer_scheduler(self):
        # Learning rate scheduler in the form (type, kwargs)
        tupleStr = self.config.learning_rate.strip()
        # Remove parenthesis
        if tupleStr[0] == '(':
            tupleStr = tupleStr[1:]
        if tupleStr[-1] == ')':
            tupleStr = tupleStr[:-1]
        name, *kwargs = tupleStr.split(',')
        if name in ['StepLR', 'MultiStepLR', 'ExponentialLR', 'Linear', 'Cosine', 'Constant']:
            scheduler = (name, kwargs)
            self.initial_lr = float(kwargs[0])
        else:
            raise NotImplementedError(f"LR Scheduler {name} not implemented.")

        # Define the optimizer
        if self.config.optimizer == 'SGD':
            wd = self.config['weight_decay'] or 0.
            self.optimizer = SGD(params=self.model.parameters(), lr=self.initial_lr,
                                 momentum=self.config.momentum,
                                 weight_decay=wd, nesterov=wd > 0.)

        # We define a scheduler. All schedulers work on a per-iteration basis
        iterations_per_epoch = len(self.trainLoader)
        n_total_iterations = iterations_per_epoch * self.config.n_epochs
        self.n_total_iterations = n_total_iterations
        n_warmup_iterations = 0

        # Set the initial learning rate
        for param_group in self.optimizer.param_groups: param_group['lr'] = self.initial_lr

        # Define the warmup scheduler if needed
        warmup_scheduler, milestone = None, None
        if self.config.n_epochs_warmup and self.config.n_epochs_warmup > 0:
            assert int(
                self.config.n_epochs_warmup) == self.config.n_epochs_warmup, "At the moment no float warmup allowed."
            n_warmup_iterations = int(float(self.config.n_epochs_warmup) * iterations_per_epoch)
            # As a start factor we use 1e-20, to avoid division by zero when putting 0.
            warmup_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer=self.optimizer,
                                                                 start_factor=1e-20, end_factor=1.,
                                                                 total_iters=n_warmup_iterations)
            milestone = n_warmup_iterations + 1

        n_remaining_iterations = n_total_iterations - n_warmup_iterations

        name, kwargs = scheduler
        scheduler = None
        if name == 'Constant':
            scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer=self.optimizer,
                                                            factor=1.0,
                                                            total_iters=n_remaining_iterations)
        elif name == 'StepLR':
            # Tuple of form ('StepLR', initial_lr, step_size, gamma)
            # Reduces initial_lr by gamma every step_size epochs
            step_size, gamma = int(kwargs[1]), float(kwargs[2])

            # Convert to iterations
            step_size = iterations_per_epoch * step_size

            scheduler = torch.optim.lr_scheduler.StepLR(optimizer=self.optimizer, step_size=step_size,
                                                        gamma=gamma)
        elif name == 'MultiStepLR':
            # Tuple of form ('MultiStepLR', initial_lr, milestones, gamma)
            # Reduces initial_lr by gamma every epoch that is in the list milestones
            milestones, gamma = kwargs[1].strip(), float(kwargs[2])
            # Remove square bracket
            if milestones[0] == '[':
                milestones = milestones[1:]
            if milestones[-1] == ']':
                milestones = milestones[:-1]
            # Convert to iterations directly
            milestones = [int(ms) * iterations_per_epoch for ms in milestones.split('|')]
            scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer=self.optimizer, milestones=milestones,
                                                             gamma=gamma)
        elif name == 'ExponentialLR':
            # Tuple of form ('ExponentialLR', initial_lr, gamma)
            gamma = float(kwargs[1])
            scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=self.optimizer, gamma=gamma)
        elif name == 'Linear':
            if len(kwargs) == 2:
                # The final learning rate has also been passed
                end_factor = float(kwargs[1])/float(kwargs[0])
            else:
                end_factor = 0.
            scheduler = torch.optim.lr_scheduler.LinearLR(optimizer=self.optimizer,
                                                          start_factor=1.0, end_factor=end_factor,
                                                          total_iters=n_remaining_iterations)
        elif name == 'Cosine':
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer,
                                                                   T_max=n_remaining_iterations, eta_min=0.)

        # Reset base lrs to make this work
        scheduler.base_lrs = [self.initial_lr if warmup_scheduler else 0. for _ in self.optimizer.param_groups]

        # Define the Sequential Scheduler
        if warmup_scheduler is None:
            self.scheduler = scheduler
        elif name in ['StepLR', 'MultiStepLR']:
            # We need parallel schedulers, since the steps should be counted during warmup
            self.scheduler = torch.optim.lr_scheduler.ChainedScheduler(schedulers=[warmup_scheduler, scheduler])
        else:
            self.scheduler = SequentialSchedulers(optimizer=self.optimizer, schedulers=[warmup_scheduler, scheduler],
                                                  milestones=[milestone])

    def define_strategy(self):
        #### UNSTRUCTURED
        # Define callbacksfinetuning_callback, restore_callback, save_model_callback
        callbackDict = {
            'after_pruning_callback': self.after_pruning_callback,
            'finetuning_callback': self.fine_tuning,
            'restore_callback': self.restore_model,
            'save_model_callback': self.save_model,
            'final_log_callback': self.final_log,
            'gradient_estimation_callback': self.gradient_estimation_callback,
            'data_parallel_callback': self.data_parallel_callback,
        }
        # Base strategies
        if self.config.use_pretrained is not None:
            # Use pretrained model
            return getattr(pretrainedStrategies, self.config.strategy)(model=self.model, n_classes=self.n_classes,
                                                                       config=self.config, callbacks=callbackDict)
        else:
            # Start from scratch
            return getattr(scratchStrategies, self.config.strategy)(model=self.model, n_classes=self.n_classes,
                                                                    config=self.config, callbacks=callbackDict)

    def log(self, runTime, finetuning: bool = False, final_logging: bool = False):
        loggingDict = self.get_metrics()
        self.strategy.start_forward_mode()
        loggingDict.update({'epoch_run_time': runTime})
        if not finetuning:
            # Update final trained metrics (necessary to be able to filter via wandb)
            for metric_type, val in loggingDict.items():
                wandb.run.summary[f"trained.{metric_type}"] = val
            # The usual logging of one epoch
            wandb.log(
                loggingDict
            )

        else:
            if not final_logging:
                wandb.log(
                    dict(finetune=loggingDict,
                         ),
                )
            else:
                # We add the after_pruning_metrics and don't commit, since the values are updated by self.final_log
                self.ultimate_log_dict = dict(finetune=loggingDict,
                                              pruned=self.after_pruning_metrics,
                                              )
        self.strategy.end_forward_mode()

    def final_log(self):
        """This function can ONLY be called by pretrained strategies using the final sparsified model"""
        # Recompute accuracy and loss
        sys.stdout.write(
            f"\nFinal logging\n")
        self.reset_averaged_metrics()
        if self.config.collect_class_statistics:
            # No need to evaluate the entire train dataset again
            self.evaluate_model(data='train', class_statistics_mode='retrained')
        self.evaluate_model(data='val', class_statistics_mode='retrained')
        self.evaluate_model(data='test', class_statistics_mode='retrained')

        # Update final trained metrics (necessary to be able to filter via wandb)
        loggingDict = self.get_metrics()
        for metric_type, val in loggingDict.items():
            wandb.run.summary[f"final.{metric_type}"] = val

        for mode in self.class_statistics.keys():
            generalResults, classResults = self.class_statistics[mode]['retrained'].get_results()
            serialized_dict = pd.json_normalize(generalResults, sep='.')
            for metric_type, val in serialized_dict.to_dict(orient='records')[0].items():
                wandb.run.summary[f"final.fairness.{mode}.{metric_type}"] = val
            Utils.dump_dict_to_json_wandb(dumpDict=classResults,
                                          name=f"final.fairness.{mode}")

        # Update after prune metrics
        for metric_type, val in self.after_pruning_metrics.items():
            wandb.run.summary[f"pruned.{metric_type}"] = val

        # Add to existing self.ultimate_log_dict which was not commited yet
        if self.ultimate_log_dict is not None:
            if loggingDict['train']['accuracy'] == 0:
                # we did not perform the recomputation, use the old values for train
                del loggingDict['train']

            self.ultimate_log_dict['finetune'].update(loggingDict)
        else:
            self.ultimate_log_dict = {'finetune':loggingDict}

        if self.config.collect_class_statistics:
            self.ultimate_log_dict['retrained.fairness'] = {mode: dict() for mode in self.class_statistics.keys()}
            for mode in self.class_statistics.keys():
                generalResults, classResults = self.class_statistics[mode]['retrained'].get_results()
                serialized_dict = pd.json_normalize(generalResults, sep='.')
                for metric_type, val in serialized_dict.to_dict(orient='records')[0].items():
                    self.ultimate_log_dict['retrained.fairness'][mode][metric_type] = val
                Utils.dump_dict_to_json_wandb(dumpDict=classResults,
                                              name=f"retrained.fairness.{mode}")
        wandb.log(self.ultimate_log_dict)
        Utils.dump_dict_to_json_wandb(metrics.per_layer_sparsity(model=self.model), 'sparsity_distribution')

    def after_pruning_callback(self):
        """Collects pruning metrics. Is called ONCE per run, namely on the LAST PRUNING step."""

        # Make the pruning permanent (this is in conflict with strategies that do not have a permanent pruning)
        self.strategy.enforce_prunedness()

        # Compute losses, accuracies after pruning
        sys.stdout.write(f"\nGoal sparsity reached - Computing incurred losses after pruning.\n")
        self.reset_averaged_metrics()

        self.evaluate_model(data='train', class_statistics_mode='pruned')
        self.evaluate_model(data='val', class_statistics_mode='pruned')
        self.evaluate_model(data='test', class_statistics_mode='pruned')
        if self.squared_model_norm is not None:
            L2_norm_square = Utils.get_model_norm_square(self.model)
            norm_drop = sqrt(abs(self.squared_model_norm - L2_norm_square))
            if float(sqrt(self.squared_model_norm)) > 0:
                relative_norm_drop = norm_drop / float(sqrt(self.squared_model_norm))
            else:
                relative_norm_drop = {}
        else:
            norm_drop, relative_norm_drop = {}, {}

        pruning_instability, pruning_stability = {}, {}
        train_loss_increase, relative_train_loss_increase_factor = {}, {}
        if self.trained_test_accuracy is not None and self.trained_test_accuracy > 0:
            pruning_instability = (
                                          self.trained_test_accuracy - self.metrics['test'][
                                      'accuracy'].compute()) / self.trained_test_accuracy
            pruning_stability = 1 - pruning_instability
        if self.trained_train_loss is not None and self.trained_train_loss > 0:
            train_loss_increase = self.metrics['train']['loss'].compute() - self.trained_train_loss
            relative_train_loss_increase_factor = train_loss_increase / self.trained_train_loss

        self.after_pruning_metrics = dict(
            train={metric_name: metric.compute() for metric_name, metric in self.metrics['train'].items()},
            val={metric_name: metric.compute() for metric_name, metric in self.metrics['val'].items()},
            test={metric_name: metric.compute() for metric_name, metric in self.metrics['test'].items()},
            norm_drop=norm_drop,
            relative_norm_drop=relative_norm_drop,
            pruning_instability=pruning_instability,
            pruning_stability=pruning_stability,
            train_loss_increase=train_loss_increase,
            relative_train_loss_increase_factor=relative_train_loss_increase_factor,
            fairness={mode: dict() for mode in self.class_statistics.keys()},
        )
        if self.config.collect_class_statistics:
            for mode in self.class_statistics.keys():
                generalResults, classResults = self.class_statistics[mode]['pruned'].get_results()
                serialized_dict = pd.json_normalize(generalResults, sep='.')
                for metric_type, val in serialized_dict.to_dict(orient='records')[0].items():
                    self.after_pruning_metrics["fairness"][mode][metric_type] = val
                Utils.dump_dict_to_json_wandb(dumpDict=classResults,
                                              name=f"pruned.fairness.{mode}")

        # Reset squared model norm for following pruning steps, otherwise ALLR does not work properly
        self.squared_model_norm = Utils.get_model_norm_square(model=self.model)

    def restore_model(self) -> None:
        sys.stdout.write(
            f"Restoring model from {self.checkpoint_file}.\n")
        self.model = self.get_model(reinit=False, temporary=True)

    def save_model(self, model_type: str, remove_pruning_hooks: bool = False, temporary: bool = False) -> str:
        if model_type not in ['initial', 'trained']:
            print(f"Ignoring to save {model_type} for now.")
            return None
        fName = f"{model_type}_model.pt"
        fPath = os.path.join(wandb.run.dir, fName) if not temporary else os.path.join(self.tmp_dir, fName)
        if remove_pruning_hooks:
            self.strategy.make_pruning_permant(model=self.model)

        # Only save models in their non-module version, to avoid problems when loading
        try:
            model_state_dict = self.model.module.state_dict()
        except AttributeError:
            model_state_dict = self.model.state_dict()

        torch.save(model_state_dict, fPath)  # Save the state_dict
        return fPath

    def evaluate_model(self, data='train', class_statistics_mode=None):
        return self.train_epoch(data=data, is_training=False, class_statistics_mode=class_statistics_mode)

    def define_retrain_schedule(self, n_epochs_finetune, pruning_sparsity, phase):
        """Define the retraining schedule.
            - Tuneable schedules all require both an initial value as well as a warmup length
            - Fixed schedules require no additional parameters and are mere conversions such as LRW
        """
        tuneable_schedules = ['constant',  # Constant learning rate
                              'stepped',  # Stepped Budget Aware Conversion (BAC)
                              'cosine',  # Cosine from initial value -> 0
                              'linear',  # Linear from initial value -> 0
                              ]
        fixed_schedules = ['FT',  # Use last lr of original training as schedule (Han et al.), no warmup
                           'LRW',  # Learning Rate Rewinding (Renda et al.), no warmup
                           'SLR',  # Scaled Learning Rate Restarting (Le et al.), maxLR init, 10% warmup
                           'CLR',  # Cyclic Learning Rate Restarting (Le et al.), maxLR init, 10% warmup
                           'LLR',  # Linear from the largest original lr to 0, maxLR init, 10% warmup
                           'ALLR',  # LLR, but in the last phase behave like LCN
                           'LossALLR',  # ALLR, but use the relative increase in loss instead of the norm drop
                           'AccALLR',  # ALLR, but use the relative decrease in train accuracy instead of the norm drop
                           ]

        # Define the initial lr, max lr and min lr
        maxLR = max(
            self.strategy.lr_dict.values())
        after_warmup_index = (self.config.n_epochs_warmup or 0) * len(self.trainLoader)
        minLR = min(list(self.strategy.lr_dict.values())[after_warmup_index:])  # Ignores warmup in orig. schedule

        n_total_iterations = len(self.trainLoader) * n_epochs_finetune
        if self.config.retrain_schedule in tuneable_schedules:
            assert self.config.retrain_schedule_init is not None
            assert self.config.retrain_schedule_warmup is not None

            n_warmup_iterations = int(self.config.retrain_schedule_warmup * n_total_iterations)
            after_warmup_lr = self.config.retrain_schedule_init
        elif self.config.retrain_schedule in fixed_schedules:
            assert self.config.retrain_schedule_init is None
            assert self.config.retrain_schedule_warmup is None

            # Define warmup length
            if self.config.retrain_schedule in ['FT', 'LRW']:
                n_warmup_iterations = 0
            else:
                # 10% warmup
                n_warmup_iterations = int(0.1 * n_total_iterations)

            # Define the after_warmup_lr
            if self.config.retrain_schedule == 'FT':
                after_warmup_lr = minLR
            elif self.config.retrain_schedule == 'LRW':
                after_warmup_lr = list(self.strategy.lr_dict.values())[-n_total_iterations]  # == remaining iterations since we don't do warmup
            elif self.config.retrain_schedule in ['ALLR', 'LossALLR', 'AccALLR']:

                if phase == self.config.n_phases or self.config.retrain_adaptive_in_every_cycle:
                    # Last phase, so do LCN
                    minLRThreshold = min(float(n_epochs_finetune) / self.config.n_epochs, 1.0) * maxLR
                    if self.config.retrain_schedule == 'ALLR':
                        # Use the norm drop
                        relative_norm_drop = self.after_pruning_metrics['relative_norm_drop']
                        scaling = relative_norm_drop / sqrt(pruning_sparsity)
                    elif self.config.retrain_schedule == 'LossALLR':
                        # Use the increase in loss
                        pruned_train_loss = self.after_pruning_metrics['train']['loss']
                        reference_train_loss = self.trained_train_loss
                        train_instability = (pruned_train_loss - reference_train_loss)/reference_train_loss
                        scaling = train_instability
                    elif self.config.retrain_schedule == 'AccALLR':
                        # Use the decrease in accuracy
                        pruned_train_acc = self.after_pruning_metrics['train']['accuracy']
                        reference_train_acc = self.trained_train_accuracy
                        train_instability = (reference_train_acc - pruned_train_acc)/reference_train_acc
                        scaling = train_instability

                    discounted_LR = float(scaling) * maxLR
                else:
                    minLRThreshold = maxLR
                    discounted_LR = maxLR

                after_warmup_lr = np.clip(discounted_LR, a_min=minLRThreshold, a_max=maxLR)

            elif self.config.retrain_schedule in ['SLR', 'CLR', 'LLR']:
                after_warmup_lr = maxLR
            else:
                raise NotImplementedError
        else:
            raise NotImplementedError

        # Set the optimizer lr
        for param_group in self.optimizer.param_groups:
            if n_warmup_iterations > 0:
                # If warmup, then we actually begin with 0 and increase to after_warmup_lr
                param_group['lr'] = 0.0
            else:
                param_group['lr'] = after_warmup_lr

        # Define warmup scheduler
        warmup_scheduler, milestone = None, None
        if n_warmup_iterations > 0:
            warmup_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR \
                (self.optimizer, T_max=n_warmup_iterations, eta_min=after_warmup_lr)
            milestone = n_warmup_iterations + 1

        # Define scheduler after the warmup
        n_remaining_iterations = n_total_iterations - n_warmup_iterations
        scheduler = None
        if self.config.retrain_schedule in ['FT', 'constant']:
            # Does essentially nothing but keeping the smallest learning rate
            scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer=self.optimizer,
                                                            factor=1.0,
                                                            total_iters=n_remaining_iterations)
        elif self.config.retrain_schedule == 'LRW':
            iterationsLR = list(self.strategy.lr_dict.values())[(-n_remaining_iterations):]
            iterationsLR.append(iterationsLR[-1])  # Double the last learning rate so we avoid the IndexError
            scheduler = FixedLR(optimizer=self.optimizer, lrList=iterationsLR)

        elif self.config.retrain_schedule in ['stepped', 'SLR']:
            iterationsLR = [lr if int(it) >= after_warmup_index else maxLR
                            for it, lr in self.strategy.lr_dict.items()]

            interpolation_width = (len(self.strategy.lr_dict)) / n_remaining_iterations  # In general not an integer
            reducedLRs = [iterationsLR[int(j * interpolation_width)] for j in range(n_remaining_iterations)]
            # Add a last LR to avoid IndexError
            reducedLRs = reducedLRs + [reducedLRs[-1]]

            lr_lambda = lambda it: reducedLRs[it] / float(maxLR)  # Function returning the correct learning rate factor
            scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lr_lambda)

        elif self.config.retrain_schedule in ['CLR', 'cosine']:
            stopLR = 0. if self.config.retrain_schedule == 'cosine' else minLR
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR \
                (self.optimizer, T_max=n_remaining_iterations, eta_min=stopLR)

        elif self.config.retrain_schedule in ['LLR', 'ALLR', 'linear', 'LossALLR', 'AccALLR']:
            scheduler = torch.optim.lr_scheduler.LinearLR(optimizer=self.optimizer,
                                                          start_factor=1.0, end_factor=0.,
                                                          total_iters=n_remaining_iterations)

        # Reset base lrs to make this work
        scheduler.base_lrs = [after_warmup_lr for _ in self.optimizer.param_groups]

        # Define the Sequential Scheduler
        if warmup_scheduler is None:
            self.scheduler = scheduler
        else:
            self.scheduler = SequentialSchedulers(optimizer=self.optimizer, schedulers=[warmup_scheduler, scheduler],
                                                  milestones=[milestone])

    def define_retrain_wd_schedule(self, n_epochs_finetune):
        wd = self.config.retrain_wd
        if wd not in [None, 'None', 'none']:
            # Set the optimizer weight decay
            for param_group in self.optimizer.param_groups:
                param_group['weight_decay'] = wd



        wd_scheduler = self.config.retrain_wd_schedule
        if wd_scheduler in [None, 'None', 'none']:
            return

        # wd scheduler in the form (type, kwargs)
        tupleStr = wd_scheduler.strip()
        # Remove parenthesis
        tupleStr = tupleStr[1:-1]
        name, *kwargs = tupleStr.split(',')
        if name in ['InitialOnly']:
            scheduler = (name, kwargs)
        else:
            raise NotImplementedError(f"Weight decay scheduler {name} not implemented.")

        iterations_per_epoch = len(self.trainLoader)
        n_total_iterations = iterations_per_epoch * n_epochs_finetune
        n_warmup_iterations = 0  # Without warmup for now
        n_remaining_iterations = n_total_iterations - n_warmup_iterations

        name, kwargs = scheduler
        scheduler = None
        if name == 'InitialOnly':
            # In the form (InitialOnly, x), where 0 <= x <= 1 is the fraction of total iterations until wd -> 0
            # Weight decay only for some epochs, then set to zero
            initial_length = int(n_remaining_iterations * float(kwargs[0]))
            scheduler = StepWD(optimizer=self.optimizer, step_size=initial_length, gamma=0.,
                                             verbose=False)

        self.wd_scheduler = scheduler

    def fine_tuning(self, pruning_sparsity, n_epochs_finetune, phase=1):
        if n_epochs_finetune == 0:
            return
        n_phases = self.config.n_phases or 1

        # Reset the GradScaler for AutoCast
        self.ampGradScaler = torch.cuda.amp.GradScaler(enabled=(self.config.use_amp is True))

        # Update the retrain schedule individually for every phase/cycle
        self.define_retrain_schedule(n_epochs_finetune=n_epochs_finetune,
                                     pruning_sparsity=pruning_sparsity, phase=phase)
        self.define_retrain_wd_schedule(n_epochs_finetune=n_epochs_finetune)

        self.strategy.set_to_finetuning_phase()
        for epoch in range(1, n_epochs_finetune + 1, 1):
            self.reset_averaged_metrics()
            sys.stdout.write(
                f"\nFinetuning: phase {phase}/{n_phases} | epoch {epoch}/{n_epochs_finetune}\n")
            # Train
            t = time.time()
            self.train_epoch(data='train')
            self.evaluate_model(data='val')
            sys.stdout.write(
                f"\nVal accuracy after this epoch: {self.metrics['val']['accuracy'].compute()} (lr = {float(self.optimizer.param_groups[0]['lr'])})\n")

            self.strategy.at_epoch_end(epoch=epoch)
            self.log(runTime=time.time() - t, finetuning=True,
                     final_logging=(epoch == n_epochs_finetune and phase == n_phases))

    def train_epoch(self, data='train', is_training=True, class_statistics_mode=None):
        assert not (data in ['test', 'val'] and is_training), "Can't train on test/val set."
        assert not (class_statistics_mode and is_training)
        loaderDict = {'train': self.trainLoader,
                      'val': self.valLoader,
                      'test': self.testLoader}
        loader = loaderDict[data]
        if is_training:
            self.loss_criterion = self.strategy.adjust_loss_fn(loss_criterion=self.loss_criterion)
            sampler = self.strategy.adjust_train_sampler(collection_done=self.frequency_collected,
                                                       sample_freqs=self.sample_freq, label_vector=self.label_vector)
            loader = torch.utils.data.DataLoader(self.trainData, batch_size=self.config.batch_size, shuffle=(sampler is None),
                                                  pin_memory=torch.cuda.is_available(), num_workers=self.num_workers, sampler=sampler)
        else:
            self.loss_criterion = torch.nn.CrossEntropyLoss(reduction='mean').to(device=self.device)

        sys.stdout.write(f"Training:\n") if is_training else sys.stdout.write(
            f"Evaluation of {data} data{' with class stats' if self.config.collect_class_statistics and class_statistics_mode else ''}:\n")

        with torch.set_grad_enabled(is_training):
            for x_input, y_target, indices in Bar(loader):
                # Move to CUDA if possible
                x_input = x_input.to(self.device, non_blocking=True)
                y_target = y_target.to(self.device, non_blocking=True)
                indices = indices.to(self.device, non_blocking=True)
                self.optimizer.zero_grad()  # Zero the gradient buffers

                itStartTime = time.time()

                self.strategy.start_forward_mode(enable_grad=is_training)
                if is_training:
                    with autocast(enabled=(self.config.use_amp is True)):
                        # Adjust y_target if necessary
                        # We use y_target for accuracy computation, y_target_train for training
                        y_target_train = self.strategy.adjust_train_target(x_input=x_input, y_target=y_target,
                                                                     dense_model=self.dense_model)

                        output = self.model.train()(x_input)
                        loss = self.loss_criterion(output, y_target_train)
                        loss = self.strategy.modify_loss(loss=loss, x_input=x_input, output=output, dense_model=self.dense_model)

                        loss = self.strategy.before_backward(loss=loss, weight_decay=self.config.weight_decay)

                    self.ampGradScaler.scale(loss).backward()  # Scaling + Backpropagation
                    # Unscale the weights manually, normally this would be done by ampGradScaler.step(), but since
                    # we might add something to the grads with during_training(), this has to be split
                    self.ampGradScaler.unscale_(self.optimizer)
                    # Potentially update the gradients
                    self.strategy.during_training(trainIteration=self.trainIterationCtr)
                    self.ampGradScaler.step(self.optimizer)
                    self.ampGradScaler.update()

                    self.strategy.end_forward_mode()  # Has no effect for DPF
                    self.strategy.after_training_iteration(it=self.trainIterationCtr,
                                                           lr=float(self.optimizer.param_groups[0]['lr']))
                    self.scheduler.step()
                    if self.wd_scheduler:
                        self.wd_scheduler.step()
                    self.trainIterationCtr += 1
                else:
                    with autocast(enabled=(self.config.use_amp is True)):
                        # We use train(mode=True) for the training dataset such that we do not get the drop in loss because of running average of BN
                        # Note however that this will change the running stats and consequently also slightly the evaluation of val/eval datasets
                        output = self.model.train(mode=(data == 'train'))(x_input)
                        loss = self.loss_criterion(output, y_target)
                        if class_statistics_mode is not None and self.config.collect_class_statistics:
                            fairnessMeter = self.class_statistics[data][class_statistics_mode]
                            # Important: this has to be kept eval even for eval on the train set
                            # since otherwise we change the running stats of BN which affect test/val set as well
                            output_dense = self.dense_model.eval()(x_input)
                            fairnessMeter(output=output, output_dense=output_dense, y_true=y_target)

                        if not self.frequency_collected and data == 'train':
                            occ, cnt = torch.unique(y_target, return_counts=True)
                            self.sample_freq[occ] += cnt
                            self.label_vector[indices] = y_target

                    self.strategy.end_forward_mode()  # Has no effect for DPF
                itEndTime = time.time()
                n_img_in_iteration = len(y_target)
                ips = n_img_in_iteration/(itEndTime - itStartTime) # Images processed per second

                self.metrics[data]['loss'](value=loss, weight=len(y_target))
                self.metrics[data]['accuracy'](output, y_target)
                self.metrics[data]['k_accuracy'](output, y_target)
                self.metrics[data]['worst_group_accuracy'](output, y_target)
                self.metrics[data]['ips_throughput'](ips)
        if data == 'train' and not is_training:
            # Gets disabled after the first train evaluation
            self.frequency_collected = True

    def train(self):
        self.ampGradScaler = torch.cuda.amp.GradScaler(enabled=(self.config.use_amp is True))
        for epoch in range(self.config.n_epochs + 1):
            self.reset_averaged_metrics()
            sys.stdout.write(f"\n\nEpoch {epoch}/{self.config.n_epochs}\n")
            t = time.time()
            if epoch == 0:
                # Just evaluate the model once to get the metrics
                self.evaluate_model(data='train')
            else:
                # Train
                self.train_epoch(data='train')
            self.evaluate_model(data='val')

            if epoch == self.config.n_epochs:
                # Do one complete evaluation on the test data set
                self.evaluate_model(data='test')

            self.strategy.at_epoch_end(epoch=epoch)

            self.log(runTime=time.time() - t)

        self.trained_test_accuracy = self.metrics['test']['accuracy'].compute()
        self.trained_train_loss = self.metrics['train']['loss'].compute()

    def gradient_estimation_callback(self):
        sys.stdout.write("Estimating gradients on a single batch.\n")
        loader = self.trainLoader
        ampGradScaler = torch.cuda.amp.GradScaler(enabled=(self.config.use_amp is True))
        loss_criterion = torch.nn.CrossEntropyLoss(reduction='mean')
        gradientDict = {(module, paramType):torch.zeros_like(getattr(module, paramType), requires_grad=False).to(device=self.device).flatten() for (module, paramType) in self.strategy.parameters_to_prune}
        with torch.set_grad_enabled(True):
            x_input, y_target = next(iter(loader))
            # Move to CUDA if possible
            x_input = x_input.to(self.device, non_blocking=True)
            y_target = y_target.to(self.device, non_blocking=True)
            self.optimizer.zero_grad()  # Zero the gradient buffers
            with autocast(enabled=(self.config.use_amp is True)):
                # Adjust y_target if necessary
                output = self.model.eval()(x_input)
                loss = loss_criterion(output, y_target)

            ampGradScaler.scale(loss).backward()  # Scaling + Backpropagation
            # Unscale the weights manually, normally this would be done by ampGradScaler.step(), but since
            # we might add something to the grads with during_training(), this has to be split
            ampGradScaler.unscale_(self.optimizer)
            #ampGradScaler.step(self.optimizer)
            ampGradScaler.update()
            with torch.no_grad():
                for (module, paramType) in self.strategy.parameters_to_prune:
                    if prune.is_pruned(module):
                        d_p = getattr(module, paramType + "_orig").grad
                    else:
                        d_p = getattr(module, paramType).grad
                    if d_p is not None:
                        gradientDict[(module, paramType)] = d_p.flatten()
        return gradientDict

    def data_parallel_callback(self, to_DP):
        sys.stdout.write(f"\nIn base before: {isinstance(self.model, torch.nn.DataParallel)}\n")
        if to_DP:
            # Move model to DP again
            if not isinstance(self.model, torch.nn.DataParallel):
                self.model = torch.nn.DataParallel(self.model)
        else:
            # Remove DP from model
            if isinstance(self.model, torch.nn.DataParallel):
                self.model = self.model.module
        sys.stdout.write(f"\nIn base after: {isinstance(self.model, torch.nn.DataParallel)}\n")
        return self.model