# ===========================================================================
# Project:      Sparse Model Soups
# File:         strategies/ensembleStrategies.py
# Description:  Strategies for building a soup.
# ===========================================================================
import copy
import importlib
import sys

import torch
import torch.nn.utils.prune as prune
from collections import OrderedDict

import metrics.metrics
from strategies import strategies as usual_strategies
import numpy as np
from utilities.utilities import LAMPUnstructured
from utilities.utilities import Utilities as Utils
from utilities.utilities import Candidate
from typing import NamedTuple
from collections import defaultdict
from scipy.optimize import linear_sum_assignment
from torch import nn
import torch.nn.functional as F





#### Base Class
class EnsemblingBaseClass(usual_strategies.Dense):
    """Ensembling Base Class"""

    def __init__(self, **kwargs) -> None:
        super().__init__(**kwargs)
        self.candidate_model_list = kwargs['candidate_models']
        self.runner = kwargs['runner']
        self.selected_models = None
        self.soup_metrics = {soup_type: {} for soup_type in ['candidates', 'selected']}

    @torch.no_grad()
    def get_soup_metrics(self, soup_list: list[Candidate]):

        # Load the models
        model_list = [candidate.get_model_weights() for candidate in soup_list]

        soup_metrics = {
            'max_barycentre_distance': Utils.get_barycentre_l2_distance(model_list),
            'min_barycentre_distance': Utils.get_barycentre_l2_distance(model_list, maximize=False),
        }

        for metric_name, metric_fn in zip(['l2_distance', 'angle'], [Utils.get_l2_distance, Utils.get_angle]):
            for agg_name, agg_fn in zip(['max', 'min', 'mean'], [torch.max, torch.min, torch.mean]):
                soup_metrics[f'{agg_name}_{metric_name}'] = Utils.aggregate_group_metrics(models=model_list,
                                                                                          metric_fn=metric_fn,
                                                                                          aggregate_fn=agg_fn)
        return soup_metrics

    def collect_candidate_information(self):
        model_list = []
        metrics_dict = {split: defaultdict(list) for split in ['test', 'ood']}
        for candidate in self.candidate_model_list:
            candidate_id, candidate_file = candidate.id, candidate.file
            if self.runner.model is not None:
                del self.runner.model
                torch.cuda.empty_cache()

            state_dict = torch.load(candidate_file,
                                    map_location=torch.device('cpu'))  # Load to CPU to avoid memory overhead
            self.runner.load_soup_model(ensemble_state_dict=state_dict)
            m, _ = Utils.split_weights_and_masks(state_dict)
            model_list.append(m)
            del state_dict
            self.runner.recalibrate_bn()

            # Collect and set test/ood metrics
            for split in ['test', 'ood']:
                single_model_metrics = self.runner.evaluate_soup(data=split)
                for metric, value in single_model_metrics.items():
                    metrics_dict[split][metric].append(value)
                candidate.set_metrics(metrics=single_model_metrics, split=split)

            # Collect metrics that are needed for other strategies to perform model selection
            single_model_val_metrics = self.runner.evaluate_soup(data='val')
            candidate.set_metrics(metrics=single_model_val_metrics, split='val')

        # Collect a lot of soup metrics
        candidates_soup_metrics = self.get_soup_metrics(soup_list=self.candidate_model_list)
        self.soup_metrics['candidates'] = candidates_soup_metrics
        for split in ['test', 'ood']:
            for aggName, aggFunc in zip(['mean', 'max'], [np.mean, np.max]):
                for metric, values in metrics_dict[split].items():
                    self.soup_metrics['candidates'][f'{split}.{metric}_{aggName}'] = aggFunc(values)

        # Collect prediction ensemble metrics
        ensemble_labels = self.runner.collect_avg_output_full(data='test',
                                                              candidate_model_list=self.candidate_model_list)
        ensemble_metrics = {
            'pred_ensemble.test': self.runner.evaluate_soup(data='test', ensemble_labels=ensemble_labels)}
        self.soup_metrics['candidates'].update(ensemble_metrics)

        sys.stdout.write(f"Test accuracies of ensemble runs: {metrics_dict['test']['accuracy']}.\n")

    def create_ensemble(self, **kwargs):
        n_models = len(self.candidate_model_list)
        assert n_models >= 2, "Not enough models to ensemble"
        self.enforce_prunedness()

    @torch.no_grad()
    def enforce_prunedness(self, device=torch.device('cpu')):
        """Enforce prunedness of the model"""
        for candidate in self.candidate_model_list:
            candidate.enforce_prunedness(device=device)

    @torch.no_grad()
    def average_models(self, soup_list: list[Candidate], soup_weights: torch.Tensor = None,
                       device: torch.device = torch.device('cpu')):
        if soup_weights is None:
            soup_weights = torch.ones(len(soup_list)) / len(soup_list)
        ensemble_state_dict = OrderedDict()

        for idx, candidate in enumerate(soup_list):
            candidate_id, candidate_file = candidate.id, candidate.file
            state_dict = torch.load(candidate_file, map_location=device)
            for key, val in state_dict.items():
                factor = soup_weights[idx].item()   # No need to use tensor here
                if '_mask' in key:
                    # We dont want to average the masks, hence we skip them and add later
                    continue
                if key not in ensemble_state_dict:
                    ensemble_state_dict[key] = factor * val.detach().clone()    # Important: clone otherwise we modify the tensors
                else:
                    ensemble_state_dict[key] += factor * val.detach().clone()   # Important: clone otherwise we modify the tensors

        # Add the masks from the last state_dict
        for key, val in state_dict.items():
            if '_mask' in key:
                ensemble_state_dict[key] = val.detach().clone()

        return ensemble_state_dict

    def final(self):
        self.callbacks['final_log_callback']()

    def get_ensemble_metrics(self):
        if self.selected_models == 'all':
            # We have already collected the metrics for all models
            self.soup_metrics['selected'] = self.soup_metrics['candidates']
        else:
            assert self.selected_models is not None and len(self.selected_models) > 0, "No models selected for metrics."
            # Collect individual metrics for the selected models, which we already have
            metrics_dict = defaultdict(lambda
                                       : defaultdict(list))
            for split in ['test', 'ood']:
                for candidate in self.selected_models:
                    single_model_metrics = candidate.get_metrics(split=split)
                    for metric, value in single_model_metrics.items():
                        metrics_dict[split][metric].append(value)
                for aggName, aggFunc in zip(['mean', 'max'], [np.mean, np.max]):
                    for metric, values in metrics_dict[split].items():
                        self.soup_metrics['selected'][f'{split}.{metric}_{aggName}'] = aggFunc(values)

            # Collect group_metrics for the selected models
            group_metrics = self.get_soup_metrics(soup_list=self.selected_models)
            self.soup_metrics['selected'].update(group_metrics)

            # Collect prediction ensemble metrics, only for test for now
            ensemble_labels = self.runner.collect_avg_output_full(data='test',
                                                                  candidate_model_list=self.selected_models)
            ensemble_metrics = {
                'pred_ensemble.test': self.runner.evaluate_soup(data='test', ensemble_labels=ensemble_labels)}
            self.soup_metrics['selected'].update(ensemble_metrics)
        return self.soup_metrics


