import gc
import os
import os.path as pt
import traceback
import json
from abc import ABC
from copy import deepcopy
from typing import Callable, List, Mapping, Tuple
from datetime import datetime

import numpy as np
import shutil
import sklearn
import torch.nn as nn
import torch.nn
import torchvision.datasets
from sklearn.metrics import accuracy_score
from torch.nn import Module
from torch.nn.functional import cross_entropy
from torch.optim.lr_scheduler import _LRScheduler
from torchvision.datasets.folder import default_loader
from torchvision.transforms.functional import to_tensor
from tqdm import tqdm

from xad.counterfactual.eval import get_roc, compute_fid_scores
from xad.counterfactual.bases import XTrainer, huber_distance, hinge_disc_loss
from xad.datasets.bases import CombinedDataset, TorchvisionDataset
from xad.models.bases import ConceptNN, ConditionalDiscriminator, ConditionalGenerator
from xad.utils.logger import Logger
from xad.utils.training_tools import NanGradientsError, int_set_to_str, weight_reset
from xad.utils.data_tools import random_split_tensor


class DISSECTTrainer(XTrainer):
    def __init__(self, xmodels: List[nn.Module],
                 n_concepts: int,
                 epochs: int, lr: float, wdk: float, milestones: List[int],
                 batch_size: int, logger: Logger = None, oe=True, n_discrete_anomaly_scores: int = 2,
                 lamb_gen=1, lamb_asc=1, lamb_cyc=1, lamb_conc=1e2, gen_every=1, disc_every=1, cluster_ncc=False,
                 milestone_alpha: float = 0.1,
                 **kwargs):
        self.__setup = {f'x_{k}': v for k, v in locals().items() if k not in ['self', 'generator', 'concept_classifier']}
        if not isinstance(xmodels[0], ConditionalGenerator):
            raise ValueError(f"Expected xmodels[0] to be a ConditionalGenerator, but it is {xmodels[0].__class__}")
        if not isinstance(xmodels[1], ConditionalDiscriminator):
            raise ValueError(f"Expected xmodels[1] to be a ConditionalDiscriminator, but it is {xmodels[1].__class__}")
        if not isinstance(xmodels[2], ConceptNN):
            raise ValueError(f"Expected xmodels[2] to be a ConceptNN, but it is {xmodels[2].__class__}")
        self.generator: ConditionalGenerator = xmodels[0]
        self.discriminator: ConditionalDiscriminator = xmodels[1]
        self.concept_classifier: ConceptNN = xmodels[2]
        self.n_discrete_anomaly_scores = n_discrete_anomaly_scores
        self.n_concepts = n_concepts
        self.epochs = epochs
        self.lr = lr
        self.wdk = wdk
        self.milestones = milestones
        self.milestone_alpha = milestone_alpha
        self.batch_size = batch_size
        self.oe = oe
        self.lamb_gen = lamb_gen
        self.lamb_asc = lamb_asc
        self.lamb_cyc = lamb_cyc
        self.lamb_conc = lamb_conc
        self.gen_every = gen_every
        self.disc_every = disc_every
        self.cluster_ncc = cluster_ncc
        self.logger: Logger = logger  # will usually be set later in :func:`xad.main.create_trainer`.

    def __setattr__(self, key, value):
        if key == 'logger' and value is not None and hasattr(self, 'logger') and self.logger is None:
            value.logsetup(self.__setup, nets=[self.generator, self.concept_classifier])  # log training setup
            del self.__setup
        super().__setattr__(key, value)

    def run(self, ad_model: torch.nn.Module, ad_feature_to_ascore: Callable, dataset: TorchvisionDataset,
            cset: set[int], clsstr: str, seed: int, workers: int = 4,
            load_models: str = None, ad_device: torch.device = torch.device('cuda:0'),
            tqdm_mininterval: float = 0.1, ad_trainer_parent=None) -> dict:
        device = ad_device
        if self.logger is None:
            raise ValueError('Logger has not been set.')
        self.logger.logsetup({'load': {int_set_to_str(cset): {seed: load_models}}, })

        def copy_models():
            generator = deepcopy(self.generator)
            generator.apply(weight_reset)
            discriminator = deepcopy(self.discriminator)
            discriminator.parameterize()
            discriminator.apply(weight_reset)
            concept_clf = deepcopy(self.concept_classifier)
            concept_clf.parameterize()
            concept_clf.apply(weight_reset)
            assert all([p.is_leaf for p in self.generator.parameters()])
            assert all([p.is_leaf for p in self.discriminator.parameters()])
            assert all([p.is_leaf for p in self.concept_classifier.parameters()])
            for n, p in generator.named_parameters():
                p.detach_().requires_grad_()  # otherwise jit models don't work due to grad_fn=clone_backward
            for n, p in discriminator.named_parameters():
                p.detach_().requires_grad_()  # otherwise jit models don't work due to grad_fn=clone_backward
            for n, p in concept_clf.named_parameters():
                p.detach_().requires_grad_()  # otherwise jit models don't work due to grad_fn=clone_backward
            return generator, discriminator, concept_clf

        eval_metrics = None
        for i in range(5):
            try:
                generator, discriminator, concept_clf = copy_models()
                generator, discriminator, concept_clf, metrics = self.train(
                    generator, discriminator, concept_clf, ad_model, ad_feature_to_ascore, dataset, cset, clsstr, seed,
                    workers, load_models, device, tqdm_mininterval
                )
                gc.collect()
                eval_metrics = self.eval(
                    generator, discriminator, concept_clf, ad_model, ad_feature_to_ascore, dataset, cset, clsstr, seed,
                    workers, device, tqdm_mininterval, ad_trainer_parent
                )
                gc.collect()
                break
            except NanGradientsError as err:  # try once more
                self.logger.warning(
                    f'Gradients got NaN for class {int_set_to_str(cset)} "{clsstr}" and seed {seed}. '
                    f'Happened {i} times so far. Try once more.'
                )
                if i == 3 - 1:
                    generator, concept_clf, metrics = None, None, None
                    self.logger.warning(
                        f'Gradients got NaN for class {int_set_to_str(cset)} "{clsstr}" and seed {seed}. '
                        f'Happened {i} times so far. Try no more. Set model and roc to None.'
                    )
        return eval_metrics

    def train(self, generator: ConditionalGenerator, discriminator: ConditionalDiscriminator, concept_clf: ConceptNN,
              ad_model: Module, ad_feature_to_ascore: Callable,
              ds: TorchvisionDataset, cset: set[int], cstr: str, seed: int,
              workers: int, load_models: str = None,
              device: torch.device = torch.device('cuda:0'),
              tqdm_mininterval: float = 0.1,
              ) -> Tuple[ConditionalGenerator, ConditionalDiscriminator, Module, Mapping]:
        # ---- prepare model and variables
        generator = generator.to(device).train()
        discriminator = discriminator.to(device).train()
        concept_clf.to(device).train()
        ad_model.to(device).eval()
        epochs = self.epochs
        metrics = {}  # TODO implement metrics
        n_log = 20
        cs = int_set_to_str(cset)

        # ---- optimizers and loaders
        opt = torch.optim.Adam(
            [{'params': generator.parameters()}, {'params': concept_clf.parameters()}],
            lr=self.lr, weight_decay=self.wdk, betas=(0.0, 0.9)
        )
        discopt = torch.optim.Adam(
            [{'params': discriminator.parameters()}],
            lr=self.lr, weight_decay=self.wdk, betas=(0.0, 0.9)
        )
        sched = torch.optim.lr_scheduler.MultiStepLR(opt, self.milestones, self.milestone_alpha)
        discsched = torch.optim.lr_scheduler.MultiStepLR(discopt, self.milestones, self.milestone_alpha)
        if isinstance(ds, CombinedDataset) and not self.oe:
            ds = ds.normal  # exclude Outlier Exposure
        loader, _ = ds.loaders(self.batch_size, num_workers=workers, persistent=True, device=device)

        # ---- cluster AD features
        if self.cluster_ncc:
            samples, features = [], []
            for imgs, lbls, idcs in tqdm(loader, desc='Collecting AD features...', mininterval=tqdm_mininterval):
                with torch.no_grad():
                    features.append(ad_model(imgs[lbls == ds.nominal_label], return_encoding=True)[1].cpu())
                    samples.append(imgs[lbls == ds.nominal_label].cpu())
            samples, features = torch.cat(samples), torch.cat(features)
            kmeans = sklearn.cluster.KMeans(self.n_concepts)
            cluster_ids = kmeans.fit_predict(features)
            clusters = torch.from_numpy(kmeans.cluster_centers_).to(device)
            self.logger.logimg(
                f'train-cluster__C{cs}-{cstr}-S{seed}',
                torch.cat([samples[cluster_ids == 0][:20], samples[cluster_ids == 1][:20]]),
                nrow=20
            )
            del samples, features, cluster_ids

        # ---- get test examples for logging each epoch
        def get_test_examples(n: int, shuffle: bool):
            tst_loader = iter(ds.loaders(
                self.batch_size, num_workers=0, persistent=False, device=device, shuffle_test=shuffle
            )[1])
            examples, example_labels, it = torch.Tensor([]).to(device), torch.Tensor([]).to(device), 0
            try:
                while examples.shape[0] < n:
                    imgs, lbls, idcs = next(tst_loader)
                    if (lbls == 1).sum() > 0:
                        examples = torch.cat([examples, imgs[lbls == 1][:n - examples.shape[0]]])
                        example_labels = torch.cat([example_labels, lbls[lbls == 1][:n - example_labels.shape[0]]])
                        if shuffle:
                            self.logger.logjson(
                                'misc', {
                                    f'test_examples_xtraining_it{it}': {f'cls_{cs}': {f'seed_{seed}': idcs}}
                                }, update=True
                            )
                        it += 1
            except StopIteration:
                pass
            del tst_loader
            with torch.no_grad():
                true_ascores = ad_feature_to_ascore(ad_model(examples))
            return examples, example_labels, true_ascores
        test_examples, test_example_labels, test_true_ascores = get_test_examples(n_log, True)

        def concat_train_examples(imgs, lbls, idcs, train_examples, train_example_labels, n: int, it: int):
            if train_examples.shape[0] < n:
                perm = torch.randperm(imgs.shape[0])
                train_examples = torch.cat([
                    train_examples, imgs[perm][:n - train_examples.shape[0]].detach().clone()
                ])
                self.logger.logjson(
                    'misc', {f'train_examples_xtraining_it{it}': {f'cls_{cs}': {f'seed_{seed}': idcs[perm][:n].cpu()}}},
                    update=True
                )
                train_example_labels = torch.cat([
                    train_example_labels, lbls[perm][:n - train_example_labels.shape[0]].detach().clone()
                ])
                with torch.no_grad():
                    train_true_ascores = ad_feature_to_ascore(ad_model(train_examples))
                return train_examples, train_example_labels, train_true_ascores
        train_examples, train_example_labels = torch.Tensor([]).to(device), torch.Tensor([]).to(device)

        # ---- prepare trackers and loggers
        reconloss, ascloss, genloss, discloss, conloss = None, None, None, None, None
        ep = self.load(load_models if isinstance(load_models, str) else None, generator, discriminator, concept_clf, opt, sched)
        start_ep = ep
        to_track = {
            'ep': lambda: f'{ep+1:{len(str(epochs))}d}/{epochs}',
            'recl': lambda: reconloss.item() if reconloss is not None else None,
            'ascl': lambda: ascloss.item() if ascloss is not None else None,
            'conl': lambda: conloss.item() if conloss is not None else None,
            'genl': lambda: genloss.item() if genloss is not None else None,
            'lr': lambda: sched.get_last_lr()
        }

        tracker = self.logger.track(
            [epochs, len(loader)], to_track, f'Xtraining {cstr}', mininterval=tqdm_mininterval
        )
        try:
            # ---- loop over epochs
            for ep in range(ep, epochs):

                # ---- loop over batches
                for it, (imgs, lbls, idcs) in enumerate(loader):
                    if train_examples.shape[0] < n_log:
                        train_examples, train_example_labels, train_true_ascores = concat_train_examples(
                            imgs, lbls, idcs, train_examples, train_example_labels, n_log, it
                        )
                    if len(lbls) <= 1:
                        continue

                    train_gen = it % self.gen_every == 0
                    train_disc = it % self.disc_every == 0

                    # ---- compute loss and optimize
                    if train_gen:
                        opt.zero_grad()

                        # reconstruction for true anomaly scores and random concepts
                        recon_concepts = torch.randint(self.n_concepts, torch.Size([imgs.shape[0], ]))
                        with torch.no_grad():
                            true_ascores = ad_feature_to_ascore(ad_model(imgs))  # TODO shift to the means of the discrete thresholds?
                        true_conditions = self.get_conditions(true_ascores, recon_concepts.to(device)).to(device)
                        rec_imgs = generator(imgs, true_conditions)
                        rec_recon_loss = huber_distance(imgs, rec_imgs).mean() * self.lamb_cyc
                        reconloss = rec_recon_loss

                        # for random scores and concepts, penalize if generated images have different scores
                        asc_ascores = torch.rand(imgs.shape[0])
                        asc_ascores = torch.from_numpy(
                            np.linspace(0, 1, self.n_discrete_anomaly_scores).astype(
                                np.float32
                            )[self.discretize_anomaly_scores(asc_ascores)]
                        )
                        # asc_ascores[(lbls == 0) * (asc_ascores == 0).to(device)] = 1.0  # exclude normal -> normal generation
                        # asc_ascores[(lbls == 1) * (asc_ascores == 1).to(device)] = 0.0  # exclude oe -> oe generation
                        asc_concepts = torch.randint(self.n_concepts, torch.Size([imgs.shape[0], ]))
                        asc_conditions = self.get_conditions(asc_ascores, asc_concepts).to(device)
                        asc_imgs = generator(imgs, asc_conditions)
                        asc_asc_loss = torch.nn.functional.binary_cross_entropy(
                            ad_feature_to_ascore(ad_model(asc_imgs)),
                            asc_ascores.to(device)
                        ) * self.lamb_asc
                        ascloss = asc_asc_loss
                        asc_disc_logits = discriminator(asc_imgs, self.to_ordinal_scores(asc_ascores).to(device))
                        asc_gen_loss = hinge_disc_loss(
                            asc_disc_logits, torch.ones(torch.Size((imgs.shape[0], ))).to(device)
                        ) * self.lamb_gen
                        genloss = asc_gen_loss

                        # reconstruction for true scores, random concepts, and generated images as conditions
                        cyc_concepts = asc_concepts
                        cyc_true_conditions = self.get_conditions(true_ascores, cyc_concepts.to(device)).to(device)
                        cyc_imgs = generator(asc_imgs, cyc_true_conditions)
                        cyc_recon_loss = huber_distance(imgs, cyc_imgs).mean() * self.lamb_cyc
                        reconloss += cyc_recon_loss
                        cyc_asc_loss = torch.nn.functional.binary_cross_entropy(
                            ad_feature_to_ascore(ad_model(cyc_imgs)),
                            true_ascores
                        ) * self.lamb_asc
                        ascloss += cyc_asc_loss

                        # concept disentanglement
                        if self.n_concepts >= 2:
                            con_org_asc_img_loss = cross_entropy(concept_clf(imgs, asc_imgs), asc_concepts.to(device))
                            con_asc_cyc_img_loss = cross_entropy(concept_clf(asc_imgs, cyc_imgs), cyc_concepts.to(device))
                            con_org_cyc_img_loss = cross_entropy(
                                concept_clf(imgs, cyc_imgs), torch.zeros_like(cyc_concepts).to(device)
                            )
                            con_org_rec_img_loss = cross_entropy(
                                concept_clf(imgs, rec_imgs), torch.zeros_like(recon_concepts).to(device)
                            )
                            conloss = con_org_asc_img_loss + con_asc_cyc_img_loss + con_org_cyc_img_loss + con_org_rec_img_loss
                            conloss = conloss * self.lamb_conc
                        else:
                            conloss = torch.FloatTensor([0]).to(ascloss.device)
                            con_org_asc_img_loss = con_asc_cyc_img_loss = con_org_cyc_img_loss = con_org_rec_img_loss = conloss
                        if self.cluster_ncc:
                            asc_imgs_features = ad_model(asc_imgs, return_encoding=True)[1]
                            cyc_imgs_features = ad_model(cyc_imgs, return_encoding=True)[1]
                            ad_features = torch.cat([asc_imgs_features, cyc_imgs_features])
                            ad_concepts = torch.cat([asc_concepts, cyc_concepts])
                            cluster_distances = torch.stack([huber_distance(ad_features, c).mean(1) for c in clusters], dim=1)
                            cluster_nll = torch.nn.functional.log_softmax(cluster_distances, dim=1).mul(-1)
                            cluster_nll = cluster_nll[np.arange(len(cluster_nll)), ad_concepts]
                            conloss = conloss + cluster_nll.mean() * self.lamb_conc

                        # together
                        loss = reconloss + ascloss + conloss + genloss
                        loss.backward()
                        opt.step()
                        opt.zero_grad()

                    # discriminator
                    if train_disc:
                        discopt.zero_grad()
                        if not train_gen:
                            with torch.no_grad():
                                true_ascores = ad_feature_to_ascore(ad_model(imgs))
                        org_disc_logits = discriminator(imgs, self.to_ordinal_scores(true_ascores).to(device))
                        org_disc_loss = hinge_disc_loss(
                            org_disc_logits, torch.ones(torch.Size((imgs.shape[0], ))).to(device)
                        ) * self.lamb_gen
                        if not train_gen:
                            with torch.no_grad():
                                asc_ascores = torch.rand(imgs.shape[0])
                                asc_ascores = torch.from_numpy(
                                    np.linspace(0, 1, self.n_discrete_anomaly_scores).astype(
                                        np.float32
                                    )[self.discretize_anomaly_scores(asc_ascores)]
                                )
                                asc_concepts = torch.randint(self.n_concepts, torch.Size([imgs.shape[0], ]))
                                asc_conditions = self.get_conditions(asc_ascores, asc_concepts).to(device)
                                asc_imgs = generator(imgs, asc_conditions)
                        asc_disc_logits = discriminator(asc_imgs.detach(), self.to_ordinal_scores(asc_ascores).to(device))
                        asc_disc_loss = hinge_disc_loss(
                            asc_disc_logits, torch.zeros(torch.Size((imgs.shape[0], ))).to(device)
                        ) * self.lamb_gen
                        discloss = org_disc_loss + asc_disc_loss
                        discloss.backward()
                        discopt.step()
                        discopt.zero_grad()

                    reconloss, ascloss, conloss, genloss = reconloss.cpu(), ascloss.cpu(), conloss.cpu(), genloss.cpu()
                    discloss = discloss.cpu()

                    # # ---- log stuff
                    if seed == 0:
                        self.logger.add_scalar(f'training__C{cs}-{cstr}__ep', ep, tracker.n)
                    self.logger.add_scalar(f'training__C{cs}-{cstr}-S{seed}__xloss', loss.item(), tracker.n,)
                    self.logger.add_scalar(f'training__C{cs}-{cstr}-S{seed}__rec_loss', reconloss.item(), tracker.n, )
                    self.logger.add_scalar(f'training__C{cs}-{cstr}-S{seed}__asc_loss', ascloss.item(), tracker.n, )
                    self.logger.add_scalar(f'training__C{cs}-{cstr}-S{seed}__con_loss', conloss.item(), tracker.n, )
                    self.logger.add_scalar(f'training__C{cs}-{cstr}-S{seed}__gen_loss', genloss.item(), tracker.n, )
                    self.logger.add_scalar(f'training__C{cs}-{cstr}-S{seed}__rec_rec_loss', rec_recon_loss.item(), tracker.n, )
                    self.logger.add_scalar(f'training__C{cs}-{cstr}-S{seed}__rec_cyc_loss', cyc_recon_loss.item(), tracker.n, )
                    self.logger.add_scalar(f'training__C{cs}-{cstr}-S{seed}__asc_asc_loss', asc_asc_loss.item(), tracker.n, )
                    self.logger.add_scalar(f'training__C{cs}-{cstr}-S{seed}__asc_cyc_loss', cyc_asc_loss.item(), tracker.n, )
                    self.logger.add_scalar(
                        f'training__C{cs}-{cstr}-S{seed}__con_org_asc_img_loss', con_org_asc_img_loss.item(), tracker.n,
                    )
                    self.logger.add_scalar(
                        f'training__C{cs}-{cstr}-S{seed}__con_asc_cyc_img_loss', con_asc_cyc_img_loss.item(), tracker.n,
                    )
                    self.logger.add_scalar(
                        f'training__C{cs}-{cstr}-S{seed}__con_org_cyc_img_loss', con_org_cyc_img_loss.item(), tracker.n,
                    )
                    self.logger.add_scalar(
                        f'training__C{cs}-{cstr}-S{seed}__con_org_rec_img_loss', con_org_rec_img_loss.item(), tracker.n,
                    )
                    self.logger.add_scalar(f'training__C{cs}-{cstr}-S{seed}__disc_loss', discloss.item(), tracker.n, )
                    if self.cluster_ncc:
                        self.logger.add_scalar(
                            f'training__C{cs}-{cstr}-S{seed}__con_cluster_loss', cluster_nll.mean().item(), tracker.n,
                        )
                    self.logger.flush()
                    tracker.update([0, 1])

                if reconloss.isnan().sum() > 0:
                    raise NanGradientsError()

                # ---- log epoch stuff
                if ep in [0, 1, 2, 4, 8, 16, 32] or ep % 50 == 0:
                    self.log_images(
                        generator, f'train__C{cs}-{cstr}-S{seed}', train_examples, train_example_labels, train_true_ascores,
                        ad_model, ad_feature_to_ascore, concept_clf, n_log, ep=ep
                    )
                    self.log_images(
                        generator, f'test__C{cs}-{cstr}-S{seed}', test_examples, test_example_labels, test_true_ascores,
                        ad_model, ad_feature_to_ascore, concept_clf, n_log, ep=ep
                    )
                self.logger.print(
                    f"{datetime.now().strftime('%Y/%m/%d %H:%M:%S')}: Xtraining for {cstr} (seed {seed}) "
                    f"epoch {ep:03d}/{epochs:03d} finished.", only_file=True
                )

                # ---- update tracker and scheduler
                sched.step()
                discsched.step()
                tracker.update([1, 0])

            # ---- log end-of-training stuff
            tracker.close()
            if train_examples is None or len(train_examples) == 0:
                train_loader = enumerate(iter(ds.loaders(self.batch_size, num_workers=0, persistent=False, device=device, )[0]))
                while len(train_examples) < n_log:
                    it, (imgs, lbls, idcs) = next(train_loader)
                    train_examples, train_example_labels, train_true_ascores = concat_train_examples(
                        imgs, lbls, idcs, train_examples, train_example_labels, n_log, it
                    )
                del train_loader
            self.log_images(
                generator, f'train__C{cs}-{cstr}-S{seed}', train_examples, train_example_labels, train_true_ascores,
                ad_model, ad_feature_to_ascore, concept_clf, n_log
            )
            self.log_images(
                generator, f'test__C{cs}-{cstr}-S{seed}', test_examples, test_example_labels, test_true_ascores,
                ad_model, ad_feature_to_ascore, concept_clf, n_log,
            )
            original_ids = deepcopy(ds.test_set.indices)
            tail = (np.asarray(original_ids).shape[0] % (n_log * 4))
            ds.test_set.indices = [
                i[len(i)//2]
                for i in np.split((np.asarray(original_ids)[:-tail]) if tail > 0 else np.asarray(original_ids), n_log * 4)
            ]
            test_examples, test_example_labels, test_true_ascores = get_test_examples(n_log * 2, False)
            self.log_images(
                generator, f'testFix__C{cs}-{cstr}-S{seed}', test_examples, test_example_labels, test_true_ascores,
                ad_model, ad_feature_to_ascore, concept_clf, n_log * 2,
            )
            ds.test_set.indices = original_ids
            self.logger.flush(timeout=20)
        except KeyboardInterrupt as err:
            raise err
        except Exception as err:
            self.logger.warning(''.join(traceback.format_exception(err.__class__, err, err.__traceback__)))
        finally:
            cs = int_set_to_str(cset)
            if epochs > 0 and start_ep != epochs:
                self.logger.snapshot(f'snapshot_generator_cls{cs}_it{seed}', generator, opt, sched, ep)
                self.logger.snapshot(f'snapshot_discriminator_cls{cs}_it{seed}', discriminator, opt, sched, ep)
                self.logger.snapshot(f'snapshot_concept-clf_cls{cs}_it{seed}', concept_clf, opt, sched, ep)
                if ep != epochs - 1:
                    self.logger.warning(f'An error occurred. Intermediate snapshots for epoch {ep}/{epochs} saved.')
                    if train_examples is not None:
                        self.log_images(
                            generator, f'train_EP{ep}__C{cs}-{cstr}-S{seed}', train_examples, train_example_labels,
                            train_true_ascores, ad_model, ad_feature_to_ascore, concept_clf, n_log
                        )
                    self.log_images(
                        generator, f'test_EP{ep}__C{cs}-{cstr}-S{seed}', test_examples, test_example_labels, test_true_ascores,
                        ad_model, ad_feature_to_ascore, concept_clf, n_log,
                    )
                    self.logger.flush(timeout=20)

        return generator, discriminator, concept_clf, metrics

    def discretize_anomaly_scores(self, anomaly_scores: torch.Tensor) -> torch.Tensor:
        return (torch.clip(anomaly_scores - 1e-7, 0) * self.n_discrete_anomaly_scores).int()  # -1e-7 -> 1.0 falls into last bin

    def to_ordinal_scores(self, anomaly_scores: torch.Tensor) -> torch.Tensor:
        """
        @param anomaly_scores: FloatTensor in [0, 1] with shape [BATCH_SIZE].
        @returns: transformed scores to ordinal scores. See https://en.wikipedia.org/wiki/Ordinal_regression.
            In short, discretizes the anomaly scores to be in {0, ..., :attr:`self.n_discrete_anomaly_scores`}.
            Then, creates a binary vector for each score with :attr:`self.n_discrete_anomaly_scores` many elements
            and element i being 1 iff discretized_score >= i.
            We use this for conditioning the discriminator.
        """
        discrete = self.discretize_anomaly_scores(anomaly_scores)
        ordinal = torch.zeros((discrete.shape[0], self.n_discrete_anomaly_scores - 1)).int()
        for i in range(1, self.n_discrete_anomaly_scores):
            ordinal[discrete >= i, i - 1] = 1
        return ordinal

    def get_conditions(self, anomaly_scores: torch.Tensor, concepts: torch.Tensor) -> torch.IntTensor:
        """
        @param: target anomaly scores for the condition.
            Can be continuous (FloatTensor in [0, 1]) or
            discretized (IntTensor in {0, ..., :attr:`self.n_discrete_anomaly_scores`}).
        @param: concepts: concept IntTensor in {0, ..., :attr:`self.n_concepts`}).
        @returns: overall condition merging the target anomaly score and the target concept.
            IntTensor in {0, ..., :attr:`self.n_discrete_anomaly_scores` * :attr:`self.n_concepts`} with shape [BATCH_SIZE];
            Therefore, each condition is merely an int.
        """
        # assert 0 <= anomaly_scores <= 1, f'Anomaly score {anomaly_scores} out of range!'
        with torch.no_grad():
            if not isinstance(anomaly_scores, (torch.IntTensor, torch.LongTensor)):
                a = self.discretize_anomaly_scores(anomaly_scores)
            else:
                a = anomaly_scores
            cond = a * self.n_concepts + concepts
        return cond

    def log_images(self, generator: ConditionalGenerator, name: str, examples: torch.Tensor, true_labels: torch.Tensor,
                   true_ascores: torch.Tensor, ad_model: Module, ad_feature_to_ascore: Callable, concept_clf: ConceptNN,
                   n_log: int, ep: int = None):
        is_intermediate = ep is not None
        name = name if not is_intermediate else pt.join("intermediate", f"{name}_ep{ep:03d}")
        generator.eval()
        concept_clf.eval()
        with torch.no_grad():
            # --- overview of counterfactuals with all concepts and 2-3 anomaly score conditions
            threescores = self.n_discrete_anomaly_scores >= 3
            if threescores:
                ascores = [torch.zeros(examples.shape[0]), torch.ones(examples.shape[0]) * 0.5, torch.ones(examples.shape[0]), ]
            else:
                ascores = [torch.zeros(examples.shape[0]), torch.ones(examples.shape[0]), ]
            concepts = [torch.ones(examples.shape[0]).int() * c for c in range(self.n_concepts)]
            images = []
            predicted_scores = [true_ascores.cpu()]
            predicted_concepts = [concept_clf(examples, examples).cpu()]
            target_scores, target_concepts = [predicted_scores[0]], [predicted_concepts[0].argmax(1)]
            for c in concepts:
                for a in ascores:
                    img = generator(examples, self.get_conditions(a, c).to(examples.device))
                    predicted_scores.append(ad_feature_to_ascore(ad_model(img)).cpu())
                    predicted_concepts.append(concept_clf(examples, img).cpu())
                    images.append(img.cpu())
                    target_scores.append(a.cpu())
                    target_concepts.append(c.cpu())
            predicted_scores, predicted_concepts = torch.cat(predicted_scores), torch.cat(predicted_concepts)
            target_scores, target_concepts = torch.cat(target_scores), torch.cat(target_concepts)

            titles, title_colors = [], []
            zped = zip(predicted_scores, target_scores, predicted_concepts, target_concepts)
            for i, (a, ta, c, tc) in enumerate(zped):
                c = c.argmax()
                titles.append(f'^a{a:0.2f}_c{c}')
                if i < n_log:
                    title_colors.append('olive' if true_labels[i] == 0 else 'saddlebrown')
                elif c != tc or self.discretize_anomaly_scores(a) != self.discretize_anomaly_scores(ta):
                    title_colors.append('red')
                else:
                    title_colors.append('green')

            this_name = pt.join('conditional_generation', name)
            self.logger.log_fig(
                torch.cat([examples.cpu(), *images, ]), this_name,
                nrow=examples.shape[0],
                rowheaders=[
                    'original', *[
                        f'a={a} c={c}' for c in range(self.n_concepts)
                        for a in (["0.0", "0.5", "1.0"] if threescores else ["0.0", "1.0"])
                    ]
                ],
                titles=titles, title_colors=title_colors,
                maxres=28 if is_intermediate else examples.shape[-1],
                pdf=not is_intermediate,
            )

            # ---- version with rows for target anomaly score being 0 only
            if not is_intermediate:
                this_name = pt.join('conditional_generation', f'{name}_3rows')
                self.logger.log_fig(
                    torch.cat([
                        examples.cpu(), *[
                            imgs for i, imgs in enumerate(images) if (target_scores[examples.shape[0]*(i+1)] == 0)
                        ],
                    ]),
                    this_name,
                    nrow=examples.shape[0],
                    rowheaders=['original', *[f'a={a} c={c}' for c in range(self.n_concepts) for a in (["0.0"])]],
                    titles=[tit for i, tit in enumerate(titles) if (i < examples.shape[0] or target_scores[i] == 0)],
                    title_colors=[col for i, col in enumerate(title_colors) if (i < examples.shape[0] or target_scores[i] == 0)],
                )

            # ---- raw image tensor with rows for target anomaly score being 0 only
            if not is_intermediate:
                self.logger.log_tensor(torch.stack([
                        examples.cpu(), *[
                            imgs for i, imgs in enumerate(images) if (target_scores[examples.shape[0]*(i+1)] == 0)
                        ],
                    ]), this_name
                )

            # --- reconstruction with true anomaly score and random concept
            target_scores = true_ascores
            target_concepts = torch.cat([torch.ones(examples.shape[0]).int() * c for c in range(self.n_concepts)])
            target_conditions = self.get_conditions(target_scores.repeat(self.n_concepts), target_concepts.to(examples.device))
            img = generator(examples.repeat(self.n_concepts, *([1] * (examples.dim() - 1))), target_conditions)
            predicted_scores = torch.cat([target_scores, (ad_feature_to_ascore(ad_model(img)))])
            predicted_concepts = concept_clf(
                examples.repeat(self.n_concepts + 1, *([1] * (examples.dim() - 1))), torch.cat([examples, img])
            ).argmax(1)
            this_name = pt.join('reconstruction', name)
            self.logger.log_fig(
                torch.cat([examples, img]).cpu(),
                this_name,
                nrow=examples.shape[0],
                rowheaders=[
                    'original', *[f'a`=a c={c}' for c in range(self.n_concepts)],
                ],
                titles=[f'^a{a:0.2f}_c{c}' for a, c in zip(predicted_scores, predicted_concepts)],
                title_colors=[
                    *['olive' if lbl == 0 else 'saddlebrown' for lbl in true_labels],
                    *[
                        'red' if c != tc or self.discretize_anomaly_scores(a) != self.discretize_anomaly_scores(ta) else 'green'
                        for a, c, ta, tc in zip(
                            predicted_scores[examples.shape[0]:], predicted_concepts[examples.shape[0]:],
                            target_scores.repeat(self.n_concepts), target_concepts
                        )
                    ],
                ],
                maxres=28 if is_intermediate else examples.shape[-1],
                pdf=not is_intermediate,
            )

            # --- cycle image generation
            # first row: original
            # second row: generation conditioned on original, "opposite" ascores, and random concepts
            # further rows for each concept: generation conditioned on second row and true ascores of original (--> cyc images)
            # Note that the reported predicted concepts always use the original as a reference in the concept classifier
            if 'train_' in name:
                target_scores = (1 - true_labels)
                target_concepts = torch.randint(self.n_concepts, torch.Size([examples.shape[0], ])).to(examples.device)
                target_conditions = self.get_conditions(target_scores, target_concepts)
                asc_img = generator(examples, target_conditions)
                cyc_scores = true_ascores.repeat(self.n_concepts)
                cyc_concepts = torch.cat([torch.ones(examples.shape[0]).int() * c for c in range(self.n_concepts)])
                cyc_conditions = self.get_conditions(cyc_scores, cyc_concepts.to(examples.device))
                cyc_img = generator(asc_img.repeat(self.n_concepts, *([1] * (asc_img.dim() - 1))), cyc_conditions)
                predicted_scores = torch.cat([true_ascores, (ad_feature_to_ascore(ad_model(torch.cat([asc_img, cyc_img]))))])
                predicted_concepts = concept_clf(
                    examples.repeat(2 + self.n_concepts, *([1] * (asc_img.dim() - 1))), torch.cat([examples, asc_img, cyc_img])
                ).argmax(1)
                this_name = pt.join('cycle_generation', name)
                self.logger.log_fig(
                    torch.cat([examples.cpu(), asc_img.cpu(), cyc_img.cpu()]),
                    this_name,
                    nrow=examples.shape[0],
                    rowheaders=[
                        'original', f'a`=(1-a) c=?',
                        *[f'a`=a c={c}' for c in range(self.n_concepts)],
                    ],
                    titles=[f'^a{a:0.2f}_c{c}' for a, c in zip(predicted_scores, predicted_concepts)],
                    title_colors=[
                        *['olive' if lbl == 0 else 'saddlebrown' for lbl in true_labels],
                        *[
                            'red' if c != tc or self.discretize_anomaly_scores(a) != self.discretize_anomaly_scores(ta)
                            else 'green'
                            for a, c, ta, tc in zip(
                                predicted_scores[examples.shape[0]:], predicted_concepts[examples.shape[0]:],
                                target_scores, target_concepts
                            )
                        ],
                        *[
                            'red' if c != tc or self.discretize_anomaly_scores(a) != self.discretize_anomaly_scores(ta)
                            else 'green'
                            for a, c, ta, tc in zip(
                                predicted_scores[examples.shape[0]*2:], predicted_concepts[examples.shape[0]*2:],
                                cyc_scores, cyc_concepts
                            )
                        ]
                    ],
                    maxres=28 if is_intermediate else examples.shape[-1],
                    pdf=not is_intermediate,
                )

        generator.train()
        concept_clf.train()

    def load(self, path: str, generator: ConditionalGenerator, discriminator: ConditionalDiscriminator,
             concept_clf: torch.nn.Module, opt: torch.optim.Optimizer = None, sched: _LRScheduler = None) -> int:
        """
        Loads a snapshot of the model including training state.
        @param path: the filepath where the snapshot is stored.
        @param generator: the model instance into which the parameters of the found generator snapshot are loaded.
            Hence, the architectures need to match.
        @param discriminator: see generator.
        @param concept_clf: see generator.
        @param opt: the optimizer instance into which the training state is loaded.
        @param sched: the learning rate scheduler into which the training state is loaded.
        @return: the last epoch with which the snapshot's model was trained.
        """
        epoch = 0
        if path is not None:
            msg = ""
            snapshot = torch.load(path)
            net_state = snapshot.pop('net', None)
            opt_state = snapshot.pop('opt', None)
            sched_state = snapshot.pop('sched', None)
            epoch = snapshot.pop('epoch', -1) + 1
            if net_state is not None:
                generator.load_state_dict(net_state)
            if opt_state is not None and opt is not None:
                opt.load_state_dict(opt_state)
            if sched_state is not None and sched is not None:
                sched.load_state_dict(sched_state)
            if pt.exists(path.replace('_generator_', '_discriminator_')) and discriminator is not None:
                snapshot = torch.load(path.replace('_generator_', '_discriminator_'))
                discriminator.load_state_dict(snapshot.pop('net', None))
                msg += 'Also loaded discriminator. '
            if pt.exists(path.replace('_generator_', '_concept-clf_')) and concept_clf is not None:
                snapshot = torch.load(path.replace('_generator_', '_concept-clf_'))
                concept_clf.load_state_dict(snapshot.pop('net', None))
                msg += 'Also loaded concept classifier. '
            self.logger.print(f'Loaded counterfactual generator snapshot at epoch {epoch}. {msg}')
        return epoch

    def eval(self, generator: ConditionalGenerator, discriminator: ConditionalDiscriminator, concept_clf: ConceptNN,
              ad_model: Module, ad_feature_to_ascore: Callable,
              ds: TorchvisionDataset, cset: set[int], cstr: str, seed: int,
              workers: int, device: torch.device = torch.device('cuda:0'),
              tqdm_mininterval: float = 0.1, ad_trainer_parent = None) -> Mapping:
        """
        @param generator: Trained generator that takes a test image and a target condition
            (made of concept and target anomaly score, see :method:`self.get_conditions`).
            Outputs a new image that is similar to the test image but hopefully
            aligns with the conditioned concept and anomaly score. If the test image is anomalous and the conditioned
            anomaly score 0, the generate image is a counterfactual example.
        @param discriminator: Trained discriminator (GAN style). Takes a test image and a supposed anomaly score.
            The latter in ordinal format (see :method:`self.to_ordinal_scores`). Returns fake (value 0) or true (value 1).
            Note that the discriminator may use the supposed anomaly score, so if the image is not aligned with the
            score it might detect this, but it also generally rates the realisticity of the image as in typical GANs.
        @param concept_clf: Concept classifier. Takes two images. The second one is supposed to be a generated version
            of the first one with the concept changed to c. The classifier predicts this c.
        @param ad_model: Takes an image and returns the anomaly feature vector.
        @param ad_feature_to_ascore: Takes an anomaly feature vector and returns an anomaly score.
        @param ds: The dataset.
        @param cset: The set of normal class ids.
        @param cstr: The string representation of cset.
        @param seed: Determines the current iteration of the training (e.g., 3 for the third random seed).
        @param workers: number of workers for the data loading.
        @param device: torch device for evaluation.
        """
        generator.eval()
        discriminator.eval()
        concept_clf.eval()
        _, loader = ds.loaders(self.batch_size, num_workers=workers, shuffle_test=False, device=device, )

        procbar = tqdm(desc=f'Xeval {cstr}', total=len(loader), mininterval=tqdm_mininterval)
        ascores_true_anomalies = []
        ascores_true_normals = []
        ascores_counterfactuals = []
        concepts_ground_truth = []
        concepts_predictions = []
        
        for it, (imgs, lbls, idcs) in enumerate(loader):
            with torch.no_grad():
                orig_anomaly_features = ad_model(imgs)
                orig_true_anomaly_scores = ad_feature_to_ascore(orig_anomaly_features)

                # get generated images for target anomaly scores being 0 (normal) and all concepts
                concepts = torch.arange(self.n_concepts).repeat_interleave(imgs.size(0))  # 0,0,...,1,1,...,2,2,...,....
                normal_ascores = torch.zeros_like(orig_true_anomaly_scores).repeat(self.n_concepts)
                normal_conditions = self.get_conditions(normal_ascores, concepts.to(device)).to(device)
                repeated_images = imgs.repeat(self.n_concepts, *(1 for _ in range(imgs.ndim - 1)))
                repeated_labels = lbls.repeat(self.n_concepts)
                counterfactual_examples = generator(repeated_images, normal_conditions)
                counterfactual_anomaly_features = ad_model(counterfactual_examples)
                counterfactual_ascores = ad_feature_to_ascore(counterfactual_anomaly_features)
                predicted_realisticity_logits = discriminator(repeated_images, self.to_ordinal_scores(normal_ascores).to(device))
                predicted_concepts_logits = concept_clf(repeated_images, counterfactual_examples)
                predicted_concepts = predicted_concepts_logits.max(dim=1)[1]
                true_normal_imgs = imgs[lbls == 0]
                true_anomalous_imgs = imgs[lbls == 1]
                counterfactuals_from_anomaly = counterfactual_examples[repeated_labels == 1]
                random_batch_splits = random_split_tensor(input=imgs,
                                                          chunks=2,
                                                          device=device) 
                
                concepts_ground_truth.append(concepts)
                concepts_predictions.append(predicted_concepts)
                ascores_true_normals.append(orig_true_anomaly_scores[lbls == 0].cpu())
                ascores_true_anomalies.append(orig_true_anomaly_scores[lbls == 1].cpu())
                ascores_counterfactuals.append(counterfactual_ascores[repeated_labels == 1].cpu())
                
                self.logger.log_batch_img(tensor=true_normal_imgs,
                                          foldername='actual_imgs',
                                          prefix=f'{it}_')
                self.logger.log_batch_img(tensor=true_anomalous_imgs,
                                          foldername='anomalous_imgs',
                                          prefix=f'{it}_')
                self.logger.log_batch_img(tensor=counterfactuals_from_anomaly,
                                          foldername='counterfactual_imgs',
                                          prefix=f'{it}_')
                self.logger.log_batch_img(tensor=random_batch_splits[0],
                                          foldername='test_subset_1',
                                          prefix=f'{it}_')
                self.logger.log_batch_img(tensor=random_batch_splits[1],
                                          foldername='test_subset_2',
                                          prefix=f'{it}_')
                
            procbar.update()
        procbar.close()

        # -------------- CORRECTNESS ------------
        ascores_true_normals = torch.cat(ascores_true_normals)
        ascores_true_anomalies = torch.cat(ascores_true_anomalies)
        ascores_counterfactuals = torch.cat(ascores_counterfactuals)
        roc_norm_vs_counterfact = get_roc(ascores_true_normals, ascores_counterfactuals)
        roc_norm_vs_anomalous = get_roc(ascores_true_normals, ascores_true_anomalies)
        auc_norm_vs_counterfact_normalized = 1 - 2 * (roc_norm_vs_counterfact.auc - 0.5)

        self.logger.print(f"Average anomaly score for the normal samples: {ascores_true_normals.mean():.6f}")
        self.logger.print(f"Average anomaly score for the anomalous samples: {ascores_true_anomalies.mean():.6f}")
        self.logger.print(
            f"Average anomaly score for the counterfactual samples, generated from truly anomalous samples"
            f" {ascores_counterfactuals.mean():.6f}"
        )
        self.logger.print(f"AUC for normal samples vs. couterfactuals: {roc_norm_vs_counterfact.auc:0.3f}")
        self.logger.print(f"AUC for true normal vs. true anomalies: {roc_norm_vs_anomalous.auc:0.3f}")
        self.logger.print(
            f"AUC for true normal vs. couterfactuals, reversed and normalized "
            f"(usually in [0,1] with 1 being best at an AUC of 0.5, might move up to 2 when AUC is worse than 0.5):"
            f" {auc_norm_vs_counterfact_normalized:0.3f}"
        )
        self.logger.logjson('results', {
            'avg_ascore_normal': {f'cls_{int_set_to_str(cset)}': {
                f'seed_{seed}': ascores_true_normals.mean().item()}
            },
            'avg_ascore_anomalous': {f'cls_{int_set_to_str(cset)}': {
                f'seed_{seed}': ascores_true_anomalies.mean().item()}
            },
            'avg_ascore_counterfactual': {f'cls_{int_set_to_str(cset)}': {
                f'seed_{seed}': ascores_counterfactuals.mean().item()}
            },
            'auc_normal_anomalous': {f'cls_{int_set_to_str(cset)}': {f'seed_{seed}': roc_norm_vs_anomalous.auc}},
            'auc_normal_counterfactual': {f'cls_{int_set_to_str(cset)}': {f'seed_{seed}': roc_norm_vs_counterfact.auc}},
            'auc_normal_counterfactual_normalized': {
                f'cls_{int_set_to_str(cset)}': {f'seed_{seed}': auc_norm_vs_counterfact_normalized}
            },
        }, update=True)
        self.logger.plot_many(
            [roc_norm_vs_anomalous, roc_norm_vs_counterfact],
            [f'normal_vs_anomalous', 'normal_vs_counterfactual'], f"rocs_{int_set_to_str(cset)}_seed{seed}",
            False
        )
        self.logger.hist_ascores(
            [ascores_true_normals, ascores_true_anomalies, ascores_counterfactuals, ],
            ['test_normal', 'test_anomalous', 'test_counterfactual', ],
            f'C{int_set_to_str(cset)}-{cstr}-S{seed}__anomaly_scores',
            colors=['green', 'red', 'blue'],
        )

        # -------------- REALISM ------------
        true_normal_path = pt.join(self.logger.dir, 'actual_imgs')
        cf_path = pt.join(self.logger.dir, 'counterfactual_imgs')
        anamalous_path = pt.join(self.logger.dir, 'anomalous_imgs')
        test_subset_1_path = pt.join(self.logger.dir, 'test_subset_1')
        test_subset_2_path = pt.join(self.logger.dir, 'test_subset_2')
        
        self.logger.print("Computing FID scores (normal, anomalous)")
        fid_score, fid_score_lower_bound, fid_score_upper_bound = compute_fid_scores(
            path_actual_imgs=true_normal_path, path_counterfactual_imgs=cf_path,
            path_anamalous_imgs=anamalous_path, path_test_subset_1=test_subset_1_path,
            path_test_subset_2=test_subset_2_path, device=device, xtrainer=self,
            cstr=cstr,  seed=seed
        )
        normalized_fid_score = 1 - (fid_score - fid_score_lower_bound) / (fid_score_upper_bound - fid_score_lower_bound)
        fid_score_paper = fid_score / fid_score_upper_bound

        self.logger.print(f"FID score for true normal samples vs. counterfactuals: {fid_score:.6f}")
        self.logger.print(f"FID score for true nomal samples vs. true anomalous (upper bound): {fid_score_upper_bound:.6f}")
        self.logger.print(f"FID score for a random half testset vs. the remaining half (lower bound): {fid_score_lower_bound:.6f}")
        self.logger.print(
            f"FID score for true normal samples vs. counterfactuals, "
            f"reversed and normalized with bounds "
            f"(in [0,1] as long as stays within bounds, 1 is best as it meets lower bound, "
            f"above 1 when below lower bound, below 0 when exceeds upper bound): "
            f"{normalized_fid_score:.6f}"
        )
        self.logger.print(f"FID score paper: {fid_score_lower_bound:.6f}")
        self.logger.logjson('results', {
            'fid_score_true_normal_vs_counterfactual': {f'cls_{int_set_to_str(cset)}': {
                f'seed_{seed}': fid_score}
            },
            'fid_score_true_normal_vs_anamalous': {f'cls_{int_set_to_str(cset)}': {
                f'seed_{seed}': fid_score_upper_bound}
            },
            'fid_score_first_half_vs_second_half': {f'cls_{int_set_to_str(cset)}': {
                f'seed_{seed}': fid_score_lower_bound}
            },
            'fid_score_true_normal_vs_counterfactual_normalized': {f'cls_{int_set_to_str(cset)}': {
                f'seed_{seed}': normalized_fid_score}
            },
            'fid_score_paper': {f'cls_{int_set_to_str(cset)}': {
                f'seed_{seed}': fid_score_paper}
            },
        }, update=True)

        # -------------- DISENTANGLEMENT ------------
        concepts_ground_truth = torch.cat(concepts_ground_truth).numpy()
        concepts_predictions = torch.cat(concepts_predictions).cpu().numpy()
        acc_concept_clf = accuracy_score(concepts_ground_truth, concepts_predictions)
        self.logger.print(f"Accuracy of the concept classifier over the generated samples: {acc_concept_clf:.6f}")
        self.logger.logjson('results', {
            'concept_clf_accuracy': {f'cls_{int_set_to_str(cset)}': {
                f'seed_{seed}': acc_concept_clf}
            },
        }, update=True)

        # -------------- SUBSTITUTABILITY ------------
        def reset_model(old_ad_model):
            model = deepcopy(old_ad_model)
            model.apply(weight_reset)
            return model

        sub_auc = None
        if ad_trainer_parent is not None:
            if isinstance(ds, CombinedDataset):
                train_normal_ids = ds.normal.train_set.indices
                train_ds = ds.normal.train_set.dataset
                train_subset = ds.normal.train_set
            else:
                train_normal_ids = ds.train_set.indices
                train_ds = ds.train_set.dataset
                train_subset = ds.train_set
            original_size = len(train_ds)
            if hasattr(train_ds, 'data'):
                attr = 'data'
            elif hasattr(train_ds, 'samples'):
                attr = 'samples'
            elif hasattr(train_ds, 'imgs'):
                attr = 'imgs'
            elif hasattr(train_ds, 'images'):
                attr = 'images'
            else:
                attr = None
                self.logger.warning(
                    f"CF evaluation. Substitutability evaluation failed. "
                    f"The data in dataset {ds} has none of the known attribute names."
                )
            if attr is not None:
                cf_imgs = None
                if not hasattr(train_ds, 'targets'):
                    self.logger.warning(f"CF evaluation. Substitutability evaluation failed. Dataset {ds} has no targets!")
                if isinstance(train_ds.__getattribute__(attr), str):  # ?
                    self.logger.warning(f"CF evaluation. Substitutability evaluation failed. Dataset format not implemented!")
                elif isinstance(train_ds.__getattribute__(attr), np.ndarray):
                    if isinstance(train_ds.__getattribute__(attr)[0], str):  # e.g., INN
                        cf_samples = []
                        train_ds.targets = np.asarray(train_ds.targets)
                        normal_label = train_ds.targets[train_normal_ids][0]
                        for img_path in os.listdir(cf_path):
                            cf_samples.append(pt.join(cf_path, img_path))
                        train_ds.__setattr__(
                            attr, np.concatenate([train_ds.__getattribute__(attr), cf_samples])
                        )
                        train_ds.__setattr__(
                            'imgs', np.concatenate([train_ds.__getattribute__(attr), cf_samples])  # imgs == samples...
                        )
                        train_ds.targets = np.concatenate([
                            train_ds.targets, train_ds.targets[train_normal_ids][0].repeat(len(cf_samples))
                        ])
                        cf_imgs = np.asarray(cf_samples)
                    else:  # e.g., CIFAR-10
                        cf_imgs = []
                        for img in os.listdir(cf_path):
                            cf_imgs.append(np.asarray(default_loader(pt.join(cf_path, img))))
                        cf_imgs = np.stack(cf_imgs)
                        if train_ds.__getattribute__(attr).ndim == 3:
                            cf_imgs = cf_imgs[:, :, :, 0]
                        if train_ds.__getattribute__(attr).dtype != np.uint8:
                            cf_imgs = cf_imgs.astype(np.float32) / 255
                        if cf_imgs.shape[1:] != train_ds.__getattribute__(attr).shape[1:]:
                            self.logger.warning(f"CF evaluation. Substitutability evaluation failed. Dataset data shape doesn't match!")
                        train_ds.__setattr__(
                            attr, np.concatenate([train_ds.__getattribute__(attr), cf_imgs])
                        )
                        train_ds.targets = np.asarray(train_ds.targets)
                        train_ds.targets = np.concatenate([
                            train_ds.targets, train_ds.targets[train_normal_ids][0].repeat(cf_imgs.shape[0])
                        ])
                elif isinstance(train_ds.__getattribute__(attr), torch.Tensor):  # e.g., MNIST
                    cf_imgs = []
                    for img in os.listdir(cf_path):
                        cf_imgs.append(to_tensor(default_loader(pt.join(cf_path, img))))
                    cf_imgs = torch.stack(cf_imgs)
                    if train_ds.__getattribute__(attr).dim() == 3:
                        cf_imgs = cf_imgs[:, 0, :, :]
                    if train_ds.__getattribute__(attr).dtype == torch.uint8:
                        cf_imgs = cf_imgs.mul(255).type(torch.uint8)
                    if cf_imgs.shape[-1] == 224 and train_ds.__getattribute__(attr).shape[-1] == 256:  # RandomCrop
                        cf_imgs = nn.functional.interpolate(cf_imgs, (256, 256), mode='bilinear')
                    if cf_imgs.shape[1:] != train_ds.__getattribute__(attr).shape[1:]:
                        self.logger.warning(f"CF evaluation. Substitutability evaluation failed. Dataset data shape doesn't match!")
                    else:
                        train_ds.__setattr__(
                            attr, torch.cat([train_ds.__getattribute__(attr), cf_imgs])
                        )
                        train_ds.targets = torch.cat([
                            train_ds.targets, train_ds.targets[train_normal_ids][0].repeat(cf_imgs.shape[0]).detach()
                        ])
                elif isinstance(train_ds.__getattribute__(attr), List):  # e.g., ImageNet.samples: List[Tuple[str, int]]
                    if (
                        not isinstance(train_ds.__getattribute__(attr)[0], Tuple)
                        or not isinstance(train_ds.__getattribute__(attr)[0][0], str)
                        or not isinstance(train_ds.__getattribute__(attr)[0][1], int)
                    ):
                        self.logger.warning(f"CF evaluation. Substitutability evaluation failed. Attribute type is unknown (samples, but not List[Tuple[str, int]])!")
                        cf_imgs = None
                    else:
                        cf_samples = []
                        train_ds.targets = np.asarray(train_ds.targets)
                        normal_label = train_ds.targets[train_normal_ids][0]
                        for img_path in os.listdir(cf_path):
                            cf_samples.append(((pt.join(cf_path, img_path)), normal_label))
                        train_ds.__setattr__(
                            attr, train_ds.__getattribute__(attr) + cf_samples
                        )
                        train_ds.__setattr__(
                            'imgs', train_ds.__getattribute__(attr) + cf_samples  # imgs == samples...
                        )
                        train_ds.targets = np.concatenate([
                            train_ds.targets, train_ds.targets[train_normal_ids][0].repeat(len(cf_samples))
                        ])
                        cf_imgs = np.asarray(cf_samples)
                else:
                    self.logger.warning(
                        f"CF evaluation. Substitutability evaluation failed. "
                        f"The data in dataset {ds}.{attr} is of unknown type."
                    )
                if cf_imgs is not None:
                    train_subset.indices = list(range(original_size, original_size + cf_imgs.shape[0]))
                    new_ad_model = reset_model(ad_model)
                    _, roc, train_labels, train_ascores = ad_trainer_parent.train_cls(
                        new_ad_model, ds, cset, cstr + "-CFSubstitution", seed, None, tqdm_mininterval, logging=False
                    )
                    roc_subst, prc, eval_labels, eval_ascores = ad_trainer_parent.eval_cls(
                        new_ad_model, ds, cset, cstr + "-CFSubstitution", seed, tqdm_mininterval, logging=False
                    )
                    auc_subst_normalized = (roc_subst.auc - 0.5) / (roc_norm_vs_anomalous.auc - 0.5)
                    self.logger.print(
                        f"AD test AuROC with normal training set substituted with counterfactuals: {roc_subst.auc:.6f}"
                    )
                    self.logger.print(
                        f"AD test AuROC, normalized, with normal training set substituted with counterfactuals "
                        f"(usually in [0, 1] with 1 being best, larger than 1 when substituted AUC exceeds original, "
                        f"becomes 0 when substituted AUC is 0.5, less than 0 when substituted AUC is less than 0.5): "
                        f"{auc_subst_normalized:.6f}"
                    )

                    # train_subset.indices = train_normal_ids + list(range(original_size, original_size + cf_imgs.shape[0]))
                    # new_ad_model = reset_model(ad_model)
                    # _, roc, train_labels, train_ascores = ad_trainer_parent.train_cls(
                    #     new_ad_model, ds, cset, cstr + "-CFExtension", seed, None, tqdm_mininterval, logging=False
                    # )
                    # roc_extnd, prc, eval_labels, eval_ascores = ad_trainer_parent.eval_cls(
                    #     new_ad_model, ds, cset, cstr + "-CFExtension", seed, tqdm_mininterval, logging=False
                    # )
                    # auc_extnd_normalized = (roc_extnd.auc - 0.5) / (roc_norm_vs_anomalous.auc - 0.5)
                    # self.logger.print(f"AD test AuROC with extended normal training set using counterfactuals: {roc_extnd.auc:.6f}")
                    # self.logger.print(
                    #     f"AD test AuROC, normalized, with normal training set extended with counterfactuals "
                    #     f"(usually in [0, 1] with 1 being best, larger than 1 when extended AUC exceeds original, "
                    #     f"becomes 0 when extended AUC is 0.5, less than 0 when extended AUC is less than 0.5): "
                    #     f"{auc_extnd_normalized:.6f}"
                    # )

                    sub_auc = roc_subst.auc
                    self.logger.logjson('results', {
                        'substitutability': {f'cls_{int_set_to_str(cset)}': {
                            f'seed_{seed}': {
                                'substituted': roc_subst.auc,
                                # 'extended': roc_extnd.auc,
                                'substituted_normalized': auc_subst_normalized,
                                # 'extended_normalized': auc_extnd_normalized
                            }}
                        },
                    }, update=True)

        self.logger.print("Cleaning temporarily saved generated images off the disk...")
        shutil.rmtree(true_normal_path)
        shutil.rmtree(cf_path)
        shutil.rmtree(anamalous_path)
        shutil.rmtree(test_subset_1_path)
        shutil.rmtree(test_subset_2_path)
        generator.train()
        discriminator.train()
        concept_clf.train()

        jsonfile = self.logger.logjson('results', {
            'summary': {f'cls_{int_set_to_str(cset)}': {
                f'seed_{seed}': {
                    'fid_cf_div_anom': fid_score_paper,
                    'cf_vs_normal_auc': roc_norm_vs_counterfact.auc,
                    'substituted_ad_auc': sub_auc,
                    'concept_clf_acc': acc_concept_clf
                }
            }},
        }, update=True)

        with open(jsonfile, 'r') as reader:
            res = json.load(reader)
        criteria = ("fid_cf_div_anom", "cf_vs_normal_auc", "substituted_ad_auc", "concept_clf_acc")
        for crit in criteria:
            for ckey in res['summary']:
                if ckey in criteria:
                    continue
                res['summary'][ckey][crit] = np.mean([res['summary'][ckey][skey][crit] for skey in res['summary'][ckey] if skey not in criteria])
            res['summary'][crit] = np.mean([res['summary'][ckey][crit] for ckey in res['summary'] if ckey not in criteria])
        self.logger.logjson('results', res, update=True)

        return {
            cstr: {
                seed: dict(
                    roc_norm_vs_counterfact=roc_norm_vs_counterfact.auc, ascores_counterfactuals=ascores_counterfactuals,
                    ascores_true_normals=ascores_true_normals, ascores_true_anomalies=ascores_true_anomalies,
                    fid_score_true_normal_vs_counterfactual=fid_score, concept_clf_accuracy=acc_concept_clf,
                    auc_norm_vs_counterfact_normalized=auc_norm_vs_counterfact_normalized,
                    normalized_fid_score=normalized_fid_score,
                )
            }
        }
