import os
import json
import torch.nn
import torch.nn.functional as F
import wandb
import matplotlib.pyplot as plt


from eufm_numerics.models import *
from eufm_numerics.dsets import *



class TrainerAndAnalyser:
    """
    The class for the training of pure DUFM models.
    """
    def __init__(self, model, optimizer, scheduler, loss_fcn, wds, num_classes, num_samples):
        """
        Initialization of the training handler.
        :param model: Model to be trained on
        :param optimizer: Optimizer which trains the model
        :param scheduler: Scheduler which schedules the optimizer's learning rate.
        :param loss_fcn: The training loss function to be used.
        :param wds: A list of weight decays per layer to be used. Requires L+1 length.
        :param num_classes: Number of classes of the classification problem.
        :param num_samples: Number of training samples per a single class.
        """
        self.Y = torch.eye(num_classes).repeat_interleave(num_samples, dim=0)
        self.model = model
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.loss_fcn = loss_fcn
        self.wds = wds
        self.samples_per_class = num_samples
        self.num_classes = num_classes
        self.samples_total = self.samples_per_class * self.num_classes

        self.losses = None
        self.fit_losses = None
        self.regularization_losses = None
        self.NC1s = None
        self.NC2s = None
        self.NC3s = None
        self.matrix_dict = None
        self.grad_norms = None
        self.reg_grad_norms = None
        self.fit_grad_norms = None
        self.grad_angles = None
        self.feature_orthogonalities = None

    def run_and_analyze(self, steps, analyze=False, log_every=1, return_matrices=False, return_every_n_steps=False,
                        save_model=False, load_model=False, load_dir=None, exp_name=None, perturb_every=None, perturbation_scale=None):
        """
        The method to be run to execute the training.
        :param steps: Number of full GD steps to be run.
        :param analyze: If true, will save the loss and DNC metrics. We recommend to be set True.
        :param log_every: The frequency of loss and DNC metrics computation.
        :param return_matrices: If true, the method will also return the weight matrices. We recommend the option save_model instead.
        :param return_every_n_steps: The frequency of returning the weight matrices, if return_matrices is True. Not recommended.
        :param save_model: If true, saves the full model in the end of the training.
        :param load_model: If true, loads the model to be training with instead of random initialization.
        :param load_dir: The directory to load the model from. Only if load_dir true.
        :param exp_name: The finest subfolder name, where runs will be saved.
        :param perturb_every: Not recommended. If true, perturbs weights of the model randomly.
        :param perturbation_scale: If perturb_every true, sets the perturbation scale.
        :return: returns the saved loss and DNC metrics.
        """
        if load_model:
            self.model = torch.load(load_dir)

        if analyze:
            self._setup_logging(steps, return_matrices, log_every=log_every, return_every_n_steps=return_every_n_steps)

        for step in range(steps):

            # Logging:
            if step % 1000 == 0:
                print(step)

            if perturb_every is not None:
                if step % perturb_every == (perturb_every // 2):
                    self.model.perturb_weights(perturbation_scale * self.scheduler.get_last_lr()[0])

            # To be sure:
            self.model.zero_grad()

            if analyze and step % log_every == 0:
                dummy_outputs = self.model.forward()
                dummy_loss = self.loss_fcn(dummy_outputs, self.Y) * self.num_classes / 2
                self.fit_losses[step // log_every] = dummy_loss.detach()
                dummy_loss.backward()
                dummy_loss = None
                dummy_outputs = None
                self.model.do_hook = False
                self.optimizer.zero_grad(set_to_none=True)

            outputs = self.model.forward()
            loss = self.loss_fcn(outputs, self.Y) * self.num_classes / 2

            dummy_counter = 0
            current_reg_term = 1 / 2 * self.wds[dummy_counter] * torch.frobenius_norm(self.model.weight0) ** 2
            loss += current_reg_term
            if analyze and step % log_every == 0:
                self.regularization_losses[step // log_every, dummy_counter] = current_reg_term.detach()
            dummy_counter += 1
            for name, module in self.model.named_modules():
                if isinstance(module, nn.Linear):
                    current_reg_term = 1 / 2 * self.wds[dummy_counter] * torch.frobenius_norm(module.weight) ** 2
                    loss += current_reg_term
                    if analyze and step % log_every == 0:
                        self.regularization_losses[step // log_every, dummy_counter] = current_reg_term.detach()
                    dummy_counter += 1

            if analyze and step % log_every == 0:
                self.NC1s[step // log_every] = self.computeNC1s().detach()
                self.NC2s[step // log_every], self.NC2s_clean[step // log_every] = self.computeNC2s()
                self.NC3s[step // log_every] = self.computeNC3s().detach()
                self.grad_norms[step // log_every] = self.compute_grad_norms().detach()
                self.fit_grad_norms[step // log_every] = self.compute_fit_grad_norms().detach()
                self.reg_grad_norms[step // log_every] = self.compute_reg_grad_norms().detach()
                self.grad_angles[step // log_every] = self.compute_grad_angles().detach()
                self.losses[step // log_every] = loss.detach()

            if return_matrices and step % return_every_n_steps == 0:
                self.matrix_dict['weight0'][step // return_every_n_steps] = self.model.weight0
                for name, module in self.model.named_modules():
                    if isinstance(module, nn.Linear):
                        self.matrix_dict[name][step // return_every_n_steps] = module.weight.data

            loss.backward()
            self.model.do_hook = True
            self.optimizer.step()
            self.scheduler.step()

        if not analyze:
            return
        elif analyze and not return_matrices:
            results = {
                'wds': self.wds,
                'num_classes': self.num_classes,
                'samples_per_class': self.samples_per_class,
                'input_dims': self.model.input_dims,
                'total_losses': self.losses,
                'fit_losses': self.fit_losses,
                'regularization_losses': self.regularization_losses,
                'NC1s': self.NC1s,
                'NC2s': self.NC2s,
                'NC2s_clean': self.NC2s_clean,
                'NC3s': self.NC3s,
                'grad_norms': self.grad_norms,
                'fit_grad_norms': self.fit_grad_norms,
                'reg_grad_norms': self.reg_grad_norms,
                'grad_angles': self.grad_angles
            }
            self._save_results(results=results, specified_experiment=exp_name, save_model=save_model)
            return self.losses, self.fit_losses, self.regularization_losses, self.NC1s, self.NC2s, self.NC2s_clean, \
                self.NC3s, self.grad_norms, self.fit_grad_norms, self.reg_grad_norms, self.grad_angles

        else:
            results = {
                'wds': self.wds,
                'num_classes': self.num_classes,
                'samples_per_class': self.samples_per_class,
                'input_dims': self.model.input_dims,
                'total_losses': self.losses,
                'fit_losses': self.fit_losses,
                'regularization_losses': self.regularization_losses,
                'NC1s': self.NC1s,
                'NC2s': self.NC2s,
                'NC2s_clean': self.NC2s_clean,
                'NC3s': self.NC3s,
                'grad_norms': self.grad_norms,
                'matrices': self.matrix_dict,
                'fit_grad_norms': self.fit_grad_norms,
                'reg_grad_norms': self.reg_grad_norms,
                'grad_angles': self.grad_angles
            }
            self._save_results(results=results, specified_experiment=exp_name)
            return self.losses, self.fit_losses, self.regularization_losses, self.NC1s, self.NC2s, self.NC2s_clean,\
                self.NC3s, self.grad_norms, self.fit_grad_norms, self.reg_grad_norms, self.grad_angles, self.matrix_dict

    def _setup_logging(self, steps, return_matrices=False, log_every=1, return_every_n_steps=None):

        layer_names = ['weight0']
        for name, module in self.model.named_modules():
            if isinstance(module, nn.Linear):
                layer_names.append(name)
        num_layers = self.model.num_layers
        eff_num_steps = steps // log_every

        self.losses = torch.zeros(eff_num_steps)
        self.fit_losses = torch.zeros(eff_num_steps)
        self.regularization_losses = torch.zeros((eff_num_steps, num_layers + 1))
        self.NC1s = torch.zeros((eff_num_steps, num_layers))
        self.NC2s = torch.zeros((eff_num_steps, num_layers))
        self.NC2s_clean = torch.zeros((eff_num_steps, num_layers))
        self.NC3s = torch.zeros((eff_num_steps, num_layers))
        self.grad_norms = torch.zeros((eff_num_steps, num_layers + 1))
        self.reg_grad_norms = torch.zeros((eff_num_steps, num_layers + 1))
        self.fit_grad_norms = torch.zeros((eff_num_steps, num_layers + 1))
        self.grad_angles = torch.zeros((eff_num_steps, num_layers + 1))

        if return_matrices:
            self.matrix_dict = {
                'weight0': torch.zeros((steps // return_every_n_steps) + tuple(self.model.weight0.size))}
            for name, module in self.model.named_modules():
                if isinstance(module, nn.Linear):
                    dimensions = (steps // return_every_n_steps) + tuple(module.weight.data.size)
                    self.matrix_dict[name] = torch.zeros(dimensions)

    def computeNC1s(self):
        with torch.no_grad():
            num_layers = self.model.num_layers
            NC1s = torch.zeros(num_layers)
            features = getattr(self.model, 'weight0')
            fmeans = 1 / self.samples_per_class * features @ torch.eye(self.num_classes).repeat_interleave(
                self.samples_per_class, dim=0)
            gmeans = fmeans.mean(dim=1, keepdim=True)
            sigma_b = 1 / self.num_classes * (fmeans - gmeans) @ (fmeans - gmeans).T
            fmeans_extended = fmeans.repeat_interleave(self.samples_per_class, dim=1)
            centered_features = features - fmeans_extended
            sigma_w = 1 / self.samples_total * centered_features @ centered_features.T
            try:
                NC1s[0] = torch.frobenius_norm(torch.linalg.pinv(sigma_b) @ sigma_w)
            except:
                NC1s[0] = 0

            for layer_idx in range(num_layers - 1):
                features = self.model.forward_until_layer(layer_idx + 1).T
                fmeans = 1 / self.samples_per_class * features @ torch.eye(self.num_classes).repeat_interleave(
                    self.samples_per_class, dim=0)
                gmeans = fmeans.mean(dim=1, keepdim=True)
                sigma_b = 1 / self.num_classes * (fmeans - gmeans) @ (fmeans - gmeans).T
                fmeans_extended = fmeans.repeat_interleave(self.samples_per_class, dim=1)
                centered_features = features - fmeans_extended
                sigma_w = 1 / self.samples_total * centered_features @ centered_features.T
                try:
                    NC1s[layer_idx + 1] = torch.frobenius_norm(torch.linalg.pinv(sigma_b) @ sigma_w)
                except:
                    NC1s[layer_idx + 1] = 0

            return NC1s

    def computeNC2s(self):
        # Computes the condition number of the matrix, assumes no bias!!!
        with torch.no_grad():
            num_layers = self.model.num_layers
            NC2s = torch.zeros(num_layers)
            NC2s_clean = torch.zeros(num_layers)
            features = getattr(self.model, 'weight0').T
            fmeans = 1 / self.samples_per_class * torch.eye(self.num_classes).repeat_interleave(
                self.samples_per_class, dim=1) @ features
            NC2s[0] = torch.linalg.cond(fmeans)
            NC2s_clean[0] = NC2s[0]

            for layer_idx in range(num_layers-1):
                features = self.model.forward_until_layer(layer_idx+1)
                fmeans = 1 / self.samples_per_class * torch.eye(self.num_classes).repeat_interleave(
                    self.samples_per_class, dim=1) @ features
                NC2s[layer_idx + 1] = torch.linalg.cond(fmeans)
                features = self.model.forward_until_layer_clean(layer_idx+1)
                fmeans = 1 / self.samples_per_class * torch.eye(self.num_classes).repeat_interleave(
                    self.samples_per_class, dim=1) @ features
                NC2s_clean[layer_idx + 1] = torch.linalg.cond(fmeans)

        return NC2s.detach(), NC2s_clean.detach()

    def computeNC3s(self):
        # Assumes no bias again!
        with torch.no_grad():
            num_layers = self.model.num_layers
            NC3s = torch.zeros(num_layers)
            for layer_idx in range(num_layers):
                if layer_idx == 0:
                    features = getattr(self.model, 'weight0').data.T
                    weights = getattr(self.model, self.model.weight_names[layer_idx]).weight.data
                    large_enough_weights = weights[weights.norm(dim=1) >= 0.000001]
                    fmeans = (1 / self.samples_per_class * torch.eye(self.num_classes).repeat_interleave(
                        self.samples_per_class, dim=1) @ features).T
                    the_metric = torch.nn.functional.relu(((1 - (((large_enough_weights @ fmeans) / large_enough_weights.norm(dim=1, keepdim=True).repeat_interleave(self.num_classes, dim=1) / fmeans.norm(dim=0, keepdim=True).repeat_interleave(large_enough_weights.size()[0], dim=0)).max(dim=1)[0]))*large_enough_weights.norm(dim=1) / large_enough_weights.norm(dim=1).sum()).sum())
                    NC3s[layer_idx] = the_metric
                else:
                    features = self.model.forward_until_layer(layer_idx)
                    weights = getattr(self.model, self.model.weight_names[layer_idx]).weight.data
                    large_enough_weights = weights[weights.norm(dim=1) >= 0.000001]
                    fmeans = (1 / self.samples_per_class * torch.eye(self.num_classes).repeat_interleave(
                        self.samples_per_class, dim=1) @ features).T
                    the_metric = torch.nn.functional.relu(((1-(((large_enough_weights @ fmeans) / large_enough_weights.norm(dim=1, keepdim=True).repeat_interleave(self.num_classes, dim=1) / fmeans.norm(dim=0, keepdim=True).repeat_interleave(large_enough_weights.size()[0], dim=0)).max(dim=1)[0]))*large_enough_weights.norm(dim=1) / large_enough_weights.norm(dim=1).sum()).sum())
                    NC3s[layer_idx] = the_metric
        return NC3s

    def compute_grad_norms(self):
        return self.model.current_grad_norms

    def compute_fit_grad_norms(self):
        return self.model.current_fit_grad_norms

    def compute_reg_grad_norms(self):
        return self.model.current_reg_grad_norms

    def compute_grad_angles(self):
        return self.model.current_grad_angle

    def _save_results(self, results, specified_experiment=None, save_model=False):
        # getting there
        path = 'experiments'
        if self.model.relu:
            path = os.path.join(path, 'with_relu')
            if not os.path.exists(path):
                os.mkdir(path)
        else:
            path = os.path.join(path, 'without_relu')
            if not os.path.exists(path):
                os.mkdir(path)
        if self.model.batch_norm:
            path = os.path.join(path, 'with_batch_norm')
            if not os.path.exists(path):
                os.mkdir(path)
        else:
            path = os.path.join(path, 'without_batch_norm')
            if not os.path.exists(path):
                os.mkdir(path)
        if self.model.bias:
            path = os.path.join(path, 'with_bias')
            if not os.path.exists(path):
                os.mkdir(path)
        else:
            path = os.path.join(path, 'without_bias')
            if not os.path.exists(path):
                os.mkdir(path)
        path = os.path.join(path, 'num_layers_' + str(self.model.num_layers))
        if not os.path.exists(path):
            os.mkdir(path)
        path = os.path.join(path, 'num_per_class_' + str(self.model.num_per_class))
        if not os.path.exists(path):
            os.mkdir(path)
        if specified_experiment is not None:
            path = os.path.join(path, specified_experiment)
            if not os.path.exists(path):
                os.mkdir(path)
        else:
            path = os.path.join(path, 'default_experiments')
            if not os.path.exists(path):
                os.mkdir(path)
        exp_idx = len(os.listdir(path))
        exp_name = 'run_' + str(exp_idx)
        path = os.path.join(path, exp_name)
        os.mkdir(path)

        # saving the model:
        if save_model:
            torch.save(self.model, os.path.join(path, 'model.pt'))

        # saving the results
        pseudoresults = results.copy()
        for key, value in pseudoresults.items():
            if isinstance(value, torch.Tensor):
                pseudoresults[key] = value.tolist()
        filename = os.path.join(path, 'results.json')
        file = open(filename, 'w')
        json_results = json.dumps(pseudoresults)
        file.write(json_results)
        file.close()

        total_optimum = compute_optimum_of_dufm(num_layers=len(self.wds)-1, wds=self.wds,
                                                samples_per_class=self.samples_per_class)

        plt.rc('font', size=14)  # controls default text sizes
        plt.rc('axes', titlesize=14)  # fontsize of the axes title
        plt.rc('axes', labelsize=14)  # fontsize of the x and y labels
        plt.rc('xtick', labelsize=14)  # fontsize of the tick labels
        plt.rc('ytick', labelsize=14)  # fontsize of the tick labels
        plt.rc('legend', fontsize=14)  # legend fontsize
        plt.rc('figure', titlesize=14)  # fontsize of the figure title

        # saving the plots
        plt.plot(torch.log10(results['total_losses']), label='total_loss')
        plt.plot(torch.log10(results['fit_losses']), label='fit_loss')
        for i in range(results['regularization_losses'].size()[1]):
            plt.plot(torch.log10(results['regularization_losses'][:, i]))
        plt.axhline(y=torch.log10(total_optimum), label='theoretical_optimum', color='r', linestyle='dotted')
        plt.legend()
        plt.title('Log losses')
        plt.xlabel('step')
        plt.ylabel('log10 loss')
        figname = 'losses.png'
        plt.savefig(os.path.join(path, figname))
        plt.close()

        for layer_idx in range(self.model.num_layers):
            plt.plot(torch.log10(results['NC1s'][:, layer_idx]), label='layer_' + str(layer_idx+1))
        plt.legend()
        plt.title('Log NC1s')
        plt.xlabel('step')
        plt.ylabel('log10 NC1')
        figname = 'nc1s.png'
        plt.savefig(os.path.join(path, figname))
        plt.close()

        for layer_idx in range(self.model.num_layers):
            plt.plot(torch.log10(results['NC2s'][:, layer_idx]), label='layer_' + str(layer_idx+1))
        plt.plot(torch.log10(results['NC2s_clean'][:, 1]), label='layer_2_pre_relu', linestyle='dashed')
        plt.legend()
        plt.title('Log NC2s')
        plt.xlabel('step')
        plt.ylabel('log10 NC2')
        figname = 'nc2s.png'
        plt.savefig(os.path.join(path, figname))
        plt.close()

        for layer_idx in range(self.model.num_layers):
            plt.plot(torch.log10(results['NC2s_clean'][:, layer_idx]), label='layer_' + str(layer_idx+1))
        plt.legend()
        plt.title('Log NC2s w/o ReLU')
        plt.xlabel('step')
        plt.ylabel('log10 NC2')
        figname = 'nc2s_clean.png'
        plt.savefig(os.path.join(path, figname))
        plt.close()

        for layer_idx in range(self.model.num_layers):
            plt.plot(torch.log10(results['NC3s'][:, layer_idx]+0.000000001), label='layer_' + str(layer_idx+1))
        plt.legend()
        plt.title('Log NC3s')
        plt.xlabel('step')
        plt.ylabel('log10 NC3')
        figname = 'nc3s.png'
        plt.savefig(os.path.join(path, figname))
        plt.close()

        for layer_idx in range(self.model.num_layers + 1):
            plt.plot(torch.log10(results['grad_norms'][:, layer_idx]), label='layer_' + str(layer_idx))
        plt.legend()
        plt.title('Log grad norms')
        plt.xlabel('step')
        plt.ylabel('log grad norms')
        figname = 'grad_norms.png'
        plt.savefig(os.path.join(path, figname))
        plt.close()

        for layer_idx in range(self.model.num_layers + 1):
            plt.plot(torch.log10(results['reg_grad_norms'][:, layer_idx]), label='layer_' + str(layer_idx))
        plt.legend()
        plt.title('Log reg grad norms')
        plt.xlabel('step')
        plt.ylabel('log reg grad norms')
        figname = 'reg_grad_norms.png'
        plt.savefig(os.path.join(path, figname))
        plt.close()

        for layer_idx in range(self.model.num_layers + 1):
            plt.plot(torch.log10(results['fit_grad_norms'][:, layer_idx]), label='layer_' + str(layer_idx))
        plt.legend()
        plt.title('Log fit grad norms')
        plt.xlabel('step')
        plt.ylabel('log fit grad norms')
        figname = 'fit_grad_norms.png'
        plt.savefig(os.path.join(path, figname))
        plt.close()

        for layer_idx in range(self.model.num_layers + 1):
            plt.plot(results['grad_angles'][:, layer_idx], label='layer_' + str(layer_idx))
        plt.legend()
        plt.title('Grad angles')
        plt.xlabel('step')
        plt.ylabel('Grad angles')
        figname = 'grad_angles.png'
        plt.savefig(os.path.join(path, figname))
        plt.close()


class RealNNTrainerAndAnalyzer:
    """
    The class for the training of joint DNN and DUFM models on real data.
    """
    def __init__(self, train_dataloader, eval_dataloader, test_dataloader, model, optimizer, scheduler, loss_fcn, wds,
                 num_classes, num_samples, num_test_samples, num_backbone_layers, exp_path, logger, device='cpu',
                 use_wandb=False, penalize_unconstrained_features=False):
        """
        :param train_dataloader: Training dataset's dataloader.
        :param eval_dataloader: Dataloader for the evaluation of loss and DNC metrics.
        :param test_dataloader: Test dataset's dataloader.
        :param model: Model to be trained on
        :param optimizer: Optimizer which trains the model
        :param scheduler: Scheduler which schedules the optimizer's learning rate.
        :param loss_fcn: The training loss function to be used.
        :param wds: A list of weight decays per layer to be used. Requires L+1 length.
        :param num_classes: Number of classes of the classification problem.
        :param num_samples: Number of training samples per a single class.
        :param num_test_samples: Number of test samples in total.
        :param num_backbone_layers: Number of layers of the ResNet.
        :param exp_path: The path to save the results into.
        :param logger: The logging device.
        :param device: Indicates whether the training is carried out on cpu or on cuda.
        :param use_wandb: Indicates whether to use wandb throughout the training.
        :param penalize_unconstrained_features: If true, the ResNet's output will be regularized instead of ResNet's weights.
        """
        self.train_dataloader = train_dataloader
        self.eval_dataloader = eval_dataloader
        self.test_dataloader = test_dataloader
        self.Y = torch.eye(num_classes).repeat_interleave(num_samples, dim=0)
        self.model = model
        self.backbone = self.model.backbone
        self.feature_refiner = self.model.fr
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.loss_fcn = loss_fcn
        self.wds = wds
        self.samples_per_class = num_samples
        self.num_classes = num_classes
        self.samples_total = self.samples_per_class * self.num_classes
        self.num_test_samples = num_test_samples
        self.num_backbone_layers = num_backbone_layers
        self.device = device
        self.exp_path = exp_path
        self.the_logger = logger
        self.use_wandb = use_wandb
        self.penalize_uf = penalize_unconstrained_features

        self.losses = None
        self.tr_accs = None
        self.test_losses = None
        self.test_accs = None
        self.NC1s = None
        self.NC2s = None
        self.NC3s = None
        self.matrix_dict = None

    def run_and_analyze(self, steps, analyze=False, log_every=1, return_matrices=False, return_every_n_steps=False,
                        save_model=False, load_model=False, load_dir=None):
        """
        :param steps: Number of training epochs to perform.
        :param analyze: If true, will save the loss and DNC metrics. Recommended.
        :param log_every: The frequency of loss and DNC computations in epochs per computation.
        :param return_matrices: Not recommended.
        :param return_every_n_steps: Not recommended.
        :param save_model: If true, will save the model in the end of the training.
        :param load_model: If true, will load the model to start the training with.
        :param load_dir: The directory to load the model from, if load_model true.
        :return: Returns the saved loss and DNC statistics.
        """

        if load_model:
            self.model = torch.load(load_dir)
            self.model = self.model.to(self.device)

        if analyze:
            self._setup_logging(steps, return_matrices, log_every=log_every, return_every_n_steps=return_every_n_steps)

        for step in range(steps):

            # Logging:
            if step % 1 == 0:
                self.the_logger.info(f"big_step: {step}")

            # To be sure:
            self.model.zero_grad()

            if analyze and step % log_every == 0:
                self.the_logger.info("eval_start")
                clean_means, means_afer_relu = self.compute_class_means()
                self.NC1s[step // log_every] = self.compute_NC1s(clean_means, means_afer_relu)
                self.NC2s[step // log_every], self.NC2s_clean[step // log_every] = self.compute_NC2s(clean_means, means_afer_relu)
                self.NC3s[step // log_every] = self.compute_NC3s(clean_means, means_afer_relu)
                tmp_loss = torch.zeros(1, device=self.device)
                tmp_tr_acc = torch.zeros(1, device=self.device)
                if self.penalize_uf:
                    tmp_features_reg_loss = torch.zeros(1, device=self.device)
                for batch, labels in self.eval_dataloader:
                    batch, labels = batch.to(self.device), labels.to(self.device)
                    with torch.no_grad():
                        outputs = self.model.forward(batch)
                        if self.penalize_uf:
                            tmp_features_reg_loss += 1 / 2 * self.wds[0] * torch.frobenius_norm(outputs)**2
                        loss = self.loss_fcn(outputs, F.one_hot(labels.long(), num_classes=self.num_classes).float())
                        tmp_loss += loss * len(labels) * self.num_classes
                        predictions = torch.argmin(torch.abs(outputs-1), dim=1)
                        tmp_tr_acc += torch.sum(predictions == labels)
                self.losses[step // log_every] = (tmp_loss / self.samples_total).cpu().detach()
                self.tr_accs[step // log_every] = (tmp_tr_acc / self.samples_total).cpu().detach()
                self.test_losses[step // log_every], self.test_accs[step // log_every] = self.compute_test_loss_and_acc()
                if self.penalize_uf:
                    self.regularization_losses[step // log_every, 0] = tmp_features_reg_loss.cpu().detach()
                    for layer_idx in range(self.model.fr.num_layers):
                        self.regularization_losses[step // log_every, layer_idx+1] = 1/2 * self.wds[layer_idx+1] * torch.frobenius_norm(getattr(self.model.fr, self.model.fr.weight_names[layer_idx]).weight.data)**2
                    self.total_losses[step // log_every] = self.losses[step // log_every] + self.regularization_losses[step // log_every].sum()
                if self.use_wandb:
                    wandb.log({
                        'log10_tr_loss': torch.log10(self.losses[step // log_every]),
                        'tr_acc': self.tr_accs[step // log_every],
                        'log10_test_loss': torch.log10(self.test_losses[step // log_every]),
                        'test_acc': self.test_accs[step // log_every]
                    })
                    for layer_idx in range(self.model.fr.num_layers):
                        wandb.log({
                            f'log10_NC1s_layer_{layer_idx}': torch.log10(self.NC1s[step // log_every][layer_idx]),
                            f'log10_NC2s_layer_{layer_idx}': torch.log10(self.NC2s[step // log_every][layer_idx]),
                            f'log10_NC2s_clean_layer_{layer_idx}': torch.log10(self.NC2s_clean[step // log_every][layer_idx]),
                            f'log10_NC3s_layer_{layer_idx}': torch.log10(self.NC3s[step // log_every][layer_idx]),
                        })
                    if self.penalize_uf:
                        wandb.log({
                            'log10_total_tr_loss': torch.log10(self.total_losses[step // log_every]),
                            'log10_features_loss': torch.log10(self.regularization_losses[step // log_every][0])
                        })
                        for layer_idx in range(self.model.fr.num_layers):
                            wandb.log({
                                f'log10_reg_loss_layer_{layer_idx}': torch.log10(self.regularization_losses[step // log_every][layer_idx+1])
                            })
                self.the_logger.info("eval_stop")

            if return_matrices and step % return_every_n_steps == 0:
                for name, module in self.model.fr.named_modules():
                    if isinstance(module, nn.Linear):
                        self.matrix_dict[name][step // return_every_n_steps] = module.weight.data.cpu().detach()

            # To be sure:
            self.model.zero_grad()

            for tr_step, (batch, labels) in enumerate(self.train_dataloader):
                self.the_logger.info(f"training_step: {tr_step}")
                batch, labels = batch.to(self.device), labels.to(self.device)

                if not self.penalize_uf:
                    outputs = self.model.forward(batch)
                    loss = self.loss_fcn(outputs, F.one_hot(labels.long(), num_classes=self.num_classes).float()) * self.num_classes / 2
                if self.penalize_uf:
                    prelim_outputs = self.model.backbone.forward(batch)
                    outputs = self.model.fr.forward(prelim_outputs)
                    loss = self.loss_fcn(outputs, F.one_hot(labels.long(), num_classes=self.num_classes).float()) * self.num_classes / 2
                    loss += 1/2 * self.wds[0] * torch.frobenius_norm(prelim_outputs)**2 / self.train_dataloader.batch_size * self.samples_total
                    for layer_idx in range(self.model.fr.num_layers):
                        loss += 1/2 * self.wds[layer_idx+1] * torch.frobenius_norm(getattr(self.model.fr, self.model.fr.weight_names[layer_idx]).weight.data)**2

                loss.backward()
                self.optimizer.step()

            self.scheduler.step()

        if not analyze:
            return
        elif analyze and not return_matrices:
            if not self.penalize_uf:
                results = {
                    'wds': self.wds,
                    'num_classes': self.num_classes,
                    'samples_per_class': self.samples_per_class,
                    'input_dims': self.model.fr.input_dims,
                    'fit_losses': self.losses,
                    'train_accs': self.tr_accs,
                    'test_losses': self.test_losses,
                    'test_accs': self.test_accs,
                    'NC1s': self.NC1s,
                    'NC2s': self.NC2s,
                    'NC2s_clean': self.NC2s_clean,
                    'NC3s': self.NC3s,
                }
                self._save_results(results=results, save_model=save_model)
                return self.losses, self.tr_accs, self.test_losses, self.test_accs, self.NC1s, self.NC2s, self.NC2s_clean, self.NC3s
            if self.penalize_uf:
                results = {
                    'wds': self.wds,
                    'num_classes': self.num_classes,
                    'samples_per_class': self.samples_per_class,
                    'input_dims': self.model.fr.input_dims,
                    'fit_losses': self.losses,
                    'regularization_losses': self.regularization_losses,
                    'total_losses': self.total_losses,
                    'train_accs': self.tr_accs,
                    'test_losses': self.test_losses,
                    'test_accs': self.test_accs,
                    'NC1s': self.NC1s,
                    'NC2s': self.NC2s,
                    'NC2s_clean': self.NC2s_clean,
                    'NC3s': self.NC3s,
                }
                self._save_results(results=results, save_model=save_model)
                return self.losses, self.tr_accs, self.test_losses, self.test_accs, self.NC1s, self.NC2s, self.NC2s_clean, self.NC3s

        else:
            results = {
                'wds': self.wds,
                'num_classes': self.num_classes,
                'samples_per_class': self.samples_per_class,
                'input_dims': self.model.fr.input_dims,
                'total_losses': self.losses,
                'train_accs': self.tr_accs,
                'test_losses': self.test_losses,
                'test_accs': self.test_accs,
                'NC1s': self.NC1s,
                'NC2s': self.NC2s,
                'NC2s_clean': self.NC2s_clean,
                'NC3s': self.NC3s,
                'matrices': self.matrix_dict,
            }
            self._save_results(results=results, save_model=save_model)
            return self.losses, self.tr_accs, self.test_losses, self.test_accs, self.NC1s, self.NC2s, self.NC2s_clean, self.NC3s, self.matrix_dict

    def _setup_logging(self, steps, return_matrices=False, log_every=1, return_every_n_steps=None):

        layer_names = ['features']
        for name, module in self.model.fr.named_modules():
            if isinstance(module, nn.Linear):
                layer_names.append(name)
        num_layers = self.model.fr.num_layers
        eff_num_steps = steps // log_every

        self.losses = torch.zeros(eff_num_steps)
        if self.penalize_uf:
            self.regularization_losses = torch.zeros((eff_num_steps, num_layers + 1))
            self.total_losses = torch.zeros(eff_num_steps)
        self.tr_accs = torch.zeros(eff_num_steps)
        self.test_losses = torch.zeros(eff_num_steps)
        self.test_accs = torch.zeros(eff_num_steps)
        self.NC1s = torch.zeros((eff_num_steps, num_layers))
        self.NC2s = torch.zeros((eff_num_steps, num_layers))
        self.NC2s_clean = torch.zeros((eff_num_steps, num_layers))
        self.NC3s = torch.zeros((eff_num_steps, num_layers))

        if return_matrices:  # TODO: FIX!!!
            self.matrix_dict = {
                'features': torch.zeros((steps // return_every_n_steps) + tuple(self.model.weight0.size))}
            for name, module in self.model.named_modules():
                if isinstance(module, nn.Linear):
                    dimensions = (steps // return_every_n_steps) + tuple(module.weight.data.size)
                    self.matrix_dict[name] = torch.zeros(dimensions)

    def compute_class_means(self):
        with torch.no_grad():
            clean_means = {}
            means_after_relu = {}

            for layer_idx in range(self.model.fr.num_layers):
                if layer_idx == 0:
                    clean_means[layer_idx] = torch.zeros((self.model.fr.input_dims[layer_idx], self.num_classes), device=self.device)
                else:
                    clean_means[layer_idx] = torch.zeros((self.model.fr.input_dims[layer_idx], self.num_classes), device=self.device)
                    means_after_relu[layer_idx] = torch.zeros((self.model.fr.input_dims[layer_idx], self.num_classes), device=self.device)

            for idx, (batch, labels) in enumerate(self.eval_dataloader):
                batch, labels = batch.to(self.device), labels.to(self.device)
                current_features = self.model.backbone(batch)

                for class_idx in range(self.num_classes):
                    clean_means[0][:, class_idx] += current_features[labels == class_idx].sum(dim=0)

                for layer_idx in range(self.model.fr.num_layers-1):
                    clean_outputs = self.model.fr.forward_until_layer_clean(current_features, layer_idx+1)
                    relu_outputs = self.model.fr.forward_until_layer(current_features, layer_idx+1)

                    for class_idx in range(self.num_classes):
                        clean_means[layer_idx+1][:, class_idx] += clean_outputs[labels == class_idx].sum(dim=0)
                        means_after_relu[layer_idx+1][:, class_idx] += relu_outputs[labels == class_idx].sum(dim=0)

            for layer_idx in range(self.model.fr.num_layers):
                if layer_idx == 0:
                    clean_means[layer_idx] /= self.model.fr.num_per_class
                else:
                    clean_means[layer_idx] /= self.model.fr.num_per_class
                    means_after_relu[layer_idx] /= self.model.fr.num_per_class
        return clean_means, means_after_relu

    def compute_NC1s(self, clean_means, means_after_relu):
        with torch.no_grad():
            num_layers = self.model.fr.num_layers
            NC1s = torch.zeros(num_layers)
            global_means = {}
            btwn_class_vars = {}
            wthn_class_vars = {}
            for layer_idx in range(num_layers):
                wthn_class_vars[layer_idx] = torch.zeros((self.model.fr.input_dims[layer_idx], self.model.fr.input_dims[layer_idx]), device=self.device)
                if layer_idx == 0:
                    global_means[layer_idx] = clean_means[layer_idx].mean(dim=1, keepdim=True).repeat_interleave(self.num_classes, dim=1)
                    btwn_class_vars[layer_idx] = 1/self.num_classes*(clean_means[layer_idx]-global_means[layer_idx])@(clean_means[layer_idx]-global_means[layer_idx]).T
                else:
                    global_means[layer_idx] = means_after_relu[layer_idx].mean(dim=1, keepdim=True).repeat_interleave(self.num_classes, dim=1)
                    btwn_class_vars[layer_idx] = 1/self.num_classes*(clean_means[layer_idx]-global_means[layer_idx])@(clean_means[layer_idx]-global_means[layer_idx]).T

            for batch, labels in self.eval_dataloader:
                batch, labels = batch.to(self.device), labels.to(self.device)
                current_features = self.model.backbone(batch)

                for class_idx in range(self.num_classes):
                    extended_class_mean = clean_means[0][:, class_idx].unsqueeze(dim=1).repeat_interleave(torch.sum(labels == class_idx), dim=1)
                    wthn_class_vars[0] += (current_features[labels == class_idx].T-extended_class_mean)@(current_features[labels == class_idx].T-extended_class_mean).T

                for layer_idx in range(num_layers-1):
                    relu_outputs = self.model.fr.forward_until_layer(current_features, layer_idx+1)

                    for class_idx in range(self.num_classes):
                        extended_class_mean = means_after_relu[layer_idx+1][:, class_idx].unsqueeze(dim=1).repeat_interleave(torch.sum(labels == class_idx), dim=1)
                        wthn_class_vars[layer_idx+1] += (relu_outputs[labels == class_idx].T-extended_class_mean)@(relu_outputs[labels == class_idx].T-extended_class_mean).T

            for layer_idx in range(num_layers):
                wthn_class_vars[layer_idx] /= self.samples_total
                try:
                    NC1s[layer_idx] = torch.frobenius_norm(torch.linalg.pinv(btwn_class_vars[layer_idx]) @ wthn_class_vars[layer_idx]).cpu().detach()
                except:
                    NC1s[layer_idx] = 0
        return NC1s

    def compute_NC2s(self, clean_means, means_after_relu):
        with torch.no_grad():
            num_layers = self.model.fr.num_layers
            NC2s = torch.zeros(num_layers)
            NC2s_clean = torch.zeros(num_layers)
            for layer_idx in range(num_layers):
                if layer_idx == 0:
                    NC2s[layer_idx] = torch.linalg.cond(clean_means[layer_idx]).cpu()
                    NC2s_clean[layer_idx] = NC2s[layer_idx].cpu()
                else:
                    NC2s[layer_idx] = torch.linalg.cond(means_after_relu[layer_idx]).cpu()
                    NC2s_clean[layer_idx] = torch.linalg.cond(clean_means[layer_idx]).cpu()
        return NC2s.detach(), NC2s_clean.detach()

    def compute_NC3s(self, clean_means, means_after_relu):
        with torch.no_grad():
            num_layers = self.model.fr.num_layers
            NC3s = torch.zeros(num_layers, device='cpu')
            for layer_idx in range(num_layers):
                if layer_idx == 0:
                    fmeans = clean_means[layer_idx]
                    weights = getattr(self.model.fr, self.model.fr.weight_names[layer_idx]).weight.data
                    large_enough_weights = weights[weights.norm(dim=1) >= 0.000001]
                    the_metric = torch.nn.functional.relu(((1 - ((
                                (large_enough_weights @ fmeans) / large_enough_weights.norm(dim=1,
                                                                                            keepdim=True).repeat_interleave(
                            self.num_classes, dim=1) / fmeans.norm(dim=0, keepdim=True).repeat_interleave(
                            large_enough_weights.size()[0], dim=0)).max(dim=1)[0])) * large_enough_weights.norm(
                        dim=1) / large_enough_weights.norm(dim=1).sum()).sum())
                    NC3s[layer_idx] = the_metric
                else:
                    fmeans = means_after_relu[layer_idx]
                    weights = getattr(self.model.fr, self.model.fr.weight_names[layer_idx]).weight.data
                    large_enough_weights = weights[weights.norm(dim=1) >= 0.000001]
                    the_metric = torch.nn.functional.relu(((1 - ((
                                (large_enough_weights @ fmeans) / large_enough_weights.norm(dim=1,
                                                                                            keepdim=True).repeat_interleave(
                            self.num_classes, dim=1) / fmeans.norm(dim=0, keepdim=True).repeat_interleave(
                            large_enough_weights.size()[0], dim=0)).max(dim=1)[0])) * large_enough_weights.norm(
                        dim=1) / large_enough_weights.norm(dim=1).sum()).sum())
                    NC3s[layer_idx] = the_metric
        return NC3s.detach()

    def compute_test_loss_and_acc(self):
        with torch.no_grad():
            loss = 0
            num_correct = 0
            for inputs, labels in self.test_dataloader:
                inputs, labels = inputs.to(self.device), labels.to(self.device)
                outputs = self.model(inputs)
                loss += self.loss_fcn(outputs, F.one_hot(labels.long(), num_classes=self.num_classes).float()) / 2 * len(labels)
                predictions = torch.argmin(torch.abs(outputs-1), dim=1)
                num_correct += torch.sum(predictions == labels)
            acc = num_correct / self.num_test_samples
            loss /= self.num_test_samples
        return loss.cpu().detach(), acc.cpu().detach()

    def _save_results(self, results, save_model=False):

        path = self.exp_path
        # saving the model:
        if save_model:
            torch.save(self.model, os.path.join(path, 'model.pt'))

        # saving the results
        pseudoresults = results.copy()
        for key, value in pseudoresults.items():
            if isinstance(value, torch.Tensor):
                pseudoresults[key] = value.tolist()
        filename = os.path.join(path, 'results.json')
        file = open(filename, 'w')
        json_results = json.dumps(pseudoresults)
        file.write(json_results)
        file.close()

        # saving the plots
        if not self.penalize_uf:
            plt.plot(torch.log10(results['fit_losses']))
            plt.title('Log losses')
            plt.xlabel('step')
            plt.ylabel('log10 loss')
            figname = 'losses.png'
            plt.savefig(os.path.join(path, figname))
            plt.close()

        if self.penalize_uf:
            plt.plot(torch.log10(results['total_losses']), label='total_loss')
            plt.plot(torch.log10(results['fit_losses']), label='fit_loss')
            for i in range(results['regularization_losses'].size()[1]):
                plt.plot(torch.log10(results['regularization_losses'][:, i]), label='layer_' + str(i))
            plt.legend()
            plt.title('Log losses')
            plt.xlabel('step')
            plt.ylabel('log10 loss')
            figname = 'losses.png'
            plt.savefig(os.path.join(path, figname))
            plt.close()

        plt.plot(results['train_accs'])
        plt.title('Train accuracies')
        plt.xlabel('step')
        plt.ylabel('acc')
        figname = 'tr_accs.png'
        plt.savefig(os.path.join(path, figname))
        plt.close()

        plt.plot(torch.log10(results['test_losses']))
        plt.title('Log test losses')
        plt.xlabel('step')
        plt.ylabel('log10 loss')
        figname = 'test_losses.png'
        plt.savefig(os.path.join(path, figname))
        plt.close()

        plt.plot(results['test_accs'])
        plt.title('Test accuracies')
        plt.xlabel('step')
        plt.ylabel('acc')
        figname = 'test_accs.png'
        plt.savefig(os.path.join(path, figname))
        plt.close()

        for layer_idx in range(self.model.fr.num_layers):
            plt.plot(torch.log10(results['NC1s'][:, layer_idx]), label='layer_' + str(layer_idx+1))
        plt.legend()
        plt.title('Log NC1s')
        plt.xlabel('step')
        plt.ylabel('log10 NC1')
        figname = 'nc1s.png'
        plt.savefig(os.path.join(path, figname))
        plt.close()

        for layer_idx in range(self.model.fr.num_layers):
            plt.plot(torch.log10(results['NC2s'][:, layer_idx]), label='layer_' + str(layer_idx+1))
        plt.legend()
        plt.title('Log NC2s')
        plt.xlabel('step')
        plt.ylabel('log10 NC2')
        figname = 'nc2s.png'
        plt.savefig(os.path.join(path, figname))
        plt.close()

        for layer_idx in range(self.model.fr.num_layers):
            plt.plot(torch.log10(results['NC2s_clean'][:, layer_idx]), label='layer_' + str(layer_idx+1))
        plt.legend()
        plt.title('Log NC2s w/o ReLU')
        plt.xlabel('step')
        plt.ylabel('log10 NC2')
        figname = 'nc2s_clean.png'
        plt.savefig(os.path.join(path, figname))
        plt.close()

        for layer_idx in range(self.model.fr.num_layers):
            plt.plot(torch.log10(results['NC3s'][:, layer_idx]), label='layer_' + str(layer_idx+1))
        plt.legend()
        plt.title('Log NC3s')
        plt.xlabel('step')
        plt.ylabel('log10 NC3')
        figname = 'nc3s.png'
        plt.savefig(os.path.join(path, figname))
        plt.close()


def compute_optimum_of_dufm(num_layers, wds, samples_per_class):
    if isinstance(wds, list):
        wds = torch.Tensor(wds)
    nl = num_layers
    nosmp = samples_per_class
    def obj_fcn(q):
        return wds[nl]/(wds[nl-1]**(nl-1)/nosmp/wds[0:nl-1].prod()*q**nl+2*wds[nl])+nl*wds[nl-1]*q
    xlin = torch.linspace(start=0, end=10, steps=10000)
    fcnvals = obj_fcn(xlin)
    #plt.plot(xlin, obj_fcn(xlin))
    #plt.show()

    return torch.min(fcnvals)