class UniformEnsembling(EnsemblingBaseClass):
    """Just averages all models"""

    def __init__(self, **kwargs) -> None:
        super().__init__(**kwargs)

    @torch.no_grad()
    def create_ensemble(self, **kwargs):
        super().create_ensemble(**kwargs)

        device = torch.device('cpu')
        soup_weights = self.get_soup_weights(soup_list=self.candidate_model_list)
        ensemble_state_dict = self.average_models(soup_list=self.candidate_model_list, soup_weights=soup_weights,
                                                  device=device)
        self.selected_models = 'all'
        return ensemble_state_dict

    def get_soup_weights(self, soup_list: list[Candidate]):
        uniform_factor = 1. / len(soup_list)
        return torch.tensor([uniform_factor] * len(soup_list))


class LearnedSoup(EnsemblingBaseClass):
    """Uses all models, but with learned weights"""

    def __init__(self, **kwargs) -> None:
        super().__init__(**kwargs)

    # Requires gradient!
    def create_ensemble(self, **kwargs):
        super().create_ensemble(**kwargs)

        device = torch.device('cpu')
        soup_weights = self.get_soup_weights(soup_list=self.candidate_model_list).to(device=device)
        with torch.no_grad():
            ensemble_state_dict = self.average_models(soup_list=self.candidate_model_list, soup_weights=soup_weights,
                                                      device=device)
        self.selected_models = 'all'
        return ensemble_state_dict

    def get_soup_weights(self, soup_list: list[Candidate]):
        """Optimize the weights of the soup"""
        ensemble_weights = self.runner.learn_soup(candidate_model_list=soup_list)
        return ensemble_weights


