import time
import gc
from abc import abstractmethod, ABC
from copy import deepcopy
from typing import Union, List, Tuple, Generic, TypeVar
from functools import partial

import random
import numpy as np
import torch
from sklearn.metrics import auc as compute_auc, roc_curve, precision_recall_curve, average_precision_score
from torch.nn import Module
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data.dataloader import DataLoader
from torchvision.transforms import Compose
from tqdm import tqdm
from torch import Tensor

import xad.utils
from xad.datasets import load_dataset, str_labels, no_classes, cset_str_description
from xad.datasets.bases import TorchvisionDataset, CombinedDataset
from xad.datasets.imagenet import ADImageNet21k
from xad.models.bases import ADNN
from xad.utils.logger import Logger, ROC, PRC, recursive_dict_update
from xad.counterfactual.dissect_trainer import XTrainer
from xad.utils.training_tools import NanGradientsError, lst_of_lsts, weight_reset, int_set_to_str


R = TypeVar('R')


class Result(Generic[R]):
    def __init__(self, classes: int):
        """
        Used to track metrics for all classes and random seeds.
        Result consists of a list (len = #classes) with each element again being a list (len = #seeds) of recorded metrics.
        Via __getitem__() one can access the recorded metrics for a specific class.
        Additionally, one can add a mean result for a class via set_mean(), which can be accessed later via
        mean() and means().
        E.g.:
        >>> rocs = Result(len(classes))
        >>> for cls in classes:
        >>>     for seed in seeds:
        >>>         ...
        >>>         rocs[cls].append(training_result_roc)
        >>>     ...
        >>>     rocs.set_mean(cls, mean_plot(rocs[cls]))
        >>> plot_many(rocs.means())
        """
        self.values = lst_of_lsts(classes)
        self.mean_values = [None] * classes

    def __getitem__(self, cls: int) -> List[R]:
        """ return the recorded metrics for the class cls """
        return self.values[cls]

    def set_mean(self, cls: int, value: R):
        """ set the mean for the class cls (e.g., a mean ROC plot) """
        self.mean_values[cls] = value

    def mean(self, cls: int, on_none_return_latest=False) -> R:
        """
        @param cls: determines the class whose mean is to be returned.
        @param on_none_return_latest: whether to return the recorded metric for the latest seed or None for missing means.
        @return: the set mean.
        """
        mean = self.mean_values[cls]
        latest = self.values[cls][-1] if len(self.values[cls]) > 0 else None
        return mean if mean is not None else (latest if on_none_return_latest else None)

    def means(self, on_none_return_latest=False) -> List[R]:
        """ returns a list of all set means """
        return [self.mean(cls, on_none_return_latest) for cls in range(len(self.mean_values))]

    def __str__(self) -> str:
        return str(self.values)

    def __repr__(self) -> str:
        return repr(self.values)

    def __iter__(self):
        return iter(self.values)


