# ===========================================================================
# Project:      Sparse Model Soups
# File:         ensembleRunner.py
# Description:  Runner class for starting from pruned models
# ===========================================================================
import itertools
import math
import os
import sys
import json
import warnings
from typing import List
from utilities.lr_schedulers import SequentialSchedulers
import numpy as np
import torch
import wandb
from tqdm.auto import tqdm
from collections import OrderedDict

from torch import nn
from torch.nn.utils import parametrize

from utilities.utilities import Utilities as Utils, WorstClassAccuracy, CalibrationError, Candidate

from runners.baseRunner import baseRunner
from runners.pretrainedRunner import pretrainedRunner
from strategies import ensembleStrategies
from torch.cuda.amp import autocast
from torchmetrics import MeanMetric
from torchmetrics.classification import MulticlassAccuracy as Accuracy

import torch.nn.utils.prune as prune


class ensembleRunner(baseRunner):

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.teacher_model = None
        self.parametrization_list = None
        self.k_splits_per_ensemble = None

    def find_multiple_existing_models(self, filterDict):
        """Finds existing wandb runs and downloads the model files."""
        current_phase = self.config.phase  # We are in the same phase
        filterDict['$and'].append({'config.phase': current_phase})
        filterDict['$and'].append({'config.n_splits_total': self.config.n_splits_total})
        sys.stdout.write(f"Structured pruning: {self.config.prune_structured}.\n")
        if current_phase > 1:
            # We need to specify the previous ensemble method as well
            filterDict['$and'].append({'config.ensemble_method': self.config.ensemble_method})
            filterDict['$and'].append({
                'config.split_id': self.config.split_id})  # This restricts us to stay with the same split in every phase, but this is okay
            filterDict['$and'].append({'config.k_splits_per_ensemble': self.config.k_splits_per_ensemble})

            # Add the learned soup parameters
            filterDict['$and'].append({'config.ls_n_epochs': self.config.ls_n_epochs})
            filterDict['$and'].append({'config.ls_lr': self.config.ls_lr})
            filterDict['$and'].append({'config.ls_use_softmax': self.config.ls_use_softmax})
            filterDict['$and'].append({'config.ls_dataset': self.config.ls_dataset})
            filterDict['$and'].append({'config.ls_use_kd': self.config.ls_use_kd})
            filterDict['$and'].append({'config.ls_kd_temp': self.config.ls_kd_temp})
            filterDict['$and'].append({'config.ls_optimizer': self.config.ls_optimizer})

        filterDict['$and'].append({'config.ensemble_by': self.config.ensemble_by})
        filterDict['$and'].append({'config.prune_structured': self.config.prune_structured})
        entity, project = wandb.run.entity, wandb.run.project
        api = wandb.Api()
        candidate_model_list = []

        # Some variables have to be extracted from the filterDict and checked manually, e.g. weight decay in scientific format
        manualVariables = ['weight_decay', 'penalty', 'group_penalty']
        manVarDict = {}
        dropIndices = []
        for var in manualVariables:
            for i in range(len(filterDict['$and'])):
                entry = filterDict['$and'][i]
                s = f"config.{var}"
                if s in entry:
                    dropIndices.append(i)
                    manVarDict[var] = entry[s]
        for idx in reversed(sorted(dropIndices)): filterDict['$and'].pop(idx)

        runs = api.runs(f"{entity}/{project}", filters=filterDict)
        runsExist = False  # If True, then there exist runs that try to set a fixed init
        for run in runs:
            if run.state != 'finished':
                # Ignore this run
                continue
            # Check if run satisfies the manual variables
            conflict = False
            for var, val in manVarDict.items():
                if var in run.config and run.config[var] != val:
                    conflict = True
                    break
            if conflict:
                continue
            sys.stdout.write(f"Trying to access {run.name}.\n")
            checkpoint_file = run.summary.get('final_model_file')
            try:
                if checkpoint_file is not None:
                    runsExist = True
                    sys.stdout.write(
                        f"Downloading pruned model with split {run.config['ensemble_by']} value: {run.config['split_val']}.\n")
                    run.file(checkpoint_file).download(
                        root=self.tmp_dir)
                    self.seed = run.config['seed']
                    candidate_id = (run.config['split_val'])
                    candidate_model_list.append(
                        Candidate(candidate_id=candidate_id, candidate_file=os.path.join(self.tmp_dir, checkpoint_file),
                                  candidate_run=run))
            except Exception as e:  # The run is online, but the model is not uploaded yet -> results in failing runs
                print(e)
                checkpoint_file = None
                break
        assert not (
                runsExist and checkpoint_file is None), "Runs found, but one of them has no model available -> abort."
        outputStr = f"Found {len(candidate_model_list)} pruned models with split vals {sorted([c.id for c in candidate_model_list])}" \
            if checkpoint_file is not None else "Nothing found."
        sys.stdout.write(f"Trying to find reference pruned models in project: {outputStr}\n")
        assert checkpoint_file is not None, "One of the pruned models has no model file to download, Aborting."
        assert len(candidate_model_list) == self.config.n_splits_total, "Not all pruned models were found, Aborting.\n"

        # Check whether we want to find a specific split-set
        if self.config.split_id is not None:
            sorted_split_vals = sorted(
                [c.id for c in candidate_model_list])  # Sort this to ensure deterministic order of combinations
            # Generate the set of all possible split combinations
            splitCombinations = itertools.combinations(sorted_split_vals, self.config.k_splits_per_ensemble)
            # Pick the combination with the split_id
            desired_split = list(list(splitCombinations)[self.config.split_id - 1])
            # Filter the candidate model list
            candidate_model_list = [c for c in candidate_model_list if c.id in desired_split]
            sys.stdout.write(
                f"Desired split: {desired_split} - Reduced the candidate model list to {len(candidate_model_list)} models with split vals {sorted([c.id for c in candidate_model_list])}.\n")

        return candidate_model_list

    def define_optimizer_scheduler(self):
        # Define the optimizer
        if self.config.optimizer == 'SGD':
            self.optimizer = torch.optim.SGD(params=self.model.parameters(), lr=0.)

    def transport_information(self, ref_run):
        missing_config_keys = ['momentum',
                               'n_epochs_warmup',
                               'n_epochs']  # Have to have n_epochs even though it might be specified, otherwise ALLR doesnt have this

        additional_dict = {
            'last_training_lr': ref_run.summary['final.learning_rate'],
            'final.test.accuracy': ref_run.summary['final.test']['accuracy'],
            'final.train.accuracy': ref_run.summary['final.train']['accuracy'],
            'final.train.loss': ref_run.summary['final.train']['loss'],
        }
        for key in missing_config_keys:
            if key not in self.config or self.config[key] is None:
                # Allow_val_change = true because e.g. momentum defaults to None, but shouldn't be passed here
                val = ref_run.config.get(key)  # If not found, defaults to None
                self.config.update({key: val}, allow_val_change=True)
        self.config.update(additional_dict)

        self.trained_test_accuracy = additional_dict['final.test.accuracy']
        self.trained_train_loss = additional_dict['final.train.loss']
        self.trained_train_accuracy = additional_dict['final.train.accuracy']

        # Get the wandb information about lr and fill the corresponding strategy dicts, which can then be used by rewinders
        f = ref_run.file('iteration-lr-dict.json').download(root=self.tmp_dir)
        with open(f.name) as json_file:
            loaded_dict = json.load(json_file)
            lr_dict = OrderedDict(loaded_dict)
        # Upload iteration-lr dict from self.strategy to be used during retraining
        Utils.dump_dict_to_json_wandb(dumpDict=lr_dict, name='iteration-lr-dict')

    def remove_parametrizations(self):
        """Remove all parametrizations from the model."""
        for module, param_type in self.parametrization_list:
            parametrize.remove_parametrizations(module, param_type)

    def remove_pruning_hooks(self, model):
        """Remove all pruning hooks from the model."""
        mask_dict = {}
        parameters_to_prune = [(module, 'weight') for name, module in model.named_modules() if
                               hasattr(module, 'weight')
                               and not isinstance(module.weight, type(None)) and not isinstance(module,
                                                                                                torch.nn.BatchNorm2d)]
        for module, param_type in parameters_to_prune:
            if prune.is_pruned(module):
                # Save the mask
                mask = getattr(module, param_type + '_mask')
                mask_dict[(module, param_type)] = mask.detach().clone()
                prune.remove(module=module, name=param_type)
        return mask_dict

    def reinit_pruning_hooks(self, mask_dict):
        """Reinitialize all pruning hooks from the model."""
        for (module, param_type), mask in mask_dict.items():
            prune.custom_from_mask(module=module, name=param_type, mask=mask)

    def get_weighted_model(self, state_dict_list: List, ensemble_weights: torch.Tensor):
        # Note: THIS DOES NOT WORK WITH DATAPARALLEL MODELS
        self.parametrization_list = []  # Reset at the beginning of every call to this fn

        # Create a new state dict for the weighted average model
        avg_state_dict = {}

        if self.config.ls_use_softmax:
            # Softmax the weights
            ensemble_weights = torch.nn.functional.softmax(ensemble_weights, dim=0)

        # Push the ensemble weights to CPU
        ensemble_weights = ensemble_weights.cpu()

        # Iterate through the model's named parameters
        for name, _ in self.model.named_parameters():
            # Strip the initial 'module.' from the name
            strippedName = name
            if name.startswith('module.'):
                strippedName = name[7:]

            # Accumulate the weighted parameters from the loaded state dicts
            weighted_params = []
            for i, state_dict in enumerate(state_dict_list):
                if strippedName in state_dict:
                    weighted_params.append(state_dict[strippedName] * ensemble_weights[i])
                elif strippedName + '_orig' in state_dict:
                    weighted_params.append(state_dict[strippedName + '_orig'] * ensemble_weights[i])

            avg_state_dict[name] = sum(weighted_params)
            avg_state_dict[name] = avg_state_dict[name].to(device=self.device)

        # We implement a parametrization class that uses the corresponding avg_state_dict entry as weight
        class ReplaceParametrization(nn.Module):
            def __init__(self, replace_tensor: torch.Tensor):
                super().__init__()
                self.replace_tensor = replace_tensor

            def forward(self, input):
                return self.replace_tensor

        def get_module(obj, names):
            """Recursive function to get a module by its state_dict name"""
            if len(names) == 1:
                return obj, names[0]
            else:
                return get_module(getattr(obj, names[0]), names[1:])

        for name, _ in list(self.model.named_parameters()):
            module, param_type = get_module(obj=self.model, names=name.split('.'))

            # We disable the gradient of the actual weights
            getattr(module, param_type).requires_grad = False

            # We register a parametrization, since otherwise we will lose the grad_fn
            reparametrization_instance = ReplaceParametrization(avg_state_dict[name])

            parametrize.register_parametrization(module, param_type, reparametrization_instance)
            self.parametrization_list.append((module, param_type))

        self.model = self.model.to(device=self.device)

        # Push the ensemble weights back to CUDA
        ensemble_weights = ensemble_weights.to(device=self.device)

    def learn_soup(self, candidate_model_list: List[Candidate]):
        # Define a tensor of weights for the candidates that requires grad
        n_candidates = len(candidate_model_list)
        ensemble_weights = torch.ones(n_candidates, device=self.device) / n_candidates
        ensemble_weights.requires_grad = True

        # Define the optimizer
        if self.config.ls_optimizer == 'Adam':
            ensemble_optimizer = torch.optim.Adam(params=[ensemble_weights], lr=self.config.ls_lr)
        elif self.config.ls_optimizer == 'SGD':
            ensemble_optimizer = torch.optim.SGD(params=[ensemble_weights], lr=self.config.ls_lr, momentum=0.9)
        ensemble_gradScaler = torch.cuda.amp.GradScaler(enabled=(self.config.use_amp is True))

        loader = self.valLoader
        if self.config.ls_dataset is not None:
            assert self.config.ls_dataset in ['train', 'val']
            if self.config.ls_dataset == 'train':
                loader = self.trainLoader

        # Define the linear learning rate scheduler
        n_iter_total = self.config.ls_n_epochs * len(loader)
        n_iter_warmup = int(0.1 * n_iter_total)
        n_iter_remaining = n_iter_total - n_iter_warmup
        # Set the initial learning rate
        for param_group in ensemble_optimizer.param_groups: param_group['lr'] = self.config.ls_lr
        warmup_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer=ensemble_optimizer,
                                                             start_factor=1e-20, end_factor=1.,
                                                             total_iters=n_iter_warmup)
        scheduler = torch.optim.lr_scheduler.LinearLR(optimizer=ensemble_optimizer,
                                                      start_factor=1., end_factor=0.,
                                                      total_iters=n_iter_remaining)
        # Reset base lrs to make this work
        scheduler.base_lrs = [self.config.ls_lr for _ in ensemble_optimizer.param_groups]

        scheduler = SequentialSchedulers(optimizer=ensemble_optimizer, schedulers=[warmup_scheduler, scheduler],
                                         milestones=[n_iter_warmup + 1])

        ls_hparams = {key: val for key, val in self.config.items() if key.startswith('ls_')}
        sys.stdout.write(
            f"Learning the soup with {n_candidates} candidates and the following settings: {ls_hparams}.\n")

        # Preload the state dicts
        state_dict_paths = [candidate.file for candidate in candidate_model_list]
        state_dict_list = [torch.load(state_dict_path, map_location=torch.device('cpu')) for state_dict_path in
                           state_dict_paths]

        # Freeze the model weights of each state dict
        for state_dict in state_dict_list:
            for k, v in state_dict.items():
                state_dict[k] = v.clone().detach().requires_grad_(False)

        lossMeter, accMeter = MeanMetric().to(device=self.device), Accuracy(num_classes=self.n_classes).to(
            device=self.device)

        # Remove pruning hooks and parametrizations
        mask_dict = self.remove_pruning_hooks(self.model)

        kd_temp = self.config.ls_kd_temp or 1.0
        use_kd = self.config.ls_use_kd is True
        if use_kd:
            sys.stdout.write(f"Using KD with temperature {kd_temp}. Loading the model of the previous phase.\n")
            filterDict = {"$and": [{"config.run_id": self.config.run_id},
                                   {"config.arch": self.config.arch},
                                   {"config.optimizer": self.config.optimizer},
                                   ]}
            if self.config.learning_rate is not None:
                warnings.warn(
                    "You specified an explicit learning rate for retraining. Note that this only controls the selection of the pretrained model.")
                filterDict["$and"].append({"config.learning_rate": self.config.learning_rate})
            if self.config.n_epochs is not None:
                warnings.warn(
                    "You specified n_epochs for retraining. Note that this only controls the selection of the pretrained model.")
                filterDict["$and"].append({"config.n_epochs": self.config.n_epochs})
            checkpoint_file, _, _ = pretrainedRunner.find_existing_model(self, filterDict=filterDict)
            if checkpoint_file is None:
                raise ValueError("Could not find a pretrained model to use for KD.")
            sys.stdout.write(f"Loading the model from {checkpoint_file}.\n")
            old_checkpoint = self.checkpoint_file
            self.checkpoint_file = checkpoint_file
            self.teacher_model = self.get_model(reinit=True, temporary=True)
            self.checkpoint_file = old_checkpoint

        best_loss_yet, best_weights_yet = float(
            'inf'), ensemble_weights.detach().clone()  # We keep track of the best weights since this is a relatively small tensor

        # Iterate over the validation set
        for epoch in range(self.config.ls_n_epochs):
            sys.stdout.write(f"Learning the soup: Epoch {epoch + 1}/{self.config.ls_n_epochs}.\n")
            lossMeter.reset()
            accMeter.reset()
            with tqdm(loader, leave=True) as pbar:
                for x_input, y_target, indices in pbar:
                    x_input = x_input.to(self.device, non_blocking=True)
                    y_target = y_target.to(self.device, non_blocking=True)
                    ensemble_optimizer.zero_grad()  # Zero the gradient buffers

                    # Ensemble the candidates with the current weights
                    self.get_weighted_model(state_dict_list=state_dict_list, ensemble_weights=ensemble_weights)

                    with autocast(enabled=(self.config.use_amp is True)):
                        output = self.model.train()(x_input)
                        if use_kd:
                            with torch.no_grad():
                                # This will also change the teacher since we update BN statistics, but we do not save it anywhere to use in the future
                                teacher_output = self.teacher_model.train()(x_input)  # Logits
                                # Apply the temperature
                                teacher_output /= kd_temp
                                # Get probabilities
                                teacher_output = torch.nn.functional.softmax(teacher_output, dim=1)  # Softmax(Logits)

                            loss = self.loss_criterion(output, teacher_output)
                        else:
                            loss = self.loss_criterion(output, y_target)

                    ensemble_gradScaler.scale(loss).backward()  # Scaling + Backpropagation
                    ensemble_gradScaler.step(ensemble_optimizer)  # Optimization step
                    ensemble_gradScaler.update()
                    scheduler.step()
                    self.remove_parametrizations()
                    lossMeter(value=loss, weight=len(y_target))
                    accuracy_labels = y_target
                    if use_kd:
                        # Get the labels from the teacher
                        accuracy_labels = torch.argmax(teacher_output, dim=1)
                    accMeter(output, accuracy_labels)
                epoch_loss = lossMeter.compute()
                if epoch_loss < best_loss_yet:
                    best_loss_yet = epoch_loss
                    best_weights_yet = ensemble_weights.detach().clone()
                    sys.stdout.write(f"New best loss - {best_loss_yet}\n")
            sys.stdout.write("Loss: {:.4f} - Accuracy: {:.4f}\n".format(epoch_loss, accMeter.compute()))

        if self.config.ls_use_softmax:
            # Apply a final softmax to the weights
            best_weights_yet = torch.nn.functional.softmax(best_weights_yet, dim=0)

        sys.stdout.write(f"Finished learning the soup - Best Ensemble Weights {best_weights_yet}.\n")

        # Reinit pruning mask
        self.reinit_pruning_hooks(mask_dict)

        return best_weights_yet

    def load_soup_model(self, ensemble_state_dict):
        # Save the ensemble state dict
        fName = f"ensemble_model.pt"
        fPath = os.path.join(self.tmp_dir, fName)
        torch.save(ensemble_state_dict, fPath)  # Save the state_dict
        self.checkpoint_file = fName

        # Actually load the model
        self.model = self.get_model(reinit=True, temporary=True)  # Load the ensembled model

    def evaluate_soup(self, data='val', ensemble_labels: torch.Tensor = None):
        # Perform an evaluation pass
        AccuracyMeter = Accuracy(num_classes=self.n_classes).to(device=self.device)
        ECEMeter = CalibrationError(norm='l1').to(device=self.device)
        MCEMeter = CalibrationError(norm='max').to(device=self.device)
        WorstClassAccuracyMeter = WorstClassAccuracy(num_classes=self.n_classes).to(device=self.device)

        if data == 'val':
            loader = self.valLoader
        elif data == 'test':
            loader = self.testLoader
        elif data == 'ood':
            loader = self.oodLoader
            if loader is None:
                sys.stdout.write(f"No OOD data found, skipping OOD evaluation.\n")
                return {}
        else:
            raise NotImplementedError

        if ensemble_labels is not None:
            sys.stdout.write(f"Performing computation of prediction ensemble {data} accuracy.\n")
        else:
            sys.stdout.write(f"Performing computation of soup {data} accuracy.\n")
        with tqdm(loader, leave=True) as pbar:
            for x_input, y_target, indices in pbar:
                # Move to CUDA if possible
                x_input = x_input.to(self.device, non_blocking=True)
                indices = indices.to(self.device, non_blocking=True)
                if ensemble_labels is not None:
                    y_target = ensemble_labels[indices]  # Avg probs/predictions of batch
                y_target = y_target.to(self.device, non_blocking=True)

                with autocast(enabled=(self.config.use_amp is True)):
                    output = self.model.train(mode=False)(x_input)
                    AccuracyMeter(output, y_target)
                    ECEMeter(output, y_target)
                    MCEMeter(output, y_target)
                    WorstClassAccuracyMeter(output, y_target)

        outputDict = {
            'accuracy': AccuracyMeter.compute().item(),
            'ece': ECEMeter.compute().item(),
            'mce': MCEMeter.compute().item(),
            'worst_class_accuracy': WorstClassAccuracyMeter.compute().item(),
        }
        return outputDict

    @torch.no_grad()
    def collect_avg_output_full(self, data: str, candidate_model_list: List[Candidate]):
        output_type = 'soft_prediction'
        assert data in ['val', 'test']
        if data == 'val':
            loader = self.valLoader
        else:
            loader = self.testLoader
        sys.stdout.write(f"\nCollecting ensemble prediction.\n")

        compute_avg_probs = (output_type in ['softmax', 'soft_prediction'])
        store_tensor = torch.zeros(len(loader.dataset), self.n_classes, device=self.device)  # On CUDA for now

        for candidate in candidate_model_list:
            # Load the candidate model
            candidate_id, candidate_file = candidate.id, candidate.file
            if self.model is not None:
                del self.model
                torch.cuda.empty_cache()

            state_dict = torch.load(candidate_file,
                                    map_location=torch.device('cpu'))
            self.load_soup_model(ensemble_state_dict=state_dict)
            with tqdm(loader, leave=True) as pbar:
                for x_input, _, indices in pbar:
                    x_input = x_input.to(self.device, non_blocking=True)  # Move to CUDA if possible
                    with autocast(enabled=(self.config.use_amp is True)):
                        output = self.model.eval()(x_input)  # Logits
                        probabilities = torch.nn.functional.softmax(output, dim=1)  # Softmax(Logits)
                        if compute_avg_probs:
                            # Just add the probabilities for the average
                            store_tensor[indices] += probabilities
                        else:
                            # Add the prediction as one hot
                            binary_tensor = torch.zeros_like(store_tensor[indices])
                            # Add the ones at corresponding entries
                            binary_tensor[torch.arange(binary_tensor.size(0)).unsqueeze(1), torch.argmax(probabilities,
                                                                                                         dim=1).unsqueeze(
                                1)] = 1.

                            store_tensor[indices] += binary_tensor

        if compute_avg_probs:
            store_tensor.mul_(1. / len(candidate_model_list))  # Weighting
        else:
            assert store_tensor.sum() == (len(candidate_model_list) * len(loader.dataset))

        if output_type in ['soft_prediction', 'hard_prediction']:
            # Take the prediction given average probabilities OR Take the most frequent prediction
            store_tensor = torch.argmax(store_tensor, dim=1)

        return store_tensor

    def run(self):
        """Function controlling the workflow of pretrainedRunner"""
        assert self.config.ensemble_by in ['pruned_seed', 'weight_decay', 'retrain_length', 'retrain_schedule']
        assert self.config.n_splits_total is not None
        assert self.config.split_val is None
        assert not (self.config.k_splits_per_ensemble is None) ^ (
                self.config.split_id is None), "Both should either be None or not None"

        if self.config.k_splits_per_ensemble is not None:
            # Compute the number of available splits as n choose k
            n = self.config.n_splits_total
            k = self.config.k_splits_per_ensemble
            assert 1 <= self.config.split_id <= math.comb(n,
                                                          k), f"Split id {self.config.split_id} > {math.comb(n, k)} is not valid, Aborting."




        # Find the reference run
        filterDict = {"$and": [{"config.run_id": self.config.run_id},
                               {"config.arch": self.config.arch},
                               {"config.optimizer": self.config.optimizer},
                               {"config.goal_sparsity": self.config.goal_sparsity},
                               {"config.n_epochs_per_phase": self.config.n_epochs_per_phase},
                               {"config.n_phases": self.config.n_phases},
                               {"config.retrain_schedule": self.config.retrain_schedule},
                               {"config.strategy": 'IMP'},
                               {"config.extended_imp": self.config.extended_imp},
                               {'config.prune_structured': self.config.prune_structured}
                               ]}

        if self.config.learning_rate is not None:
            warnings.warn(
                "You specified an explicit learning rate for retraining. Note that this only controls the selection of the pretrained model.")
            filterDict["$and"].append({"config.learning_rate": self.config.learning_rate})
        if self.config.n_epochs is not None:
            warnings.warn(
                "You specified n_epochs for retraining. Note that this only controls the selection of the pretrained model.")
            filterDict["$and"].append({"config.n_epochs": self.config.n_epochs})


        candidate_models = self.find_multiple_existing_models(filterDict=filterDict)
        wandb.config.update({'seed': self.seed})  # Push the seed to wandb

        # Set a unique random seed
        np.random.seed(self.seed)
        torch.manual_seed(self.seed)
        # Remark: If you are working with a multi-GPU model, this function is insufficient to get determinism. To seed all GPUs, use manual_seed_all().
        torch.cuda.manual_seed(self.seed)  # This works if CUDA not available

        torch.backends.cudnn.benchmark = True

        self.transport_information(ref_run=candidate_models[0].run)

        self.trainLoader, self.valLoader, self.testLoader, self.trainLoader_unshuffled = self.get_dataloaders()
        self.oodLoader = self.get_ood_dataloaders()

        # We first define the ensembling strategy, create the ensemble, then use the 'Dense' strategy and regularly
        # load the model
        # Define callbacks finetuning_callback, restore_callback, save_model_callback
        callbackDict = {
            'final_log_callback': self.final_log,
            'soup_evaluation_callback': self.evaluate_soup,
            'load_soup_callback': self.load_soup_model,
            'recalibrate_bn_callback': self.recalibrate_bn,
        }
        self.ensemble_strategy = getattr(ensembleStrategies, self.config.ensemble_method)(model=None,
                                                                                          n_classes=self.n_classes,
                                                                                          config=self.config,
                                                                                          candidate_models=candidate_models,
                                                                                          runner=self,
                                                                                          callbacks=callbackDict)

        self.ensemble_strategy.collect_candidate_information()

        # Create ensemble
        ensemble_state_dict = self.ensemble_strategy.create_ensemble()

        # Save the ensemble state dict
        fName = f"ensemble_model.pt"
        fPath = os.path.join(self.tmp_dir, fName)
        torch.save(ensemble_state_dict, fPath)  # Save the state_dict
        self.checkpoint_file = fName

        # Actually load the model
        self.model = self.get_model(reinit=True, temporary=True)  # Load the ensembled model

        if self.config.prune_after_merge:
            sys.stdout.write("Pruning to restore original sparsity.\n")
            assert self.config.pruning_selector in [None, 'global']
            # We prune globally again
            parameters_to_prune = [(module, 'weight') for name, module in self.model.named_modules() if
                                   hasattr(module, 'weight')
                                   and not isinstance(module.weight, type(None)) and not isinstance(module,
                                                                                                    torch.nn.BatchNorm2d)]
            prune.global_unstructured(
                parameters_to_prune,
                pruning_method=prune.L1Unstructured,
                amount=self.config.goal_sparsity,
            )
            for module, param_type in parameters_to_prune:
                if prune.is_pruned(module):
                    prune.remove(module, param_type)

        # Create 'Dense' as the Base Strategy
        self.strategy = self.define_strategy(use_dense_base=True)
        self.strategy.after_initialization()

        # Define optimizer to not get errors in the main evaluation (even though we do not actually use the optimizer)
        self.define_optimizer_scheduler()

        # Evaluate ensemble
        self.ensemble_strategy.final()

        self.checkpoint_file = self.save_model(model_type='ensemble')
        wandb.summary['final_model_file'] = f"ensemble_model_{self.config.ensemble_method}_{self.config.phase}.pt"