class GreedySoup(EnsemblingBaseClass):
    """Greedy approach"""

    def __init__(self, **kwargs) -> None:
        super().__init__(**kwargs)

    @torch.no_grad()
    def create_ensemble(self, **kwargs):
        super().create_ensemble(**kwargs)
        val_accuracies = [(candidate, candidate.get_single_metric(metric='accuracy', split='val'))
                          for candidate in self.candidate_model_list]
        device = torch.device('cpu')

        # Sort the models by their validation accuracy in decreasing order
        sorted_tuples = sorted(val_accuracies, key=lambda x: x[1], reverse=True)

        ingredients_candidates = [sorted_tuples[0][0]]
        max_val_accuracy = sorted_tuples[0][1]
        for candidate, _ in sorted_tuples[1:]:
            # Check whether we benefit from adding to the soup
            ensemble_state_dict = self.average_models(soup_list=ingredients_candidates + [candidate], device=device)
            self.callbacks['load_soup_callback'](ensemble_state_dict=ensemble_state_dict)
            self.callbacks['recalibrate_bn_callback']()
            soup_metrics = self.callbacks['soup_evaluation_callback'](data='val')
            soup_val_accuracy = soup_metrics['accuracy']
            if soup_val_accuracy >= max_val_accuracy:
                ingredients_candidates = ingredients_candidates + [candidate]
                max_val_accuracy = soup_val_accuracy

        self.selected_models = ingredients_candidates
        if len(ingredients_candidates) == len(self.candidate_model_list):
            self.selected_models = 'all'
            sys.stdout.write("GreedySoup used all candidates.\n")
        else:
            sys.stdout.write(
                f"GreedySoup used candidates with ids: {[candidate.id for candidate in ingredients_candidates]}.\n")
        final_ensemble_state_dict = self.average_models(soup_list=ingredients_candidates, device=device)
        return final_ensemble_state_dict

class PermutationSpec(NamedTuple):
    perm_to_axes: dict
    axes_to_perm: dict