class XADTrainer(ABC):
    # whether to keep the model snapshots in RAM and make the `run` method return them in addition to storing them on the disk
    KEEP_SNAPSHOT_IN_RAM = False

    def __init__(self, model: ADNN, train_transform: Compose, test_transform: Compose,
                 dataset: str, oe_dataset: List[str], datapath: str, logger: Logger,
                 epochs: int, lr: float, wdk: float, milestones: List[int], batch_size: int,
                 device: Union[str, torch.device] = 'cuda',
                 oe_limit_samples: Union[int, List[int]] = np.infty, oe_limit_classes: Union[int, List[set[int]]] = np.infty,
                 workers: int = 2, xtrainer: XTrainer = None, train_split: float = 0.0):
        """
        The base trainer class.
        It defines a `run` method that iterates over all class sets and multiple random seeds per class set.
        For each class_set-seed combination, it trains and evaluates a given AD model.
        The objective needs to be implemented (see :method:`loss`, etc.); this is an abstract class.
        Pre-implemented trainers with objectives can be found in other files; e.g., :class:`xad.training.hsc.HSCTrainer`.
        The trainer treats all classes in the current class set as normal.
        The trainer always evaluates using the full test set.
        For training, it uses only normal samples and perhaps auxiliary anomalies from a different source (Outlier Exposure).
        For a list of all available training configurations, have a look at the parameters below.

        @param model: some model that is to be trained/evaluated. For training multiple classes/seeds, a separate
            copy of the model is initialized and trained.
        @param train_transform: pre-processing pipeline applied to training samples (included data augmentation).
        @param test_transform: pre-processing pipeline applied to test samples (including data augmentation).
        @param dataset: string specifying the dataset, see :data:`xad.datasets.__init__.DS_CHOICES`.
        @param oe_dataset: string specifying the Outlier Exposure dataset, see :data:`xad.datasets.__init__.DS_CHOICES`.
        @param datapath: filepath to where the datasets are located or automatically to be downloaded to.
            Specifies the root directory for all datasets.
        @param logger: a logger instance that is used to print current training progress and log metrics to the disk.
        @param epochs: how many full iterations of the dataset are to be trained per class-seed combination.
        @param lr: initial learning rate.
        @param wdk: weight decay.
        @param milestones: milestones for the learning rate scheduler; at each milestone the learning rate is reduced by 0.1.
        @param batch_size: batch size. If there is an OE dataset, the overall batch size will be twice as large
            as an equally-sized batch of OE samples gets concatenated to the batch of normal training samples.
        @param device: torch device to be used. 'cuda' uses the first available GPU.
        @param oe_limit_samples: limits the number of different samples for the OE dataset.
            If of type int, causes a random selection a subset of OE samples.
            If of type list, causes a fixed selection of OE samples with the specified ids in the list.
        @param oe_limit_classes: limits the number of different classes for the OE dataset.
            If of type int, causes a random selection a subset of OE classes.
            If of type list, causes a fixed selection of OE class sets per normal class set, with the specified ids in the set.
        @param workers: number of data-loading workers. See :class:`torch.utils.data.DataLoader`.
        @param xtrainer: A trainer for counterfactual training. It will be invoked after each successful training
            of a class-seed combination.
        @param train_split: Determines the ratio of samples that are split from the training data for additional validation.
            The training data contains only normal and perhaps OE samples. The split will be used to compute an AUC and
            anomaly score distribution to test how well OE generalizes. Train_split Can be 0.
        """
        logger.logsetup(
            {k: v for k, v in locals().items() if k not in ['self', 'xtrainer']}, nets=[model]
        )  # log training setup
        self.model = model.cpu() if model is not None else model
        self.train_transform = train_transform
        self.test_transform = test_transform
        self.dsstr = dataset
        self.oe_dsstr = oe_dataset
        self.oe_limit_samples = oe_limit_samples
        self.oe_limit_classes = oe_limit_classes
        self.datapath = datapath
        self.logger = logger
        self.device = device
        self.epochs = epochs
        self.lr = lr
        self.wdk = wdk
        self.milestones = milestones
        self.batch_size = batch_size
        self.center = None
        self.workers = workers
        self.ds = None
        self.xtrainer = xtrainer
        self.train_split = train_split

    def run(self, run_classes: List[set[int]] = None, run_seeds: int = 1,
            load: List[List[Union[Module, str]]] = None, xload: List[List[str]] = None,
            seeds: List[int] = None, tqdm_mininterval: float = 0.1, **train_loss_kwargs) -> Tuple[List[List[ADNN]], dict]:
        """
        Iterates over all class sets and multiple random seeds per class set.
        For each class_set-seed combination, it trains and evaluates a given AD model using the trainer's loss
        (see method:`XADTrainer.loss`). For example, see :class:`xad.training.hsc` for an implementation of the HSC loss.
        This method also prepares the complete dataset (all splits and perhaps outlier exposure)
        at the start of the training of each class_set-seed combination.

        @param run_classes: which class sets to run on. If none, creates a collection of singletons for all classes.
            All classes of the current class set are considered normal.
        @param run_seeds: how often to train a class with different random seeds.
        @param load: a list (len = #class_sets) with each element again being a list (len = #seeds) of model snapshots.
            The snapshots can either be PyTorch Module instances defining the model directly or strings specifying
            filepaths to stored snapshots in the form of dictionaries. These contain the model, optimizer, and scheduler state.
            Whenever the trainer starts training with the j-th random seed for class set i, it tries to initialize the model
            with the snapshot found in load[i][j]. If not available, it trains a model from scratch.
            The model snapshots need to match the model architecture specified in the initialization of the trainer.
            If the snapshots are dictionaries, the trainer just continues training at the stored last epoch and, if the number
            of overall epochs has been reached already, just evaluates the model.
        @param xload: similar to load but contains model snapshots for counterfactual training.
        @param seeds: Optional. A list (len = #class_sets * run-seeds) for fixed random torch seeds to reproduce experiments.
        @param tqdm_mininterval: Minimum progress display update interval [default: 0.1] seconds.
        @return: returns a tuple of
            - a list (len = #class_sets) with each element again being a list (len = #seeds) of trained AD models.
              If XADTrainer.KEEP_SNAPSHOT_IN_RAM is False, this list will be None to prevent out-of-memory errors.
            - a dictionary containing all important final metrics recorded at the end of the training and during
              evaluation.
            All the returned information is always also stored on the disk using the trainer's logger.
        """
        timer = self.logger.Timer(self.logger, 'Overall the experiment')
        timer.__enter__()

        # prepare variables
        run_classes = run_classes if run_classes is not None else [{i} for i in range(no_classes(self.dsstr))]
        run_classes_strings = [cset_str_description(self.dsstr, cset) for cset in run_classes]
        train_cls_rocs = Result(len(run_classes))
        eval_cls_rocs = Result(len(run_classes))
        eval_cls_prcs = Result(len(run_classes))
        models = lst_of_lsts(len(run_classes))
        ascores_log = lst_of_lsts(len(run_classes))
        xtrainer_metrics = {}
        assert self.ds is None or len(run_classes) == 1, \
            'pre-loading DS (setting trainer.ds to something) only allowed for one class set'

        # prepare random seeds
        assert seeds is None or len(seeds) == len(run_classes) * run_seeds, \
            f'There are {len(seeds)} fixed seeds but {len(run_classes) * run_seeds} runs.'
        seeds = seeds if seeds is not None else torch.randint(
            0, 2**32-1, (len(run_classes) * run_seeds, )
        ).squeeze().numpy().tolist()
        seeds = seeds if not isinstance(seeds, int) else [seeds]
        self.logger.logsetup(
            {'run_classes': run_classes, 'run_seeds': run_seeds, 'load': load, 'seeds': seeds}
        )
        seeds_iter = iter(seeds)

        # Loop over all class sets, considering in each step all classes in the current class set normal
        for cid, (cset, cstr) in enumerate(zip(run_classes, run_classes_strings)):
            oe_class_limit = self.oe_limit_classes[cid] if isinstance(self.oe_limit_classes, (list, tuple)) \
                else self.oe_limit_classes

            for seed in range(run_seeds):
                cur_seed = next(seeds_iter)
                torch.manual_seed(cur_seed)
                np.random.seed(cur_seed)
                random.seed(cur_seed)
                self.logger.print(
                    f'\n------ START TRAINING CLS {cset} "{cstr}" WITH {seed}'
                    f'{"st" if seed == 1 else ("nd" if seed == 2 else ("rd" if seed == 3 else "th"))}'
                    f' SEED {cur_seed} ------'
                )

                if load is not None and len(load) > cid and len(load[cid]) > seed:
                    cur_load = load[cid][seed]
                else:
                    cur_load = None
                if xload is not None and len(xload) > cid and len(xload[cid]) > seed:
                    cur_xload = xload[cid][seed]
                else:
                    cur_xload = None

                # prepare model
                def copy_model():
                    if cur_load is not None and isinstance(cur_load, Module):
                        self.logger.print('Loaded model (not snapshot).')
                        model = deepcopy(cur_load)
                    else:
                        model = deepcopy(self.model)
                        model.apply(weight_reset)
                    assert all([p.is_leaf for p in self.model.parameters()])
                    for n, p in model.named_parameters():
                        p.detach_().requires_grad_()  # otherwise jit models don't work due to grad_fn=clone_backward
                    return model

                orig_cache_size = ADImageNet21k.img_cache_size
                if isinstance(cur_load, str) and self.load_epochs_only(cur_load) >= self.epochs:
                    ADImageNet21k.img_cache_size = 0
                ds = load_dataset(
                    self.dsstr, self.datapath, cset, 0,
                    self.train_transform, self.test_transform, self.logger, self.oe_dsstr,
                    self.oe_limit_samples, oe_class_limit,
                ) if self.ds is None else self.ds
                ds.train_split(self.train_split)
                ADImageNet21k.img_cache_size = orig_cache_size

                # train
                for i in range(5):
                    try:
                        model = copy_model()
                        model, roc, train_labels, train_ascores = self.train_cls(
                            model, ds, cset, cstr, seed, cur_load, tqdm_mininterval, **train_loss_kwargs
                        )
                        break
                    except NanGradientsError as err:  # try once more
                        self.logger.warning(
                            f'Gradients got NaN for class {cset} "{cstr}" and seed {seed}. '
                            f'Happened {i} times so far. Try once more.'
                        )
                        ds = load_dataset(
                            self.dsstr, self.datapath, cset, 0,
                            self.train_transform, self.test_transform, self.logger, self.oe_dsstr,
                            self.oe_limit_samples, oe_class_limit,
                        ) if self.ds is None else self.ds
                        if i == 3 - 1:
                            model, roc = None, None
                            self.logger.warning(
                                f'Gradients got NaN for class {cset} "{cstr}" and seed {seed}. '
                                f'Happened {i} times so far. Try no more. Set model and roc to None.'
                            )
                models[cid].append(model)
                train_cls_rocs[cid].append(roc)
                self.logger.plot_many(
                    train_cls_rocs.means(True), run_classes_strings, name='training_intermediate_roc', step=cid*run_seeds+seed
                )

                # eval 
                model = models[cid][-1]
                if model is not None:
                    roc, prc, eval_labels, eval_ascores = self.eval_cls(model, ds, cset, cstr, seed, tqdm_mininterval)
                else:
                    roc, prc = None, None
                eval_cls_rocs[cid].append(roc)
                eval_cls_prcs[cid].append(prc)
                self.logger.plot_many(
                    eval_cls_rocs.means(True), run_classes_strings, name='eval_intermediate_roc', step=cid*run_seeds+seed
                )
                self.logger.plot_many(
                    eval_cls_prcs.means(True), run_classes_strings, name='eval_intermediate_prc', step=cid*run_seeds+seed
                )

                # ROC and anomaly scores for train-split
                if model is not None and self.train_split > 0.0:
                    roc, train_split_labels, train_split_ascores = self.train_split_scores(
                        model, ds, cset, cstr, seed, tqdm_mininterval
                    )

                # log histogram for anomaly scores
                if model is not None and len(train_ascores) > 0:
                    train_normal_scores = train_ascores[train_labels == ds.nominal_label]
                    train_anomalous_scores = train_ascores[train_labels == ds.anomalous_label]
                    eval_normal_scores = eval_ascores[eval_labels == ds.nominal_label]
                    eval_anomalous_scores = eval_ascores[eval_labels == ds.anomalous_label]
                    train_split_normal_scores = None
                    train_split_anom_scores = None
                    if self.train_split > 0.0:
                        train_split_normal_scores = train_split_ascores[train_split_labels == ds.nominal_label]
                        train_split_anom_scores = train_split_ascores[train_split_labels == ds.anomalous_label]
                    if train_split_normal_scores is not None and train_split_anom_scores is not None:
                        self.logger.hist_ascores(
                            [train_normal_scores, train_anomalous_scores, eval_normal_scores,
                             eval_anomalous_scores, train_split_normal_scores, train_split_anom_scores],
                            ['train_normal', 'train_anomalous', 'test_normal', 'test_anomalous', 'val_normal', 'val_anomalous'],
                            f'C{int_set_to_str(cset)}-{cstr}-S{seed}__anomaly_scores',
                        )
                    else:
                        self.logger.hist_ascores(
                            [train_normal_scores, train_anomalous_scores, eval_normal_scores, eval_anomalous_scores],
                            ['train_normal', 'train_anomalous', 'test_normal', 'test_anomalous'],
                            f'C{int_set_to_str(cset)}-{cstr}-S{seed}__anomaly_scores',
                        )
                    ascores_log[cid].append({
                        'train_normal_scores': train_normal_scores.mean().item(),
                        'eval_normal_scores': eval_normal_scores.mean().item(),
                        'train_anomalous_scores': train_anomalous_scores.mean().item(),
                        'eval_anomalous_scores': eval_anomalous_scores.mean().item(),
                    })

                if model is not None:

                    # explanation
                    if self.xtrainer is not None:
                        # def compute_xad_ascore(ad_features: torch.Tensor, lamb=1e6) -> torch.Tensor:
                        #     asc = self.compute_anomaly_score(ad_features, self.center)
                        #     return asc.mul(-lamb).exp().mul(-1).add(1)
                        compute_xad_ascore = partial(self.compute_anomaly_score, center=self.center)
                        xtrainer_metrics_cls_seed = self.xtrainer.run(
                            models[cid][-1], compute_xad_ascore,
                            ds, cset, cstr, seed, self.workers, ad_device=self.device, load_models=cur_xload,
                            tqdm_mininterval=tqdm_mininterval, ad_trainer_parent=self
                        )
                        if xtrainer_metrics_cls_seed is not None:
                            recursive_dict_update(xtrainer_metrics, xtrainer_metrics_cls_seed)

                    if not XADTrainer.KEEP_SNAPSHOT_IN_RAM:
                        models[cid][-1] = None

                del ds
                gc.collect()

            # seed-wise many_roc plots for current class 
            cls_mean_roc = self.logger.plot_many(
                train_cls_rocs[cid], None, name=f'training__C{int_set_to_str(cset)}-{cstr}__roc', step=cid
            )
            train_cls_rocs.set_mean(cid, cls_mean_roc)
            cls_mean_roc = self.logger.plot_many(
                eval_cls_rocs[cid], None, name=f'eval__C{int_set_to_str(cset)}-{cstr}__roc', step=cid
            )
            eval_cls_rocs.set_mean(cid, cls_mean_roc)
            cls_mean_prc = self.logger.plot_many(
                eval_cls_prcs[cid], None, name=f'eval__C{int_set_to_str(cset)}-{cstr}__prc', step=cid
            )
            eval_cls_prcs.set_mean(cid, cls_mean_prc)

        # training: compute cls-wise roc curves and combine in a final overview roc plot
        if any([t is not None for t in train_cls_rocs.means()]):
            mean_auc = np.mean([m.auc for m in train_cls_rocs.means() if m is not None]) 
            std_auc = np.std([m.auc for m in train_cls_rocs.means() if m is not None])
            self.logger.logtxt(f'Training: Overall {mean_auc*100:04.2f}% +- {std_auc*100:04.2f} AUC.')
            self.logger.plot_many(train_cls_rocs.means(), run_classes_strings, name='training_roc')

            # print an overview of cls-wise rocs
            print('--------------- OVERVIEW ------------------')
            for auc, cstr in ((a.auc, c) for a, c in zip(train_cls_rocs.means(), run_classes_strings) if a is not None):
                print(f'Training: Class "{cstr}" yields {auc*100:04.2f}% AUC.')
            print(f'Training: Overall {mean_auc*100:04.2f}% +- {std_auc*100:04.2f} AUC.')

        # evaluation: compute cls-wise roc curves and combine in a final overview roc plot
        mean_auc = np.mean([m.auc for m in eval_cls_rocs.means() if m is not None]) 
        std_auc = np.std([m.auc for m in eval_cls_rocs.means() if m is not None])
        self.logger.plot_many(eval_cls_rocs.means(), run_classes_strings, name='eval_roc')
        mean_avg_prec = np.mean([m.avg_prec for m in eval_cls_prcs.means() if m is not None]) 
        std_avg_prec = np.std([m.avg_prec for m in eval_cls_prcs.means() if m is not None])
        self.logger.plot_many(eval_cls_prcs.means(), run_classes_strings, name='eval_prc')

        # summarize split-wise anomaly scores
        ascores_log = {
            **{
                k: {cid: {seed: ascores_log[cid][seed][k]}} for cid in range(len(run_classes)) for seed in range(run_seeds)
                for k in ascores_log[0][0].keys()
            },
            **{
                f"{k}_mean": np.mean([ascores_log[cid][seed][k] for cid in range(len(run_classes)) for seed in range(run_seeds)])
                for k in ascores_log[0][0].keys()
            },
            **{
                f"{k}_std": np.std([ascores_log[cid][seed][k] for cid in range(len(run_classes)) for seed in range(run_seeds)])
                for k in ascores_log[0][0].keys()
            },
            "eval_anom_vs_nom_ascores_mean": np.mean([
                ascores_log[cid][seed]['eval_anomalous_scores'] - ascores_log[cid][seed]['eval_normal_scores']
                for cid in range(len(run_classes)) for seed in range(run_seeds)
            ]),
            "eval_anom_vs_nom_ascores_std": np.std([
                ascores_log[cid][seed]['eval_anomalous_scores'] - ascores_log[cid][seed]['eval_normal_scores']
                for cid in range(len(run_classes)) for seed in range(run_seeds)
            ])
        }

        # print some overview of the achieved scores
        self.logger.logtxt('--------------- OVERVIEW ------------------')
        self.logger.logtxt(f'Eval: Overall {mean_avg_prec*100:04.2f}% +- {std_avg_prec*100:04.2f}% AvgPrec.')
        for auc, std, cstr in ((a.auc, a.std, c) for a, c in zip(eval_cls_rocs.means(), run_classes_strings) if a is not None):
            self.logger.logtxt(f'Eval: Class "{cstr}" yields {auc*100:04.2f}% +- {std*100:04.2f}% AUC.')
        self.logger.logtxt(f'Eval: Overall {mean_auc*100:04.2f}% +- {std_auc*100:04.2f}% AUC.')

        self.logger.logjson('results', {
            'eval_mean_auc': mean_auc, 'eval_std_auc': std_auc, 'eval_mean_avg_prec': mean_avg_prec,
            'eval_cls_rocs': [[roc.get_score() if roc is not None else None for roc in cls_roc] for cls_roc in eval_cls_rocs],
            'classes': run_classes_strings, 'ascores_dist': ascores_log
        })
        timer.__exit__(None, None, None)

        return models, {
            'mean_auc': mean_auc, 'mean_avg_prec': mean_avg_prec, 'std_auc': std_auc,
            'cls_aucs': [[roc.get_score() if roc is not None else None for roc in cls_roc] for cls_roc in eval_cls_rocs],
            'ascores_dist': ascores_log, 'counterfactual': xtrainer_metrics
        }

    def train_cls(self, model: torch.nn.Module, ds: TorchvisionDataset, cset: set[int], clsstr: str, seed: int,
                  load: Union[Module, str] = None, tqdm_mininterval: float = 0.1, logging=True,
                  **train_loss_kwargs) -> Tuple[torch.nn.Module, ROC, Tensor, Tensor]:
        """
        Trains the given model for the current class.
        @param model: the AD model that is to be trained.
        @param ds: the dataset containing normal training samples and perhaps Outlier Exposure.
            If it contains OE, the dataset is an instance of a CombinedDataset (see :class:`xad.datasets.bases.CombinedDataset`).
            The loader of a combined dataset returns a batch where the first half are normal training samples and
            the second half is made of Outlier Exposure.
        @param cset: the current normal class set.
        @param clsstr: A string representation for the current class set.
            (e.g., 'airplane' for ds being CIFAR-10 and the class set being {0}).
        @param seed: the current iteration of random seeds. E.g., `2` denotes the second random seed of the current class.
        @param load: if not None, initializes the AD model with `load`. `load` can either be a PyTorch module or a filepath.
            If it is a filepath, also loads the last epoch with which the stored model was trained and only trains
            for the remaining number of epochs. The architecture found in `load` needs to match the one specified in the
            trainer's initialization.
        @param: Minimum progress display update interval [default: 0.1] seconds.
        @return: the trained model and training ROC.
        """
        # ---- prepare model and variables
        model = model.to(self.device).train()
        epochs = self.epochs
        cls_roc = None

        # ---- optimizers and loaders
        opt = torch.optim.Adam(model.parameters(), lr=self.lr, weight_decay=self.wdk, amsgrad=False)
        sched = torch.optim.lr_scheduler.MultiStepLR(opt, self.milestones, 0.1)
        loader, _ = ds.loaders(self.batch_size, num_workers=self.workers, persistent=True, device=self.device)
        if epochs > 0 and seed == 0 and self.logger.active:
            prev = ds.preview(40, True,  classes=[0, 1] if isinstance(ds, CombinedDataset) else [0])
            stats = ds.n_normal_anomalous()
            self.logger.logimg(
                f'training__C{int_set_to_str(cset)}-{clsstr}__preview', prev, nrow=prev.shape[0] // len(stats),
                rowheaders=[str(stats[k]) for k in sorted(stats.keys())]
            )
            del prev

        # ---- prepare trackers and loggers  
        ep, loss = self.load(load if isinstance(load, str) else None, model, opt, sched), None
        start_ep, ep_labels, ep_ascores = ep, [], []
        center = self.center = self.prepare_metric(clsstr, loader, model, seed)
        to_track = {
            'ep': lambda: f'{ep+1:{len(str(epochs))}d}/{epochs}', 'loss': lambda: loss.item() if loss is not None else None, 
            'roc': lambda: cls_roc.auc if cls_roc is not None else None, 'lr': lambda: sched.get_last_lr()[0]
        }
        tracker = self.logger.track([epochs, len(loader)], to_track, f'training {clsstr}', mininterval=tqdm_mininterval)

        # ---- loop over epochs
        for ep in range(ep, epochs):
            ep_labels, ep_ascores = [], []

            # ---- loop over batches
            for imgs, lbls, idcs in loader:

                # ---- compute loss and optimize
                opt.zero_grad()
                image_features = model(imgs)
                loss = self.loss(image_features, lbls, center, inputs=imgs, **train_loss_kwargs)
                loss.backward()
                opt.step()
                opt.zero_grad()
                anomaly_scores = self.compute_anomaly_score(image_features, center, inputs=imgs).cpu()
                loss = loss.cpu()

                # ---- log stuff
                ep_labels.append(lbls.detach().cpu())
                ep_ascores.append(anomaly_scores.detach().cpu())
                if logging:
                    if seed == 0:
                        self.logger.add_scalar(f'training__C{int_set_to_str(cset)}-{clsstr}__ep', ep, tracker.n)
                    self.logger.add_scalar(
                        f'training__C{int_set_to_str(cset)}-{clsstr}-S{seed}__loss', loss.item(), tracker.n,
                    )
                tracker.update([0, 1])

            # ---- prepare labels and anomaly scores of epoch
            ep_labels, ep_ascores = torch.cat(ep_labels), torch.cat(ep_ascores)
            if ep_ascores.isnan().sum() > 0:
                raise NanGradientsError()

            # ---- compute training AuROC
            if (ep_labels == 1).sum() > 0:
                fpr, tpr, thresholds = roc_curve(ep_labels, ep_ascores.squeeze())
                auc = compute_auc(fpr, tpr)
                cls_roc = ROC(tpr, fpr, thresholds, auc)

            # ---- log epoch stuff
            self.logger.flush(1e-2)
            if logging:
                self.logger.tb_writer.add_histogram(
                    f'Training: CLS{int_set_to_str(cset)} SEED{seed} anomaly_scores normal',
                    ep_ascores[ep_labels == 0], ep,
                )
                if (ep_labels == 1).sum() > 0:
                    self.logger.tb_writer.add_histogram(
                        f'Training: CLS{int_set_to_str(cset)} SEED{seed} anomaly_scores anomalous',
                        ep_ascores[ep_labels == 1], ep,
                    )
                    self.logger.add_scalar(
                        f'training__C{int_set_to_str(cset)}-{clsstr}-S{seed}__AUC', cls_roc.auc*100, ep,
                    )

            # ---- update tracker and scheduler
            sched.step()
            tracker.update([1, 0])

        tracker.close()
        if logging and start_ep < epochs:
            self.logger.snapshot(f'snapshot_cls{int_set_to_str(cset)}_it{seed}', model, epoch=self.epochs)
        if len(ep_ascores) == 0:  # prolly continued run, fully trained AD model, still need to compute training ascores
            with torch.no_grad():
                for imgs, lbls, idcs in loader:
                    image_features = model(imgs)
                    loss = self.loss(image_features, lbls, center, inputs=imgs, **train_loss_kwargs)
                    anomaly_scores = self.compute_anomaly_score(image_features, center, inputs=imgs).cpu()
                    ep_labels.append(lbls.cpu())
                    ep_ascores.append(anomaly_scores.cpu())
                ep_labels, ep_ascores = torch.cat(ep_labels), torch.cat(ep_ascores)
                if (ep_labels == 1).sum() > 0:
                    fpr, tpr, thresholds = roc_curve(ep_labels, ep_ascores.squeeze())
                    auc = compute_auc(fpr, tpr)
                    cls_roc = ROC(tpr, fpr, thresholds, auc)
        del loader
        return model.cpu().eval(), cls_roc, ep_labels, ep_ascores

    def eval_cls(self, model: torch.nn.Module, ds: TorchvisionDataset, cset: set[int], clsstr: str, seed: int,
                 tqdm_mininterval: float = 0.1, logging=True) -> Tuple[ROC, PRC, Tensor, Tensor]:
        """
        Evaluates the given model for the current class.
        Returns and logs the ROC and PRC metrics.
        @param model: the (trained) model to be evaluated.
        @param ds: the dataset to be used for evaluating (should be a test split of some dataset).
        @param cset: the current normal class set.
        @param clsstr: A string representation for the current class set.
            (e.g., 'airplane' for ds being CIFAR-10 and the class set being {0}).
        @param seed: the current iteration of random seeds. E.g., `2` denotes the second random seed of the current class.
        @param tqdm_mininterval: Minimum progress display update interval [default: 0.1] seconds.
        @return: ROC and PRC metric.
        """
        model = model.to(self.device).eval()
        _, loader = ds.loaders(self.batch_size, num_workers=self.workers, shuffle_test=False, device=self.device, )
        if seed == 0 and self.logger.active:
            prev = ds.preview(20, False)
            stats = ds.n_normal_anomalous(False)
            self.logger.logimg(
                f'eval_cls{int_set_to_str(cset)}-{clsstr}_preview', prev, nrow=prev.shape[0] // 2,
                rowheaders=[str(stats[0]), str(stats[1])],
            )
            del prev

        center = self.center
        ep_labels, ep_ascores = [], []  # [...], list of all labels/etc.
        procbar = tqdm(desc=f'evaluating {clsstr}', total=len(loader), mininterval=tqdm_mininterval)
        for imgs, lbls, idcs in loader:
            with torch.no_grad():
                image_features = model(imgs)
            anomaly_scores = self.compute_anomaly_score(image_features, center, inputs=imgs)
            ep_labels.append(lbls.cpu())
            ep_ascores.append(anomaly_scores.cpu())
            procbar.update()
        procbar.close()
        ep_labels, ep_ascores = torch.cat(ep_labels), torch.cat(ep_ascores)

        fpr, tpr, thresholds = roc_curve(ep_labels, ep_ascores.squeeze())
        auc = compute_auc(fpr, tpr)
        cls_roc = ROC(tpr, fpr, thresholds, auc)

        prec, rec, thresholds = precision_recall_curve(ep_labels, ep_ascores.squeeze())
        average_prec = average_precision_score(ep_labels, ep_ascores.squeeze())
        cls_prc = PRC(prec, rec, thresholds, average_prec)

        if logging:
            self.logger.logtxt(
                f'Eval: class "{clsstr}" yields {auc * 100:04.2f}% AUC and {average_prec * 100:04.2f}% '
                f'average precision (seed {seed}).'
            )
            self.logger.tb_writer.add_histogram(
                f'Eval: (SD{seed}) anomaly_scores cls{int_set_to_str(cset)} nominal', ep_ascores[ep_labels == 0], 0, walltime=0
            )
            self.logger.tb_writer.add_histogram(
                f'Eval: (SD{seed}) anomaly_scores cls{int_set_to_str(cset)} anomalous', ep_ascores[ep_labels == 1], 0, walltime=0
            )
            model.cpu()

        return cls_roc, cls_prc, ep_labels, ep_ascores

    def train_split_scores(self, model: torch.nn.Module, ds: TorchvisionDataset, cset: set[int], clsstr: str, seed: int,
                           tqdm_mininterval: float = 0.1) -> Tuple[ROC, Tensor, Tensor]:
        model = model.to(self.device).eval()
        loader = ds.val_loader(self.batch_size, num_workers=self.workers, device=self.device, )
        center = self.center
        ep_labels, ep_ascores = [], []  # [...], list of all labels/etc.
        procbar = tqdm(desc=f'train_split scores for cls {clsstr}', total=len(loader), mininterval=tqdm_mininterval)
        for imgs, lbls, idcs in loader:
            with torch.no_grad():
                image_features = model(imgs)
            anomaly_scores = self.compute_anomaly_score(image_features, center, inputs=imgs)
            ep_labels.append(lbls.cpu())
            ep_ascores.append(anomaly_scores.cpu())
            procbar.update()
        procbar.close()
        ep_labels, ep_ascores = torch.cat(ep_labels), torch.cat(ep_ascores)
        fpr, tpr, thresholds = roc_curve(ep_labels, ep_ascores.squeeze())
        auc = compute_auc(fpr, tpr)
        cls_roc = ROC(tpr, fpr, thresholds, auc)
        self.logger.logtxt(
            f'Train-split: class "{clsstr}" yields {auc * 100:04.2f}% AUC (seed {seed}).'
        )
        self.logger.tb_writer.add_histogram(
            f'Train-split: (SD{seed}) anomaly_scores cls{int_set_to_str(cset)} nominal', ep_ascores[ep_labels == 0], 0,
        )
        if ep_ascores[ep_labels == 1].shape[0] > 0:
            self.logger.tb_writer.add_histogram(
                f'Train-split: (SD{seed}) anomaly_scores cls{int_set_to_str(cset)} anomalous', ep_ascores[ep_labels == 1], 0,
            )
        model.cpu()
        return cls_roc, ep_labels, ep_ascores

    def load(self, path: str, model: torch.nn.Module,
             opt: torch.optim.Optimizer = None, sched: _LRScheduler = None) -> int:
        """
        Loads a snapshot of the model including training state.
        @param path: the filepath where the snapshot is stored.
        @param model: the model instance into which the parameters of the found snapshot are loaded.
            Hence, the architectures need to match.
        @param opt: the optimizer instance into which the training state is loaded.
        @param sched: the learning rate scheduler into which the training state is loaded.
        @return: the last epoch with which the snapshot's model was trained.
        """
        epoch = 0
        if path is not None:
            snapshot = torch.load(path)
            net_state = snapshot.pop('net', None)
            opt_state = snapshot.pop('opt', None)
            sched_state = snapshot.pop('sched', None)
            epoch = snapshot.pop('epoch', 0)
            if net_state is not None:
                model.load_state_dict(net_state)
            if opt_state is not None and opt is not None:
                opt.load_state_dict(opt_state)
            if sched_state is not None and sched is not None:
                sched.load_state_dict(sched_state)
            self.logger.print(f'Loaded snapshot at epoch {epoch}')
        return epoch

    def load_epochs_only(self, path: str):
        """ loads the last epoch with which the snapshot's model found at `path` was trained """
        if path is None:
            return 0
        else:
            return torch.load(path).pop('epoch', 0)

    @abstractmethod
    def prepare_metric(self, cstr: str, loader: DataLoader, model: torch.nn.Module, seed: int, **kwargs) -> torch.Tensor:
        """
        Implement a 'center' (DSVDD) or, in general, a reference tensor for the anomaly score metric.
        Executed at the beginning of training (even if training epochs == 0).
        Optional for Outlier Exposure-based methods.
        @param cstr: the string representation of the current class (e.g., 'airplane' for ds being CIFAR-10 and class being 0).
            For the one vs. rest benchmark, the current class is the normal class.
        @param loader: a data loader that can be used to compute the reference tensor.
            The trainer's `train_cls` method executes `prepare_metric` and passes the training loader for this purpose.
        @param model: The model for which the reference tensor is to be computed.
        @param seed: the current iteration of random seeds. E.g., `2` denotes the second random seed of the current class.
        @param kwargs: potential further implementation-specific parameters.
        @return: the reference tensor.
        """
        pass

    @abstractmethod
    def compute_anomaly_score(self, features: torch.Tensor, center: torch.Tensor, **kwargs) -> torch.Tensor:
        """
        Implement a method that computes the anomaly scores for a given batch of image features.
        @param features: a batch of image features (shape: n x d). The trainer computes these features with the AD model.
        @param center: a center or, in general, a reference tensor (shape: d) that can be used to compute the anomaly scores.
        @param kwargs: potential further implementation-specific parameters.
        @return: the batch of anomaly scores (shape: n).
        """
        pass

    @abstractmethod
    def loss(self, features: torch.Tensor, labels: torch.Tensor,  center: torch.Tensor, **kwargs) -> torch.Tensor:
        """
        Implement a method that computes the loss for a given batch of image features.
        @param features: a batch of image features (shape: n x d). The trainer computes these features with the AD model.
        @param labels: a batch of corresponding integer labels (shape: n).
        @param center: a center or, in general, a reference tensor (shape: d) that can be used to compute the anomaly scores.
        @param kwargs: potential further implementation-specific parameters.
        @return: the loss (scalar).
        """
        pass