class PermutationEnsembling(EnsemblingBaseClass):
    """Just averages all models"""

    def __init__(self, **kwargs) -> None:
        super().__init__(**kwargs)
        self.delta_l2_distance = None

    @torch.no_grad()
    def create_ensemble(self, **kwargs):
        super().create_ensemble(**kwargs)
        model_list = [candidate.file for candidate in self.candidate_model_list]

        if len(model_list) == 2:
            self.merge_two_models(run_list=model_list, device=torch.device('cpu'))
        else:
            self.merge_multiple_models(run_list=model_list, device=torch.device('cpu'))

        ensemble_state_dict = self.average_models(soup_list=self.candidate_model_list, device=torch.device('cpu'))
        self.selected_models = 'all'

        return ensemble_state_dict

    @torch.no_grad()
    def merge_two_models(self, run_list, device):
        model_a = torch.load(run_list[0], map_location=device)
        model_b = torch.load(run_list[1], map_location=device)

        # Split models into original weights and masks
        model_a_weights, model_a_masks = Utils.split_weights_and_masks(model_a)
        model_b_weights, model_b_masks = Utils.split_weights_and_masks(model_b)

        model_a = model_a_weights
        model_b = model_b_weights

        permutation_spec = getattr(importlib.import_module('models.' + self.run_config['dataset']),
                                   self.run_config['arch']).get_permutation_spec()
        sys.stdout.write(f"Pulling {permutation_spec}.\n")

        L_2_distance_before = Utils.get_l2_distance(model_a=model_a, model_b=model_b)
        sys.stdout.write("#### Permuting models.\n")
        final_permutation = PermutationEnsembling.weight_matching(permutation_spec,
                                                                  model_a, model_b)

        model_b = PermutationEnsembling.apply_permutation(permutation_spec, final_permutation, model_b)

        L_2_distance_after = Utils.get_l2_distance(model_a=model_a, model_b=model_b)
        sys.stdout.write(f"L2-distance between models changed from {L_2_distance_before} to {L_2_distance_after}.\n")
        self.delta_l2_distance = L_2_distance_before - L_2_distance_after

        # Apply permutation to model_b_mask
        model_b_masks = PermutationEnsembling.apply_permutation(permutation_spec, final_permutation, model_b_masks)

        # Join weights and masks again
        model_a = Utils.join_weights_and_masks(model_a, model_a_masks)
        model_b = Utils.join_weights_and_masks(model_b, model_b_masks)

        # Overwrite existing state dict
        fPath = run_list[1]
        torch.save(model_b, fPath)  # Save the state_dict

    @torch.no_grad()
    def merge_multiple_models(self, run_list, device):
        models = [torch.load(run, map_location=device) for run in run_list]
        models_weights, models_masks = [], []
        for model in models:
            m_weights, m_masks = Utils.split_weights_and_masks(model)
            models_weights.append(m_weights)
            models_masks.append(m_masks)
        models = models_weights

        permutation_spec = getattr(importlib.import_module('models.' + self.run_config['dataset']),
                                   self.run_config['arch']).get_permutation_spec()
        sys.stdout.write(f"Pulling {permutation_spec}.\n")

        L_2_distance_before = Utils.get_group_l2_distance(models)
        sys.stdout.write("#### Permuting multiple models.\n")

        # Overwriting models
        models, models_masks = PermutationEnsembling.multi_merge_weight_matching(permutation_spec, models, models_masks)

        L_2_distance_after = Utils.get_group_l2_distance(models)
        sys.stdout.write(
            f"Maximal L2-distance between two models changed from {L_2_distance_before} to {L_2_distance_after}.\n")
        self.delta_l2_distance = L_2_distance_before - L_2_distance_after

        models = [Utils.join_weights_and_masks(model, model_masks) for model, model_masks in zip(models, models_masks)]

        # Overwrite existing state dicts
        for idx in range(len(run_list)):
            fPath = run_list[idx]
            model_state_dict = models[idx]
            torch.save(model_state_dict, fPath)  # Save the state_dict

    @staticmethod
    def get_permuted_param(ps: PermutationSpec, perm, k: str, params, except_axis=None):
        """Get parameter `k` from `params`, with the permutations applied."""
        w = params[k]
        for axis, p in enumerate(ps.axes_to_perm[k]):
            # Skip the axis we're trying to permute.
            if axis == except_axis:
                continue

            # None indicates that there is no permutation relevant to that axis.
            if p is not None:
                w = torch.index_select(w, axis, perm[p].int())

        return w

    @staticmethod
    def apply_permutation(ps: PermutationSpec, perm, params):
        """Apply a `perm` to `params`."""
        if isinstance(params, torch.nn.Module):
            params = params.state_dict()
        return {k: PermutationEnsembling.get_permuted_param(ps, perm, k, params)
        if not k.endswith(('running_mean', 'running_var', 'num_batches_tracked')) else params[k]
                for k in params.keys()}

    @staticmethod
    def weight_matching(ps: PermutationSpec, params_a, params_b, max_iter=100, init_perm=None):
        """Find a permutation of `params_b` to make them match `params_a`."""
        if isinstance(params_a, torch.nn.Module):
            params_a = params_a.state_dict()
        if isinstance(params_b, torch.nn.Module):
            params_b = params_b.state_dict()

        perm_sizes = {p: params_a[axes[0][0]].shape[axes[0][1]] for p, axes in ps.perm_to_axes.items()}

        perm = {p: torch.arange(n) for p, n in perm_sizes.items()} if init_perm is None else init_perm
        perm_names = list(perm.keys())

        for iteration in range(max_iter):
            progress = False
            for p_ix in torch.randperm(len(perm_names)):
                p = perm_names[p_ix]
                n = perm_sizes[p]
                A = torch.zeros((n, n))
                for wk, axis in ps.perm_to_axes[p]:
                    w_a = params_a[wk]
                    w_b = PermutationEnsembling.get_permuted_param(ps, perm, wk, params_b, except_axis=axis)
                    w_a = torch.moveaxis(w_a, axis, 0).reshape((n, -1))
                    w_b = torch.moveaxis(w_b, axis, 0).reshape((n, -1))

                    A += w_a @ w_b.T

                ri, ci = linear_sum_assignment(A.detach().numpy(), maximize=True)
                assert (torch.tensor(ri) == torch.arange(len(ri))).all()
                oldL = torch.einsum('ij,ij->i', A, torch.eye(n)[perm[p].long()]).sum()
                newL = torch.einsum('ij,ij->i', A, torch.eye(n)[ci, :]).sum()
                sys.stdout.write(f"Iteration {iteration} - Permutation {p}: {newL - oldL}\n")
                progress = progress or newL > oldL + 1e-12

                perm[p] = torch.Tensor(ci)

            if not progress:
                break

        return perm

    @staticmethod
    def multi_merge_weight_matching(ps: PermutationSpec, models, models_masks, max_iter=100, init_perm=None):
        for idx in range(len(models)):
            if isinstance(models[idx], torch.nn.Module):
                models[idx] = models[idx].state_dict()

        # perm_sizes = {p: models[0][axes[0][0]].shape[axes[0][1]] for p, axes in ps.perm_to_axes.items()}

        # perms = {idx: {p: torch.arange(n) for p, n in perm_sizes.items()} if init_perm is None else init_perm for idx in range(len(models))}

        for iteration in range(max_iter):
            progress = False
            for i in torch.randperm(len(models)):
                # Select one model
                i = int(i)
                model_i = models[i]
                # perm = perms[i]

                # Create average ensemble of all other models
                rest_model_list = [models[j] for j in range(len(models)) if j != i]
                factor = 1. / len(rest_model_list)
                rest_model = OrderedDict()
                for model in rest_model_list:
                    for key, val in model.items():
                        if key not in rest_model:
                            rest_model[
                                key] = val.clone().detach()  # Important otherwise we overwrite the original parameters
                        else:
                            rest_model[key] += val

                for key, val in rest_model.items():
                    rest_model[key] = rest_model[key] * factor

                oldL = Utils.get_l2_distance(model_a=rest_model, model_b=model_i)
                new_perm = PermutationEnsembling.weight_matching(ps=ps, params_a=rest_model, params_b=model_i,
                                                                 max_iter=100)
                models[i] = PermutationEnsembling.apply_permutation(ps=ps, perm=new_perm, params=model_i)
                models_masks[i] = PermutationEnsembling.apply_permutation(ps=ps, perm=new_perm, params=models_masks[i])
                newL = Utils.get_l2_distance(model_a=rest_model, model_b=models[i])

                sys.stdout.write(f"Iteration {iteration} - Model {i}: {newL - oldL}\n")
                progress = progress or newL > oldL + 1e-12

                # Change the permutation accordingly
                # perms[i] = new_perm

            if not progress:
                break

        return models, models_masks

    def get_ensemble_metrics(self):
        generalMetrics = super().get_ensemble_metrics()
        joinedMetrics = generalMetrics | {'permutation_delta_L2_distance': self.delta_l2_distance}
        return joinedMetrics


def mlp_permutation_spec(num_hidden_layers: int) -> PermutationSpec:
    """We assume that one permutation cannot appear in two axes of the same weight array."""
    assert num_hidden_layers >= 1
    return Utils.permutation_spec_from_axes_to_perm({
        "layer0.weight": ("P_0", None),
        **{f"layer{i}.weight": (f"P_{i}", f"P_{i - 1}")
           for i in range(1, num_hidden_layers)},
        **{f"layer{i}.bias": (f"P_{i}",)
           for i in range(num_hidden_layers)},
        f"layer{num_hidden_layers}.weight": (None, f"P_{num_hidden_layers - 1}"),
        f"layer{num_hidden_layers}.bias": (None,),
    })


def cnn_permutation_spec() -> PermutationSpec:
    conv = lambda name, p_in, p_out, bias=True: {f"{name}.weight": (p_out, p_in, None, None,),
                                                 f"{name}.bias": (p_out,)} if bias else {
        f"{name}.weight": (p_out, p_in, None, None,)}
    dense = lambda name, p_in, p_out, bias=True: {f"{name}.weight": (p_out, p_in),
                                                  f"{name}.bias": (p_out,)} if bias else {
        f"{name}.weight": (p_out, p_in)}
    return Utils.permutation_spec_from_axes_to_perm({
        **conv("conv1", None, "P_bg0"),
        **conv("conv2", "P_bg0", "P_bg1", False),
        **dense("fc1", "P_bg1", "P_bg2"),
        **dense("fc2", "P_bg2", None, True),
    })


def test_weight_matching_mlp():
    """If we just have a single hidden layer then it should converge after just one step."""
    ps = mlp_permutation_spec(num_hidden_layers=1)
    print(ps.axes_to_perm)
    rng = torch.Generator()
    rng.manual_seed(13)
    num_hidden = 10
    shapes = {
        "layer0.weight": (2, num_hidden),
        "layer0.bias": (num_hidden,),
        "layer1.weight": (num_hidden, 3),
        "layer1.bias": (3,)
    }

    params_a = {k: torch.randn(shape, generator=rng) for k, shape in shapes.items()}
    params_b = {k: torch.randn(shape, generator=rng) for k, shape in shapes.items()}
    perm = PermutationEnsembling.weight_matching(ps, params_a, params_b)
    print(perm)


def test_weight_matching_cnn():
    class CNN(nn.Module):
        def __init__(self):
            super(CNN, self).__init__()
            self.conv1 = nn.Conv2d(3, 32, 3, 1, bias=True)
            self.conv2 = nn.Conv2d(32, 128, 3, 1, bias=False)
            self.dropout1 = nn.Dropout(0.25)
            self.dropout2 = nn.Dropout(0.5)
            self.avg = nn.AvgPool2d(kernel_size=1, stride=1)
            self.fc1 = nn.Linear(128, 128)
            self.fc2 = nn.Linear(128, 10, bias=True)

        def forward(self, x):
            x = self.conv1(x)
            x = F.relu(x)
            x = self.conv2(x)
            x = F.relu(x)
            x = F.max_pool2d(x, 2)
            x = self.dropout1(x)
            x = self.avg(x)
            x = x.view(x.size(0), -1)
            x = self.fc1(x)
            x = F.relu(x)
            x = self.dropout2(x)
            x = self.fc2(x)
            output = F.log_softmax(x, dim=1)
            return output

    model_a = CNN()
    model_b = CNN()

    permutation_spec = cnn_permutation_spec()

    L_2_distance_before = Utils.get_l2_distance(model_a=model_a, model_b=model_b)
    final_permutation = PermutationEnsembling.weight_matching(permutation_spec,
                                                              model_a.state_dict(), model_b.state_dict())

    model_b = PermutationEnsembling.apply_permutation(permutation_spec, final_permutation, model_b.state_dict())

    L_2_distance_after = Utils.get_l2_distance(model_a=model_a, model_b=model_b)
    sys.stdout.write(f"L2-distance between models changed from {L_2_distance_before} to {L_2_distance_after}.")


def test_weight_matching_wrn():
    from models.cifar10 import WideResNet20

    model_a = WideResNet20()
    model_b = WideResNet20()

    permutation_spec = WideResNet20.get_permutation_spec()

    L_2_distance_before = Utils.get_l2_distance(model_a=model_a, model_b=model_b)
    final_permutation = PermutationEnsembling.weight_matching(permutation_spec,
                                                              model_a.state_dict(), model_b.state_dict())

    model_b = PermutationEnsembling.apply_permutation(permutation_spec, final_permutation, model_b.state_dict())

    L_2_distance_after = Utils.get_l2_distance(model_a=model_a, model_b=model_b)
    sys.stdout.write(f"L2-distance between models changed from {L_2_distance_before} to {L_2_distance_after}.")


def test_weight_matching_wrn_test():
    from models.cifar10 import WideResNet20

    model_a = WideResNet20().state_dict()
    model_b = copy.deepcopy(model_a)

    permutation_spec = WideResNet20.get_permutation_spec()

    # Permute the weights of model_b
    perm_sizes = {p: model_a[axes[0][0]].shape[axes[0][1]] for p, axes in permutation_spec.perm_to_axes.items()}
    random_perm = {p: torch.randperm(n) for p, n in perm_sizes.items()}
    model_b = PermutationEnsembling.apply_permutation(permutation_spec, random_perm, model_b)
    L_2_distance_before = Utils.get_l2_distance(model_a=model_a, model_b=model_b)
    final_permutation = PermutationEnsembling.weight_matching(permutation_spec,
                                                              model_a, model_b)

    model_b = PermutationEnsembling.apply_permutation(permutation_spec, final_permutation, model_b)

    L_2_distance_after = Utils.get_l2_distance(model_a=model_a, model_b=model_b)
    sys.stdout.write(f"L2-distance between models changed from {L_2_distance_before} to {L_2_distance_after}.")


@torch.no_grad()
def get_mixed_sparsity(params_a, params_b):
    """Returns the global sparsity of module (mostly of entire model)"""
    n_total, n_zero = 0., 0.
    if isinstance(params_a, torch.nn.Module):
        params_a = params_a.state_dict()
    if isinstance(params_b, torch.nn.Module):
        params_b = params_b.state_dict()

    for key in params_a.keys():
        if not '.weight' in key:
            continue
        p_a, p_b = params_a[key], params_b[key]
        p = 0.5 * (p_a + p_b)
        n_total += p.numel()
        n_zero += torch.sum(p == 0).item()

    return float(n_zero) / n_total


def test_weight_matching_cnn_prune():
    class CNN(nn.Module):
        def __init__(self):
            super(CNN, self).__init__()
            self.conv1 = nn.Conv2d(3, 32, 3, 1, bias=True)
            self.conv2 = nn.Conv2d(32, 128, 3, 1, bias=False)
            self.dropout1 = nn.Dropout(0.25)
            self.dropout2 = nn.Dropout(0.5)
            self.avg = nn.AvgPool2d(kernel_size=1, stride=1)
            self.fc1 = nn.Linear(128, 128)
            self.fc2 = nn.Linear(128, 10, bias=True)

        def forward(self, x):
            x = self.conv1(x)
            x = F.relu(x)
            x = self.conv2(x)
            x = F.relu(x)
            x = F.max_pool2d(x, 2)
            x = self.dropout1(x)
            x = self.avg(x)
            x = x.view(x.size(0), -1)
            x = self.fc1(x)
            x = F.relu(x)
            x = self.dropout2(x)
            x = self.fc2(x)
            output = F.log_softmax(x, dim=1)
            return output

    model_a = CNN()
    model_b = CNN()

    permutation_spec = cnn_permutation_spec()

    # Prune certain layers
    for model in [model_a, model_b]:
        for name, module in model.named_modules():
            if isinstance(module, torch.nn.Linear) or isinstance(module, torch.nn.Conv2d):
                prune.ln_structured(module, 'weight', 0.7, n=1, dim=0)

                # Remove the prunings to enforce them
                prune.remove(module, 'weight')

    s_individual = metrics.metrics.global_sparsity(module=model_a, param_type='weight')

    s_before = get_mixed_sparsity(model_a, model_b)

    L_2_distance_before = Utils.get_l2_distance(model_a=model_a, model_b=model_b)
    final_permutation = PermutationEnsembling.weight_matching(permutation_spec,
                                                              model_a, model_b)

    model_b = PermutationEnsembling.apply_permutation(permutation_spec, final_permutation, model_b)
    s_after = get_mixed_sparsity(model_a, model_b)
    L_2_distance_after = Utils.get_l2_distance(model_a=model_a, model_b=model_b)
    sys.stdout.write(f"L2-distance between models changed from {L_2_distance_before} to {L_2_distance_after}.\n")
    sys.stdout.write(f"Individual Sparsity {s_individual} - Sparsity of ensemble changed from {s_before} to {s_after}.")


if __name__ == "__main__":
    test_weight_matching_cnn_prune()
