import gc
import os
import os.path as pt
import traceback
import json
import sys
from abc import ABC
from copy import deepcopy
from typing import Callable, List, Mapping, Tuple
from datetime import datetime

import numpy as np
import shutil
import sklearn
import torch
import torch.nn as nn
import torchvision.datasets
from sklearn.metrics import accuracy_score
from torch.nn import Module
from torch.nn.functional import cross_entropy
from torch.optim.lr_scheduler import _LRScheduler
from torchvision.datasets.folder import default_loader
from torchvision.transforms.functional import to_tensor
from diffusers import DDIMScheduler, DDIMInverseScheduler
from tqdm import tqdm

from xad.counterfactual.eval import get_roc, compute_fid_scores
from xad.counterfactual.bases import XTrainer, huber_distance, hinge_disc_loss
from xad.datasets.bases import CombinedDataset, TorchvisionDataset
from xad.models.bases import ConceptNN, ConditionalDiscriminator, ConditionalGenerator
from xad.utils.logger import Logger
from xad.utils.training_tools import NanGradientsError, int_set_to_str, weight_reset
from xad.utils.data_tools import random_split_tensor
from xad.counterfactual.diffedit_ce_pipe import CEStableDiffusionDiffEditPipeline

MINIBATCHIFY = 1


class DiffEditTrainer(XTrainer):
    def __init__(self, xmodels: List[Module],
                 n_concepts: int, epochs: int, lr: float, wdk: float, milestones: List[int],
                 batch_size: int, logger: Logger = None, oe=True,
                 lamb_dist: float = 1e-3, mask_encode_strength: float = 0.5, mask_thresholding_ratio: float = 2.0,
                 diffusion_inference_steps: int = 40, diffusion_resolution: int = 512,
                 devices: List[torch.device] = None, lamb_conc: float = 0.5, lamb_gen: float = 1e-1,
                 lamb_asc: float = 1.0, gen_every: int = 2, disc_every: int = 1,
                 milestone_alpha: float = 0.5, additive_gen=True, gen_use_mask=True,
                 **kwargs):
        self.__setup = {f'x_{k}': v for k, v in locals().items() if k not in ['self', 'generator', 'discriminator', 'concept_classifier']}
        if not isinstance(xmodels[0], ConditionalGenerator):
            raise ValueError(f"Expected xmodels[0] to be a ConditionalGenerator, but it is {xmodels[0].__class__}")
        if not isinstance(xmodels[1], ConceptNN):
            raise ValueError(f"Expected xmodels[1] to be a ConceptNN, but it is {xmodels[1].__class__}")
        self.generator = xmodels[0]
        self.concept_classifier = xmodels[1]
        self.discriminator = None
        if len(xmodels) > 2 and not isinstance(xmodels[2], ConditionalDiscriminator):
            raise ValueError(f"Expected xmodels[2] to be a ConditionalDiscriminator, but it is {xmodels[2].__class__}")
        elif len(xmodels) > 2:
            self.discriminator = xmodels[2]
        self.n_concepts = n_concepts
        self.epochs = epochs
        self.lr = lr
        self.wdk = wdk
        self.milestones = milestones  # can be floats for intermediate epochs
        self.milestone_alpha = milestone_alpha
        self.batch_size = batch_size
        self.oe = oe
        self.lamb_dist = lamb_dist
        self.lamb_conc = lamb_conc
        self.lamb_gen = lamb_gen
        self.lamb_asc = lamb_asc
        self.gen_every = gen_every
        self.disc_every = disc_every
        self.mask_encode_strength = mask_encode_strength
        self.mask_thresholding_ratio = mask_thresholding_ratio
        self.diffusion_inference_steps = diffusion_inference_steps
        self.diffusion_resolution = diffusion_resolution
        self.additive_gen = additive_gen
        self.gen_use_mask = gen_use_mask
        self.devices = devices if devices is not None else [torch.device(f'cuda:{i}') for i in range(torch.cuda.device_count())]
        if len(self.devices) <= 0:
            raise ValueError(
                f"DiffEdit trainer requires at least one GPU, but there are only {len(devices)}!" 
                f"Check the CUDA_VISIBLE_DEVICES environment variable and the --devices argument."
            )
        self.logger: Logger = logger  # will usually be set later in :func:`xad.main.create_trainer`.

        self.ce_pipe = CEStableDiffusionDiffEditPipeline.from_pretrained(
            "stabilityai/stable-diffusion-2-1",
            torch_dtype=torch.float16,
            # device_map='auto',
            low_cpu_mem_usage=True
        )
        self.ce_pipe.to(self.devices[0])
        self.ce_pipe.scheduler = DDIMScheduler.from_config(self.ce_pipe.scheduler.config)
        self.ce_pipe.inverse_scheduler = DDIMInverseScheduler.from_config(self.ce_pipe.scheduler.config)
        self.ce_pipe.set_progress_bar_config(disable=True)
        for p in self.ce_pipe.unet.parameters():
            p.requires_grad_(False)
        for p in self.ce_pipe.text_encoder.parameters():
            p.requires_grad_(False)
        for p in self.ce_pipe.vae.parameters():
            p.requires_grad_(False)

    def __setattr__(self, key, value):
        if key == 'logger' and value is not None and hasattr(self, 'logger') and self.logger is None:
            value.logsetup(self.__setup, nets=[self.generator, self.concept_classifier])  # log training setup
            del self.__setup
        super().__setattr__(key, value)

    def minibatchify(self, batch: Tuple, size: int):
        bs = max([len(b) for b in batch])
        for s, e in ((start, start+size) for start in range(0, bs, size)):
            yield tuple(b[s:e] for b in batch)

    def diffedit_latents_to_img(self, latents: torch.Tensor) -> torch.Tensor:
        image = self.ce_pipe.vae.decode(
            latents / self.ce_pipe.vae.config.scaling_factor, return_dict=False
        )[0]
        image = self.ce_pipe.image_processor.postprocess(
            image, output_type='pt', do_denormalize=[True] * image.shape[0]
        )
        return image

    def diffedit(self, img: torch.Tensor, query_prompt: str, ad_device: torch.device, mask: bool = True) -> Tuple[
        torch.Tensor, np.ndarray, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        query_prompt = [query_prompt] * img.shape[0]
        target_prompt = [""] * img.shape[0]
        with torch.no_grad():
            upscaled_img = nn.functional.interpolate(
                img, (self.diffusion_resolution, self.diffusion_resolution), mode="bilinear"
            )
            mask_image = self.ce_pipe.generate_mask(
                image=upscaled_img, source_prompt=query_prompt, target_prompt=target_prompt,
                num_maps_per_mask=10, mask_encode_strength=self.mask_encode_strength,
                mask_thresholding_ratio=self.mask_thresholding_ratio,
                num_inference_steps=self.diffusion_inference_steps,
                guidance_scale=7.5
            ) 
            if not mask:
                mask_image = mask_image * 0 + 1
            image_latents = self.ce_pipe.invert(
                image=upscaled_img, prompt=target_prompt,
                num_inference_steps=self.diffusion_inference_steps, guidance_scale=7.5
            ).latents  # batch_size x time_steps x latent_channels x latent_height x latent_width
            image_latents_latest = image_latents.squeeze(0)[0]
            diffedit_suggestion_latent = self.ce_pipe(
                prompt=query_prompt, mask_image=mask_image,
                image_latents=image_latents,
                num_inference_steps=self.diffusion_inference_steps, output_type='latent', 
            )[0].to(ad_device)
            diffedit_suggestion = self.diffedit_latents_to_img(diffedit_suggestion_latent).to(ad_device)
        return upscaled_img, mask_image, image_latents, image_latents_latest, diffedit_suggestion_latent, diffedit_suggestion

    def generate_image(self, generator: ConditionalGenerator, diffedit_suggestion_latent: torch.Tensor, image_latents: torch.Tensor, 
                       mask_image: np.ndarray, condition: torch.Tensor, query_prompt: str, final_res: int, 
                       ad_device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
        generator.to(diffedit_suggestion_latent.device)
        modification = generator(diffedit_suggestion_latent, condition.to(diffedit_suggestion_latent.device))
        mask = torch.from_numpy(mask_image).unsqueeze(1).to(modification.device)
        if self.additive_gen:
            if self.gen_use_mask:
                modified_suggestion_latent = diffedit_suggestion_latent + modification * mask
            else:
                modified_suggestion_latent = diffedit_suggestion_latent + modification
        else:
            if self.gen_use_mask:
                modified_suggestion_latent = (1 - mask) * diffedit_suggestion_latent + modification * mask
            else:
                modified_suggestion_latent = modification
        generated_image = self.diffedit_latents_to_img(modified_suggestion_latent).to(ad_device)
        generated_image_small = nn.functional.interpolate(
            generated_image, (final_res, final_res), mode="bilinear"
        ).float()    
        return generated_image_small, generated_image   

    def get_query_prompt(self, cstr :str) -> str:
        query_prompt = " or ".join([
            f'a photo of an {cs.replace("_", " ")}' if cs.startswith(("a", "e", "i", "o", "u"))
            else f'a photo of a {cs.replace("_", " ")}'
            for cs in cstr.split("+")
        ])
        return query_prompt

    def run(self, ad_model: torch.nn.Module, ad_feature_to_ascore: Callable, dataset: TorchvisionDataset,
            cset: set[int], clsstr: str, seed: int, workers: int = 4,
            load_models: str = None, ad_device: torch.device = torch.device('cuda:0'),
            tqdm_mininterval: float = 0.1, ad_trainer_parent=None) -> dict:
        if self.logger is None:
            raise ValueError('Logger has not been set.')
        self.logger.logsetup({'load': {int_set_to_str(cset): {seed: load_models}}, })

        def copy_models():
            generator = deepcopy(self.generator)
            generator.apply(weight_reset)
            concept_clf = deepcopy(self.concept_classifier)
            concept_clf.parameterize()
            concept_clf.apply(weight_reset)
            assert all([p.is_leaf for p in self.generator.parameters()])
            assert all([p.is_leaf for p in self.concept_classifier.parameters()])
            for n, p in generator.named_parameters():
                p.detach_().requires_grad_()  # otherwise jit models don't work due to grad_fn=clone_backward
            for n, p in concept_clf.named_parameters():
                p.detach_().requires_grad_()  # otherwise jit models don't work due to grad_fn=clone_backward
            discriminator = None
            if self.discriminator is not None:
                discriminator = deepcopy(self.discriminator)
                discriminator.parameterize()
                discriminator.apply(weight_reset)
                assert all([p.is_leaf for p in self.discriminator.parameters()])
                for n, p in discriminator.named_parameters():
                    p.detach_().requires_grad_()  # otherwise jit models don't work due to grad_fn=clone_backward
            return generator, discriminator, concept_clf

        eval_metrics = None
        for i in range(5):
            try:
                generator, discriminator, concept_clf = copy_models()
                generator, discriminator, concept_clf, metrics = self.train(
                    generator, concept_clf, ad_model, ad_feature_to_ascore, dataset, cset, clsstr, seed,
                    workers, load_models, ad_device, tqdm_mininterval, discriminator=discriminator
                )
                gc.collect()
                eval_metrics = self.eval(
                    generator, concept_clf, ad_model, ad_feature_to_ascore, dataset, cset, clsstr, seed,
                    workers, ad_device, tqdm_mininterval, ad_trainer_parent, discriminator=discriminator
                )
                gc.collect()
                break
            except NanGradientsError as err:  # try once more
                self.logger.warning(
                    f'Gradients got NaN for class {int_set_to_str(cset)} "{clsstr}" and seed {seed}. '
                    f'Happened {i} times so far. Try once more.'
                )
                if i == 3 - 1:
                    generator, concept_clf, metrics = None, None, None
                    self.logger.warning(
                        f'Gradients got NaN for class {int_set_to_str(cset)} "{clsstr}" and seed {seed}. '
                        f'Happened {i} times so far. Try no more. Set model and roc to None.'
                    )
        return eval_metrics

    def train(self, generator: ConditionalGenerator, concept_clf: ConceptNN,
              ad_model: Module, ad_feature_to_ascore: Callable,
              ds: TorchvisionDataset, cset: set[int], cstr: str, seed: int,
              workers: int, load_models: str = None,
              ad_device: torch.device = torch.device('cuda:0'),
              tqdm_mininterval: float = 0.1, discriminator: ConditionalDiscriminator = None,
              ) -> Tuple[ConditionalGenerator, ConditionalDiscriminator, Module, Mapping]:
        # ---- prepare model and variables
        original_ds = ds
        disc_device = self.devices[-1]
        generator = generator.to(ad_device).train()
        discriminator = discriminator.to(disc_device).train() if discriminator is not None else None
        concept_clf.to(ad_device).train()
        ad_model.to(ad_device).eval()
        ad_model_requires_grad = list(ad_model.parameters())[0].requires_grad
        for p in ad_model.parameters():
            p.requires_grad_(False)
        epochs = self.epochs
        n_log = 20
        cs = int_set_to_str(cset)
        query_prompt = self.get_query_prompt(cstr)
        asc_guide_scaler = torch.cuda.amp.GradScaler(2.**9)
        metrics = {}

        
        # ---- optimizers and loaders
        opt = torch.optim.Adam(
            [{'params': generator.parameters()}, {'params': concept_clf.parameters()},],
            lr=self.lr, weight_decay=self.wdk, betas=(0.0, 0.9)
        )
        if discriminator is None:
            discopt = None
        else:
            discopt = torch.optim.Adam(
                [{'params': discriminator.parameters()}],
                lr=self.lr, weight_decay=self.wdk, betas=(0.0, 0.9)
            )
        if not isinstance(ds, CombinedDataset):
            raise ValueError("Training DiffEdit requires OE, but the dataset contains no OE!")
        else:
            if discriminator is None:
                ds = ds.oe  # exclude normal samples
        loader, _ = ds.loaders(self.batch_size, num_workers=workers, persistent=True, device=ad_device)
        sched = torch.optim.lr_scheduler.MultiStepLR(
            opt, [int(float(m) * len(loader)) for m in self.milestones], self.milestone_alpha
        )
        if discriminator is None:
            discsched = None
        else:
            discsched = torch.optim.lr_scheduler.MultiStepLR(
                discopt, [int(float(m) * len(loader)) for m in self.milestones], self.milestone_alpha
            )

        # ---- get test examples for logging each epoch
        def get_test_examples(n: int, shuffle: bool):
            tst_loader = iter(original_ds.loaders(
                self.batch_size, num_workers=0, persistent=False, device=ad_device, shuffle_test=shuffle
            )[1])
            examples, example_labels, it = torch.Tensor([]).to(ad_device), torch.Tensor([]).to(ad_device), 0
            try:
                while examples.shape[0] < n:
                    imgs, lbls, idcs = next(tst_loader)
                    if (lbls == 1).sum() > 0:
                        examples = torch.cat([examples, imgs[lbls == 1][:n - examples.shape[0]]])
                        example_labels = torch.cat([example_labels, lbls[lbls == 1][:n - example_labels.shape[0]]])
                        if shuffle:
                            self.logger.logjson(
                                'misc', {
                                    f'test_examples_xtraining_it{it}': {f'cls_{cs}': {f'seed_{seed}': idcs}}
                                }, update=True
                            )
                        it += 1
            except StopIteration:
                pass
            del tst_loader
            with torch.no_grad():
                true_ascores = ad_feature_to_ascore(ad_model(examples))
            return examples, example_labels, true_ascores
        test_examples, test_example_labels, test_true_ascores = get_test_examples(n_log, True)

        def concat_train_examples(imgs, lbls, idcs, train_examples, train_example_labels, n: int, it: int):
            if train_examples.shape[0] < n:
                perm = torch.randperm(imgs.shape[0])
                train_examples = torch.cat([
                    train_examples, imgs[perm][:n - train_examples.shape[0]].detach().clone()
                ])
                self.logger.logjson(
                    'misc', {f'train_examples_xtraining_it{it}': {f'cls_{cs}': {f'seed_{seed}': idcs[perm][:n].cpu()}}},
                    update=True
                )
                train_example_labels = torch.cat([
                    train_example_labels, lbls[perm][:n - train_example_labels.shape[0]].detach().clone()
                ])
                with torch.no_grad():
                    train_true_ascores = ad_feature_to_ascore(ad_model(train_examples))
                return train_examples, train_example_labels, train_true_ascores
        train_examples, train_example_labels = torch.Tensor([]).to(ad_device), torch.Tensor([]).to(ad_device)

        # ---- prepare trackers and loggers
        reconloss, ascloss, genloss, discloss, conloss = None, None, None, None, None
        overall_max_it = len(loader) * epochs
        start_overall_it = self.load(load_models if isinstance(load_models, str) else None, generator, concept_clf, discriminator, opt, sched)
        overall_it = start_overall_it
        start_ep = start_overall_it // len(loader)
        ep = start_ep
        to_track = {
            'ep': lambda: f'{ep+1:{len(str(epochs))}d}/{epochs}',
            'recl': lambda: reconloss.item() if reconloss is not None else None,
            'ascl': lambda: ascloss.item() if ascloss is not None else None,
            'conl': lambda: conloss.item() if conloss is not None else None,
            'genl': lambda: genloss.item() if genloss is not None else None,
            'lr': lambda: sched.get_last_lr()
        }

        tracker = self.logger.track(
            [epochs, len(loader)], to_track, f'Xtraining {cstr}', mininterval=tqdm_mininterval
        )
        if start_ep > 0 and epochs > 0:  # will unforunately mess with the ETA a bit...
            for _ in range(start_ep):
                for dummy_it in range(len(loader)):
                    tracker.update([0, 1])
                    if discriminator is not None and dummy_it % self.disc_every == 0:
                        # discopt ?
                        discsched.step()
                        pass
            tracker.update([1, 0])
        try:
            # ---- loop over epochs
            for ep in range(start_ep, epochs):
                last_overall_it_this_ep = len(loader) * (ep + 1) - 1
                intermediate_start = (start_overall_it % len(loader)) if (start_overall_it > 0 and start_ep == ep) else 0
                if intermediate_start > 0:
                    for dummy_it in range(intermediate_start):
                        tracker.update([0, 1])  # will unforunately mess with the ETA a bit... 
                        if discriminator is not None and dummy_it % self.disc_every == 0:
                            # discopt ?
                            discsched.step()
                            pass

                # ---- loop over batches
                for it, (imgs, lbls, idcs) in enumerate(loader):
                    if train_examples.shape[0] < n_log:
                        train_examples, train_example_labels, train_true_ascores = concat_train_examples(
                            imgs, lbls, idcs, train_examples, train_example_labels, n_log, it
                        )
                    if len(lbls) <= 1:
                        continue

                    train_gen = it % self.gen_every == 0
                    train_disc = it % self.disc_every == 0
                    generated_image_small_batch = []
                    with torch.no_grad():
                        true_ascores = ad_feature_to_ascore(ad_model(imgs.to(ad_device)))

                    if train_gen:
                        # ---- loop over mini batches
                        for sit, (img, lbl, idx) in enumerate(self.minibatchify((imgs, lbls, idcs), MINIBATCHIFY)):
                            # skip normal samples
                            if (lbl == 0).sum().item() <= 0:
                                continue 
                            img, lbl, idx = img[lbl == 0], lbl[lbl == 0], idx[lbl == 0]

                            with torch.autocast(device_type='cuda', dtype=torch.float16):
                                upscaled_img, mask_image, image_latents, image_latents_latest, diffedit_suggestion_latent, diffedit_suggestion = self.diffedit(
                                    img, query_prompt, ad_device
                                )

                                random_condition = torch.randint(0, self.n_concepts, (img.shape[0], )).to(diffedit_suggestion_latent.device)
                                generated_image_small, generated_image = self.generate_image(
                                    generator, diffedit_suggestion_latent, image_latents, mask_image, random_condition, 
                                    query_prompt, img.shape[-1], ad_device
                                )
                                generated_image_small_batch.append(generated_image_small.detach())

                                # penalize large anomaly score       
                                ad_score = ad_feature_to_ascore(ad_model(generated_image_small))
                                ascloss = self.lamb_asc * ad_score.mean()

                                # penalize distance to original suggestion
                                reconloss = nn.functional.mse_loss(diffedit_suggestion, generated_image)
                                reconloss = self.lamb_dist * reconloss

                                # penalize concept misclassification
                                conloss = cross_entropy(concept_clf(img, generated_image_small), random_condition.to(ad_device))
                                conloss = self.lamb_conc * conloss

                                # fool discriminator: generated CEs --> real (target==1)
                                if discriminator is not None:
                                    disc_logits = discriminator(  # assumes discrete_anomaly_scores==1
                                        generated_image_small.to(disc_device), torch.IntTensor((0, )).unsqueeze(0).to(disc_device) 
                                    )
                                    genloss = hinge_disc_loss(
                                        disc_logits, 
                                        torch.ones(torch.Size((disc_logits.shape[0], ))).to(disc_device)
                                    ) * self.lamb_gen
                                    genloss = genloss.to(ad_device)
                        
                                # together
                                loss = reconloss + ascloss + conloss
                                loss = loss + (genloss if genloss is not None else 0)

                            scaled_loss = asc_guide_scaler.scale(loss)
                            scaled_loss.backward()

                        for p in opt.param_groups:
                            for w in p['params']:
                                w.grad /= (lbls != 0).sum().item()
                        asc_guide_scaler.step(opt)
                        asc_guide_scaler.update()
                        opt.zero_grad()
                        sched.step()  # we do a step per batch iteration here!

                    if discriminator is not None and train_disc:
                        # normal ground_truth samples --> real (target==1)
                        org_disc_logits = discriminator(
                            imgs[lbls==0].to(disc_device), 
                            torch.IntTensor((0, )).unsqueeze(0).repeat( # assumes discrete_anomaly_scores==1
                                (lbls==0).sum(), 1
                            ).to(disc_device) 
                        )
                        org_disc_loss = hinge_disc_loss(
                            org_disc_logits, 
                            torch.ones(torch.Size((org_disc_logits.shape[0], ))).to(disc_device)
                        ) * self.lamb_gen
                        if not train_gen:
                            with torch.no_grad():
                                for sit, (img, lbl, idx) in enumerate(self.minibatchify((imgs, lbls, idcs), MINIBATCHIFY)):
                                    if (lbl == 0).sum().item() <= 0:
                                        continue 
                                    img, lbl, idx = img[lbl == 0], lbl[lbl == 0], idx[lbl == 0]
                                    with torch.autocast(device_type='cuda', dtype=torch.float16):
                                        upscaled_img, mask_image, image_latents, image_latents_latest, diffedit_suggestion_latent, diffedit_suggestion = self.diffedit(
                                            img, query_prompt, ad_device
                                        )
                                        random_condition = torch.randint(0, self.n_concepts, (img.shape[0], )).to(diffedit_suggestion_latent.device)
                                        generated_image_small, generated_image = self.generate_image(
                                            generator, diffedit_suggestion_latent, image_latents, mask_image, random_condition, 
                                            query_prompt, img.shape[-1], ad_device
                                        )
                                        generated_image_small_batch.append(generated_image_small.detach())
                        generated_image_small_batch = torch.cat(generated_image_small_batch)
                        # generated CEs --> fake (target==0)
                        disc_logits = discriminator(  
                            generated_image_small_batch.to(disc_device), 
                            torch.IntTensor((0, )).unsqueeze(0).repeat( # assumes discrete_anomaly_scores==1
                                generated_image_small_batch.shape[0], 1
                            ).to(disc_device) 
                        )
                        ce_disc_loss = hinge_disc_loss(
                            disc_logits, 
                            torch.zeros(torch.Size((disc_logits.shape[0], ))).to(disc_device)
                        ) * self.lamb_gen
                        discloss = org_disc_loss + ce_disc_loss
                        discloss.backward()
                        discopt.step()
                        discopt.zero_grad()
                        discsched.step()  # we do a step per batch iteration here!

                    reconloss, ascloss, conloss, genloss = reconloss.cpu(), ascloss.cpu(), conloss.cpu(), genloss.cpu()
                    discloss = discloss.cpu()


                    # # ---- log stuff
                    if seed == 0:
                        self.logger.add_scalar(f'training__C{cs}-{cstr}__ep', ep, tracker.n)
                    self.logger.add_scalar(f'training__C{cs}-{cstr}-S{seed}__xloss', loss.item(), tracker.n,)
                    self.logger.add_scalar(f'training__C{cs}-{cstr}-S{seed}__rec_loss', reconloss.item(), tracker.n, )
                    self.logger.add_scalar(f'training__C{cs}-{cstr}-S{seed}__asc_loss', ascloss.item(), tracker.n, )
                    self.logger.add_scalar(f'training__C{cs}-{cstr}-S{seed}__con_loss', conloss.item(), tracker.n, )
                    self.logger.add_scalar(f'training__C{cs}-{cstr}-S{seed}__disc_loss', discloss.item(), tracker.n, )
                    self.logger.add_scalar(f'training__C{cs}-{cstr}-S{seed}__gen_loss', genloss.item(), tracker.n, )
                    self.logger.flush()
                    tracker.update([0, 1])

                    # ---- log intermediate results at end of certain iterations stuff
                    overall_it = it + len(loader) * ep + intermediate_start
                    if overall_it in [0, 10, 50, 100, 250,] or overall_it % 500 == 0:
                        if len(train_examples) >= n_log:
                            self.log_images(
                                generator, f'train__C{cs}-{cstr}-S{seed}', train_examples, train_example_labels, train_true_ascores,
                                ad_model, ad_feature_to_ascore, query_prompt, concept_clf, n_log, ep=overall_it
                            )
                            self.log_images(
                                generator, f'test__C{cs}-{cstr}-S{seed}', test_examples, test_example_labels, test_true_ascores,
                                ad_model, ad_feature_to_ascore, query_prompt, concept_clf, n_log, ep=overall_it
                            )
                            if overall_it >= 1000:
                                self.logger.snapshot(f'intermediate_epit{overall_it}_snapshot_generator_cls{cs}_it{seed}', generator, opt, sched, overall_it)
                                if discriminator is not None:
                                    self.logger.snapshot(f'intermediate_epit{overall_it}_snapshot_discriminator_cls{cs}_it{seed}', discriminator, opt, sched, overall_it)
                                self.logger.snapshot(f'intermediate_epit{overall_it}_snapshot_concept-clf_cls{cs}_it{seed}', concept_clf, opt, sched, overall_it)
                                intermediate_generator_snapshots = sorted([
                                    snapshot_file for snapshot_file in os.listdir(pt.join(self.logger.dir, 'snapshots')) 
                                    if snapshot_file.endswith(f'_generator_cls{cs}_it{seed}.pt') and snapshot_file.startswith('intermediate_epit')
                                ], key=lambda x: int(x[len('intermediate_epit'):-len(f'_snapshot_generator_cls{cs}_it{seed}.pt')]))
                                if len(intermediate_generator_snapshots) >= 4:
                                    oldest_gen_snapshot = intermediate_generator_snapshots[0]
                                    os.remove(pt.join(self.logger.dir, 'snapshots', oldest_gen_snapshot))
                                    os.remove(pt.join(self.logger.dir, 'snapshots', oldest_gen_snapshot.replace('_generator_', '_discriminator_')))
                                    os.remove(pt.join(self.logger.dir, 'snapshots', oldest_gen_snapshot.replace('_generator_', '_concept-clf_')))
                                
                    self.logger.print(
                        f"{datetime.now().strftime('%Y/%m/%d %H:%M:%S')}: Xtraining for {cstr} (seed {seed}) "
                        f"epoch {overall_it:06d}/{overall_max_it:06d} finished.", only_file=True
                    )

                    if intermediate_start > 0:  # break epoch "early" to match loaded snapshots' intermediate iteration
                        if overall_it >= last_overall_it_this_ep:  
                            break

                if reconloss.isnan().sum() > 0:
                    raise NanGradientsError()

                # ---- update tracker and scheduler
                tracker.update([1, 0])

            # ---- log end-of-training stuff
            tracker.close()
            if train_examples is None or len(train_examples) == 0:
                train_loader = enumerate(iter(ds.loaders(self.batch_size, num_workers=0, persistent=False, device=ad_device, )[0]))
                while len(train_examples) < n_log:
                    it, (imgs, lbls, idcs) = next(train_loader)
                    train_examples, train_example_labels, train_true_ascores = concat_train_examples(
                        imgs, lbls, idcs, train_examples, train_example_labels, n_log, it
                    )
                del train_loader
            self.log_images(
                generator, f'train__C{cs}-{cstr}-S{seed}', train_examples, train_example_labels, train_true_ascores,
                ad_model, ad_feature_to_ascore, query_prompt, concept_clf, n_log
            )
            self.log_images(
                generator, f'test__C{cs}-{cstr}-S{seed}', test_examples, test_example_labels, test_true_ascores,
                ad_model, ad_feature_to_ascore, query_prompt, concept_clf, n_log,
            )
            original_ids = deepcopy(original_ds.test_set.indices)
            tail = (np.asarray(original_ids).shape[0] % (n_log * 4))
            original_ds.test_set.indices = [
                i[len(i)//2]
                for i in np.split((np.asarray(original_ids)[:-tail]) if tail > 0 else np.asarray(original_ids), n_log * 4)
            ]
            test_examples, test_example_labels, test_true_ascores = get_test_examples(n_log * 2, False)
            self.log_images(
                generator, f'testFix__C{cs}-{cstr}-S{seed}', test_examples, test_example_labels, test_true_ascores,
                ad_model, ad_feature_to_ascore, query_prompt, concept_clf, n_log * 2,
            )
            original_ds.test_set.indices = original_ids
            test_examples, test_example_labels, test_true_ascores = get_test_examples(n_log * 4, True)
            self.log_images(
                generator, f'testLong__C{cs}-{cstr}-S{seed}', test_examples, test_example_labels, test_true_ascores,
                ad_model, ad_feature_to_ascore, query_prompt, concept_clf, n_log,
            )
            self.logger.flush(timeout=20)
        except KeyboardInterrupt as err:
            self.logger.warning(f"Diffedit training. Keyboard interrupt.")
        except Exception as err:
            self.logger.warning(''.join(traceback.format_exception(err.__class__, err, err.__traceback__)))
        finally:
            # reset requires_grad of AD model to original state
            for p in ad_model.parameters():
                p.requires_grad_(ad_model_requires_grad)

            cs = int_set_to_str(cset)
            if epochs > 0 and start_ep != epochs:
                self.logger.snapshot(f'snapshot_generator_cls{cs}_it{seed}', generator, opt, sched, overall_it)
                if discriminator is not None:
                    self.logger.snapshot(f'snapshot_discriminator_cls{cs}_it{seed}', discriminator, opt, sched, overall_it)
                self.logger.snapshot(f'snapshot_concept-clf_cls{cs}_it{seed}', concept_clf, opt, sched, overall_it)
                if overall_it < overall_max_it - 1:  
                    self.logger.warning(f'An error occurred. Intermediate snapshots for overall iterations {overall_it}/{overall_max_it} saved.')
                    if train_examples is not None:
                        self.log_images(
                            generator, f'train_EP{overall_it}__C{cs}-{cstr}-S{seed}', train_examples, train_example_labels,
                            train_true_ascores, ad_model, ad_feature_to_ascore, query_prompt, concept_clf, n_log
                        )
                    self.log_images(
                        generator, f'test_EP{overall_it}__C{cs}-{cstr}-S{seed}', test_examples, test_example_labels, test_true_ascores,
                        ad_model, ad_feature_to_ascore, query_prompt, concept_clf, n_log,
                    )
                    test_examples, test_example_labels, test_true_ascores = get_test_examples(n_log * 4, True)
                    self.log_images(
                        generator, f'testLong_EP{overall_it}__C{cs}-{cstr}-S{seed}', test_examples, test_example_labels, test_true_ascores,
                        ad_model, ad_feature_to_ascore, query_prompt, concept_clf, n_log,
                    )
                    self.logger.flush(timeout=20)

        return generator, discriminator, concept_clf, metrics

    def log_images(self, generator: ConditionalGenerator, name: str, examples: torch.Tensor, true_labels: torch.Tensor,
                   true_ascores: torch.Tensor, ad_model: Module, ad_feature_to_ascore: Callable, query_prompt: str, 
                   concept_clf: ConceptNN, n_log: int, ep: int = None):
        is_intermediate = ep is not None
        ad_device = next(ad_model.parameters()).device
        name = name if not is_intermediate else pt.join("intermediate", f"{name}_ep{ep:03d}")
        n_log = min(n_log, examples.shape[0])
        spatial_size = examples.shape[-1]
        generator.eval()
        concept_clf.eval()
        with torch.autocast(device_type='cuda', dtype=torch.float16):
            with torch.no_grad():
                # get diffedit outputs
                unmasked_diffedit_suggestions = []
                unmasked_diffedit_suggestion_scores = []
                unmasked_diffedit_suggestion_concepts = []
                masks = []
                diffedit_suggestions = []
                diffedit_suggestion_scores = []
                diffedit_suggestion_concepts = []
                diffedit_suggestion_latent_list = []
                image_latents_list = []
                for single_example in examples:  # TODO minibatchify here as well
                    single_example = single_example.unsqueeze(0)

                    _, _, _, _, _, unmasked_diffedit_suggestion = self.diffedit(
                        single_example, query_prompt, ad_device, mask=False
                    )                
                    unmasked_diffedit_suggestion_small = nn.functional.interpolate(
                        unmasked_diffedit_suggestion, (spatial_size, spatial_size), mode="bilinear"
                    ).float() 
                    unmasked_diffedit_suggestion_scores.append(ad_feature_to_ascore(ad_model(unmasked_diffedit_suggestion_small)).cpu())
                    unmasked_diffedit_suggestion_concepts.append(concept_clf(single_example, unmasked_diffedit_suggestion_small).cpu())
                    unmasked_diffedit_suggestions.append(unmasked_diffedit_suggestion_small.cpu())

                    upscaled_img, mask_image, image_latents, image_latents_latest, diffedit_suggestion_latent, diffedit_suggestion = self.diffedit(
                        single_example, query_prompt, ad_device
                    )
                    diffedit_suggestion_latent_list.append(diffedit_suggestion_latent)
                    image_latents_list.append(image_latents)
                    masks.append(mask_image)
                    diffedit_suggestion_small = nn.functional.interpolate(
                        diffedit_suggestion, (spatial_size, spatial_size), mode="bilinear"
                    ).float() 
                    diffedit_suggestion_scores.append(ad_feature_to_ascore(ad_model(diffedit_suggestion_small)).cpu())
                    diffedit_suggestion_concepts.append(concept_clf(single_example, diffedit_suggestion_small).cpu())
                    diffedit_suggestions.append(diffedit_suggestion_small.cpu()) 

                # get generated images
                images = []
                images_masked = []
                predicted_scores = [true_ascores.cpu()]
                predicted_concepts = [concept_clf(examples, examples).cpu()]
                target_scores, target_concepts = [predicted_scores[0]], [predicted_concepts[0].argmax(1)]
                concepts = [torch.ones(examples.shape[0]).int() * c for c in range(self.n_concepts)]
                for i, c in enumerate(concepts):
                    zped = zip(c, examples, masks, diffedit_suggestion_latent_list, image_latents_list)
                    for cc, single_example, mask_image, diffedit_suggestion_latent, image_latents in zped:
                        single_example = single_example.unsqueeze(0)
                        cc = cc.unsqueeze(0)
                        generated_image_small, generated_image = self.generate_image(
                            generator, diffedit_suggestion_latent, image_latents, mask_image, cc, 
                            query_prompt, spatial_size, ad_device
                        )   
                        if not self.gen_use_mask:
                            self.gen_use_mask = True
                            generated_image_small_masked, generated_image_masked = self.generate_image(
                                generator, diffedit_suggestion_latent, image_latents, mask_image, cc, 
                                query_prompt, spatial_size, ad_device, 
                            )   
                            images_masked.append(generated_image_small_masked.cpu())
                            self.gen_use_mask = False
                        predicted_scores.append(ad_feature_to_ascore(ad_model(generated_image_small)).cpu())
                        predicted_concepts.append(concept_clf(single_example, generated_image_small).cpu())
                        images.append(generated_image_small.cpu())
                        target_scores.append(torch.zeros_like(cc).cpu())
                        target_concepts.append(cc.cpu())

                # prepare tensors for visualizing
                masks = [
                    nn.functional.interpolate(
                        torch.from_numpy(mask_image).repeat(3, 1, 1).unsqueeze(0).float(),
                        (spatial_size, spatial_size),  # not bilinear because we're upsampling here
                    ) for mask_image in masks
                ]
                predicted_scores, predicted_concepts = torch.cat(predicted_scores), torch.cat(predicted_concepts)
                target_scores, target_concepts = torch.cat(target_scores), torch.cat(target_concepts)
                unmasked_diffedit_suggestion_scores = torch.cat(unmasked_diffedit_suggestion_scores)
                unmasked_diffedit_suggestion_concepts = torch.cat(unmasked_diffedit_suggestion_concepts)
                diffedit_suggestion_scores = torch.cat(diffedit_suggestion_scores)
                diffedit_suggestion_concepts = torch.cat(diffedit_suggestion_concepts)

                # ---- counterfactual overview figure (3 rows: first is original, the others with CEs for each concept)
                titles, title_colors = [], []
                zped = zip(predicted_scores, target_scores, predicted_concepts, target_concepts)
                for i, (a, ta, c, tc) in enumerate(zped):
                    c = c.argmax()
                    titles.append(f'^a{a:0.2f}_c{c}')
                    if i < n_log:
                        title_colors.append('olive' if true_labels[i] == 0 else 'saddlebrown')
                    elif c != tc or (a - ta)**2 > 0.1:
                        title_colors.append('red')
                    else:
                        title_colors.append('green')
                this_name = pt.join('conditional_generation', name)
                self.logger.log_fig(
                    torch.cat([examples.cpu(), *images, ]), this_name,
                    nrow=examples.shape[0],
                    rowheaders=[
                        'original', *[
                            f'a={a} c={c}' for c in range(self.n_concepts)
                            for a in ["0.0", ]
                        ]
                    ],
                    titles=titles, title_colors=title_colors,
                    maxres=64 if is_intermediate else examples.shape[-1],
                    pdf=not is_intermediate,
                )

        generator.train()
        concept_clf.train()

    def load(self, path: str, generator: ConditionalGenerator, 
             concept_clf: torch.nn.Module, 
             discriminator: ConditionalDiscriminator = None,
             opt: torch.optim.Optimizer = None, sched: _LRScheduler = None) -> int:
        """
        Loads a snapshot of the model including training state.
        @param path: the filepath where the snapshot is stored.
        @param generator: the model instance into which the parameters of the found generator snapshot are loaded.
            Hence, the architectures need to match.
        @param concept_clf: see generator.
        @param opt: the optimizer instance into which the training state is loaded.
        @param sched: the learning rate scheduler into which the training state is loaded.
        @return: the last epoch with which the snapshot's model was trained.
        """
        epoch = 0
        if path is not None:
            msg = ""
            snapshot = torch.load(path)
            net_state = snapshot.pop('net', None)
            opt_state = snapshot.pop('opt', None)
            sched_state = snapshot.pop('sched', None)
            epoch = snapshot.pop('epoch', -1) + 1
            if net_state is not None:
                generator.load_state_dict(net_state)
            else:
                return epoch
            if opt_state is not None and opt is not None:
                opt.load_state_dict(opt_state)
                msg += 'Also loaded optimizer. '
            if sched_state is not None and sched is not None:
                sched.load_state_dict(sched_state)
                msg += 'Also loaded scheduler. '
            if pt.exists(path.replace('_generator_', '_concept-clf_')) and concept_clf is not None:
                snapshot = torch.load(path.replace('_generator_', '_concept-clf_'))
                concept_clf.load_state_dict(snapshot.pop('net', None))
                msg += 'Also loaded concept classifier. '
            if pt.exists(path.replace('_generator_', '_discriminator_')) and discriminator is not None:
                snapshot = torch.load(path.replace('_generator_', '_discriminator_'))
                discriminator.load_state_dict(snapshot.pop('net', None))
                msg += 'Also loaded discriminator. '
            self.logger.print(f'Loaded counterfactual generator snapshot at epoch {epoch}. {msg}')
        return epoch

    def eval(self, generator: ConditionalGenerator, concept_clf: ConceptNN,
              ad_model: Module, ad_feature_to_ascore: Callable,
              ds: TorchvisionDataset, cset: set[int], cstr: str, seed: int,
              workers: int, ad_device: torch.device = torch.device('cuda:0'),
              tqdm_mininterval: float = 0.1, ad_trainer_parent = None, 
              discriminator: ConditionalDiscriminator = None,) -> Mapping:
        """
        @param generator: Trained generator that takes a test image and a target condition
            (made of concept and target anomaly score, see :method:`self.get_conditions`).
            Outputs a new image that is similar to the test image but hopefully
            aligns with the conditioned concept and anomaly score. If the test image is anomalous and the conditioned
            anomaly score 0, the generate image is a counterfactual example.
        @param discriminator: Trained discriminator (GAN style). Takes a test image and a supposed anomaly score.
            The latter in ordinal format (see :method:`self.to_ordinal_scores`). Returns fake (value 0) or true (value 1).
            Note that the discriminator may use the supposed anomaly score, so if the image is not aligned with the
            score it might detect this, but it also generally rates the realisticity of the image as in typical GANs.
        @param concept_clf: Concept classifier. Takes two images. The second one is supposed to be a generated version
            of the first one with the concept changed to c. The classifier predicts this c.
        @param ad_model: Takes an image and returns the anomaly feature vector.
        @param ad_feature_to_ascore: Takes an anomaly feature vector and returns an anomaly score.
        @param ds: The dataset.
        @param cset: The set of normal class ids.
        @param cstr: The string representation of cset.
        @param seed: Determines the current iteration of the training (e.g., 3 for the third random seed).
        @param workers: number of workers for the data loading.
        @param device: torch device for evaluation.
        """
        generator.eval()
        concept_clf.eval()
        query_prompt = self.get_query_prompt(cstr)
        _, loader = ds.loaders(self.batch_size, num_workers=workers, shuffle_test=False, device=ad_device, )

        procbar = tqdm(desc=f'Xeval {cstr}', total=len(loader), mininterval=tqdm_mininterval)
        ascores_true_anomalies = []
        ascores_true_normals = []
        ascores_counterfactuals = []
        concepts_ground_truth = []
        concepts_predictions = []
        
        for it, (imgs, lbls, idcs) in enumerate(loader):
            # if it not in (0, 21, 199, 299): # for debugging
            #     continue

            true_normal_imgs = imgs[lbls == 0]
            true_anomalous_imgs = imgs[lbls == 1]
            if len(true_normal_imgs) > 0:
                self.logger.log_batch_img(
                    tensor=true_normal_imgs, foldername='actual_imgs', prefix=f'{it}_'
                )
            if len(true_anomalous_imgs) > 0:
                self.logger.log_batch_img(
                    tensor=true_anomalous_imgs, foldername='anomalous_imgs', prefix=f'{it}_'
                )
            random_batch_splits = random_split_tensor(
                input=imgs, chunks=2, device=ad_device
            )
            self.logger.log_batch_img(
                tensor=random_batch_splits[0], foldername='test_subset_1', prefix=f'{it}_'
            )
            self.logger.log_batch_img(
                tensor=random_batch_splits[1], foldername='test_subset_2', prefix=f'{it}_'
            )
            counterfactual_example_batch = []
            for sit, (img, lbl, idx) in enumerate(zip(imgs, lbls, idcs)):
                with torch.no_grad():
                    with torch.autocast(device_type='cuda', dtype=torch.float16):
                        img, lbl, idx = img.unsqueeze(0), lbl.unsqueeze(0), idx.unsqueeze(0)
                        orig_anomaly_features = ad_model(img)
                        orig_true_anomaly_scores = ad_feature_to_ascore(orig_anomaly_features)

                        if lbl.item() == 0:
                            ascores_true_normals.append(orig_true_anomaly_scores.cpu())
                        else:
                            ascores_true_anomalies.append(orig_true_anomaly_scores.cpu())
                            upscaled_img, mask_image, image_latents, image_latents_latest, diffedit_suggestion_latent, diffedit_suggestion = self.diffedit(
                                img, query_prompt, ad_device
                            )
                            concepts = torch.arange(self.n_concepts)
                            for concept in concepts:
                                concept = concept.unsqueeze(0)
                                generated_image_small, generated_image = self.generate_image(
                                    generator, diffedit_suggestion_latent, image_latents, mask_image, concept, 
                                    query_prompt, img.shape[-1], ad_device
                                )
                                counterfactual_example = generated_image_small
                                counterfactual_example_batch.append(counterfactual_example.cpu())

                                counterfactual_anomaly_features = ad_model(counterfactual_example)
                                counterfactual_ascores = ad_feature_to_ascore(counterfactual_anomaly_features)
                                predicted_concepts_logits = concept_clf(img, counterfactual_example)
                                predicted_concepts = predicted_concepts_logits.max(dim=1)[1]

                                concepts_ground_truth.append(concept)
                                concepts_predictions.append(predicted_concepts)
                                ascores_counterfactuals.append(counterfactual_ascores.cpu())

            if len(counterfactual_example_batch) > 0:
                counterfactual_example_batch = torch.cat(counterfactual_example_batch)
                self.logger.log_batch_img(
                    tensor=counterfactual_example_batch, foldername='counterfactual_imgs', prefix=f'{it}_'
                )

            procbar.update()
        procbar.close()

        # -------------- CORRECTNESS ------------
        ascores_true_normals = torch.cat(ascores_true_normals)
        ascores_true_anomalies = torch.cat(ascores_true_anomalies)
        ascores_counterfactuals = torch.cat(ascores_counterfactuals)
        roc_norm_vs_counterfact = get_roc(ascores_true_normals, ascores_counterfactuals)
        roc_norm_vs_anomalous = get_roc(ascores_true_normals, ascores_true_anomalies)
        auc_norm_vs_counterfact_normalized = 1 - 2 * (roc_norm_vs_counterfact.auc - 0.5)

        self.logger.print(f"Average anomaly score for the normal samples: {ascores_true_normals.mean():.6f}")
        self.logger.print(f"Average anomaly score for the anomalous samples: {ascores_true_anomalies.mean():.6f}")
        self.logger.print(
            f"Average anomaly score for the counterfactual samples, generated from truly anomalous samples"
            f" {ascores_counterfactuals.mean():.6f}"
        )
        self.logger.print(f"AUC for normal samples vs. couterfactuals: {roc_norm_vs_counterfact.auc:0.3f}")
        self.logger.print(f"AUC for true normal vs. true anomalies: {roc_norm_vs_anomalous.auc:0.3f}")
        self.logger.print(
            f"AUC for true normal vs. couterfactuals, reversed and normalized "
            f"(usually in [0,1] with 1 being best at an AUC of 0.5, might move up to 2 when AUC is worse than 0.5):"
            f" {auc_norm_vs_counterfact_normalized:0.3f}"
        )
        self.logger.logjson('results', {
            'avg_ascore_normal': {f'cls_{int_set_to_str(cset)}': {
                f'seed_{seed}': ascores_true_normals.mean().item()}
            },
            'avg_ascore_anomalous': {f'cls_{int_set_to_str(cset)}': {
                f'seed_{seed}': ascores_true_anomalies.mean().item()}
            },
            'avg_ascore_counterfactual': {f'cls_{int_set_to_str(cset)}': {
                f'seed_{seed}': ascores_counterfactuals.mean().item()}
            },
            'auc_normal_anomalous': {f'cls_{int_set_to_str(cset)}': {f'seed_{seed}': roc_norm_vs_anomalous.auc}},
            'auc_normal_counterfactual': {f'cls_{int_set_to_str(cset)}': {f'seed_{seed}': roc_norm_vs_counterfact.auc}},
        }, update=True)
        self.logger.plot_many(
            [roc_norm_vs_anomalous, roc_norm_vs_counterfact],
            [f'normal_vs_anomalous', 'normal_vs_counterfactual'], f"rocs_{int_set_to_str(cset)}_seed{seed}",
            False
        )
        self.logger.hist_ascores(
            [ascores_true_normals, ascores_true_anomalies, ascores_counterfactuals, ],
            ['test_normal', 'test_anomalous', 'test_counterfactual', ],
            f'C{int_set_to_str(cset)}-{cstr}-S{seed}__anomaly_scores',
            colors=['green', 'red', 'blue'],
        )

        # -------------- REALISM ------------
        true_normal_path = pt.join(self.logger.dir, 'actual_imgs')
        cf_path = pt.join(self.logger.dir, 'counterfactual_imgs')
        anamalous_path = pt.join(self.logger.dir, 'anomalous_imgs')
        test_subset_1_path = pt.join(self.logger.dir, 'test_subset_1')
        test_subset_2_path = pt.join(self.logger.dir, 'test_subset_2')

        self.logger.print("Computing FID scores (normal, anomalous)")
        fid_score, fid_score_lower_bound, fid_score_upper_bound = compute_fid_scores(
            path_actual_imgs=true_normal_path, path_counterfactual_imgs=cf_path,
            path_anamalous_imgs=anamalous_path, path_test_subset_1=test_subset_1_path,
            path_test_subset_2=test_subset_2_path, device=ad_device, xtrainer=self,
            cstr=cstr, seed=seed
        )
        normalized_fid_score = 1 - (fid_score - fid_score_lower_bound) / (fid_score_upper_bound - fid_score_lower_bound)
        fid_score_paper = fid_score / fid_score_upper_bound

        self.logger.print(f"FID score for true normal samples vs. counterfactuals: {fid_score:.6f}")
        self.logger.print(f"FID score for true nomal samples vs. true anomalous (upper bound): {fid_score_upper_bound:.6f}")
        self.logger.print(f"FID score for a random half testset vs. the remaining half (lower bound): {fid_score_lower_bound:.6f}")
        self.logger.print(
            f"FID score for true normal samples vs. counterfactuals, "
            f"reversed and normalized with bounds "
            f"(in [0,1] as long as stays within bounds, 1 is best as it meets lower bound, "
            f"above 1 when below lower bound, below 0 when exceeds upper bound): "
            f"{normalized_fid_score:.6f}"
        )
        self.logger.print(f"FID score paper: {fid_score_lower_bound:.6f}")
        self.logger.logjson('results', {
            'fid_score_true_normal_vs_counterfactual': {f'cls_{int_set_to_str(cset)}': {
                f'seed_{seed}': fid_score}
            },
            'fid_score_true_normal_vs_anamalous': {f'cls_{int_set_to_str(cset)}': {
                f'seed_{seed}': fid_score_upper_bound}
            },
            'fid_score_first_half_vs_second_half': {f'cls_{int_set_to_str(cset)}': {
                f'seed_{seed}': fid_score_lower_bound}
            },
            'fid_score_true_normal_vs_counterfactual_normalized': {f'cls_{int_set_to_str(cset)}': {
                f'seed_{seed}': normalized_fid_score}
            },
            'fid_score_paper': {f'cls_{int_set_to_str(cset)}': {
                f'seed_{seed}': fid_score_paper}
            },
        }, update=True)

        # -------------- DISENTANGLEMENT ------------
        concepts_ground_truth = torch.cat(concepts_ground_truth).numpy()
        concepts_predictions = torch.cat(concepts_predictions).cpu().numpy()
        acc_concept_clf = accuracy_score(concepts_ground_truth, concepts_predictions)
        self.logger.print(f"Accuracy of the concept classifier over the generated samples: {acc_concept_clf:.6f}")
        self.logger.logjson('results', {
            'concept_clf_accuracy': {f'cls_{int_set_to_str(cset)}': {
                f'seed_{seed}': acc_concept_clf}
            },
        }, update=True)

        # -------------- SUBSTITUTABILITY ------------
        def reset_model(old_ad_model):
            model = deepcopy(old_ad_model)
            model.apply(weight_reset)
            return model

        sub_auc = None
        if ad_trainer_parent is not None:
            if isinstance(ds, CombinedDataset):
                train_normal_ids = ds.normal.train_set.indices
                train_ds = ds.normal.train_set.dataset
                train_subset = ds.normal.train_set
            else:
                train_normal_ids = ds.train_set.indices
                train_ds = ds.train_set.dataset
                train_subset = ds.train_set
            original_size = len(train_ds)
            if hasattr(train_ds, 'data'):
                attr = 'data'
            elif hasattr(train_ds, 'samples'):
                attr = 'samples'
            elif hasattr(train_ds, 'imgs'): 
                attr = 'imgs'
            elif hasattr(train_ds, 'images'):
                attr = 'images'
            else:
                attr = None
                self.logger.warning(
                    f"CF evaluation. Substitutability evaluation failed. "
                    f"The data in dataset {ds} has none of the known attribute names."
                )
            if attr is not None:
                cf_imgs = None
                if not hasattr(train_ds, 'targets'):
                    self.logger.warning(f"CF evaluation. Substitutability evaluation failed. Dataset {ds} has no targets!")
                if isinstance(train_ds.__getattribute__(attr), str):  # ?
                    self.logger.warning(f"CF evaluation. Substitutability evaluation failed. Dataset format not implemented!")
                elif isinstance(train_ds.__getattribute__(attr), np.ndarray): 
                    if isinstance(train_ds.__getattribute__(attr)[0], str):  # e.g., INN
                        cf_samples = []
                        train_ds.targets = np.asarray(train_ds.targets)
                        normal_label = train_ds.targets[train_normal_ids][0]
                        for img_path in os.listdir(cf_path):
                            cf_samples.append(pt.join(cf_path, img_path))
                        train_ds.__setattr__(
                            attr, np.concatenate([train_ds.__getattribute__(attr), cf_samples])
                        )
                        train_ds.__setattr__(
                            'imgs', np.concatenate([train_ds.__getattribute__(attr), cf_samples])  # imgs == samples...
                        )
                        train_ds.targets = np.concatenate([
                            train_ds.targets, train_ds.targets[train_normal_ids][0].repeat(len(cf_samples))
                        ])
                        cf_imgs = np.asarray(cf_samples)
                    else:  # e.g., CIFAR-10
                        cf_imgs = []
                        for img in os.listdir(cf_path):
                            cf_imgs.append(np.asarray(default_loader(pt.join(cf_path, img))))
                        cf_imgs = np.stack(cf_imgs)
                        if train_ds.__getattribute__(attr).ndim == 3:
                            cf_imgs = cf_imgs[:, :, :, 0]
                        if train_ds.__getattribute__(attr).dtype != np.uint8:
                            cf_imgs = cf_imgs.astype(np.float32) / 255
                        if cf_imgs.shape[1:] != train_ds.__getattribute__(attr).shape[1:]:
                            self.logger.warning(f"CF evaluation. Substitutability evaluation failed. Dataset data shape doesn't match!")
                            cf_imgs = None
                        else:
                            train_ds.__setattr__(
                                attr, np.concatenate([train_ds.__getattribute__(attr), cf_imgs])
                            )
                            train_ds.targets = np.asarray(train_ds.targets)
                            train_ds.targets = np.concatenate([
                                train_ds.targets, train_ds.targets[train_normal_ids][0].repeat(cf_imgs.shape[0])
                            ])
                elif isinstance(train_ds.__getattribute__(attr), torch.Tensor):  # e.g., MNIST
                    cf_imgs = []
                    for img in os.listdir(cf_path):
                        cf_imgs.append(to_tensor(default_loader(pt.join(cf_path, img))))
                    cf_imgs = torch.stack(cf_imgs)
                    if train_ds.__getattribute__(attr).dim() == 3:
                        cf_imgs = cf_imgs[:, 0, :, :]
                    if train_ds.__getattribute__(attr).dtype == torch.uint8:
                        cf_imgs = cf_imgs.mul(255).type(torch.uint8)
                    if cf_imgs.shape[1:] != train_ds.__getattribute__(attr).shape[1:]:
                        self.logger.warning(f"CF evaluation. Substitutability evaluation failed. Dataset data shape doesn't match!")
                        cf_imgs = None
                    else:
                        train_ds.__setattr__(
                            attr, torch.cat([train_ds.__getattribute__(attr), cf_imgs])
                        )
                        train_ds.targets = torch.cat([
                            train_ds.targets, train_ds.targets[train_normal_ids][0].repeat(cf_imgs.shape[0]).detach()
                        ])
                elif isinstance(train_ds.__getattribute__(attr), List):  # e.g., ImageNet.samples: List[Tuple[str, int]]
                    if (
                        not isinstance(train_ds.__getattribute__(attr)[0], Tuple) 
                        or not isinstance(train_ds.__getattribute__(attr)[0][0], str) 
                        or not isinstance(train_ds.__getattribute__(attr)[0][1], int)
                    ):
                        self.logger.warning(f"CF evaluation. Substitutability evaluation failed. Attribute type is unknown (samples, but not List[Tuple[str, int]])!")
                        cf_imgs = None
                    else:
                        cf_samples = []
                        train_ds.targets = np.asarray(train_ds.targets)
                        normal_label = train_ds.targets[train_normal_ids][0]
                        for img_path in os.listdir(cf_path):
                            cf_samples.append(((pt.join(cf_path, img_path)), normal_label))
                        train_ds.__setattr__(
                            attr, train_ds.__getattribute__(attr) + cf_samples
                        )
                        train_ds.__setattr__(
                            'imgs', train_ds.__getattribute__(attr) + cf_samples  # imgs == samples...
                        )
                        train_ds.targets = np.concatenate([
                            train_ds.targets, train_ds.targets[train_normal_ids][0].repeat(len(cf_samples))
                        ])
                        cf_imgs = np.asarray(cf_samples)
                else:
                    self.logger.warning(
                        f"CF evaluation. Substitutability evaluation failed. "
                        f"The data in dataset {ds}.{attr} is of unknown type."
                    )
                if cf_imgs is not None:
                    train_subset.indices = list(range(original_size, original_size + cf_imgs.shape[0]))
                    new_ad_model = reset_model(ad_model)
                    _, roc, train_labels, train_ascores = ad_trainer_parent.train_cls(
                        new_ad_model, ds, cset, cstr + "-CFSubstitution", seed, None, tqdm_mininterval, logging=False
                    )
                    roc_subst, prc, eval_labels, eval_ascores = ad_trainer_parent.eval_cls(
                        new_ad_model, ds, cset, cstr + "-CFSubstitution", seed, tqdm_mininterval, logging=False
                    )
                    auc_subst_normalized = (roc_subst.auc - 0.5) / (roc_norm_vs_anomalous.auc - 0.5)
                    self.logger.print(
                        f"AD test AuROC with normal training set substituted with counterfactuals: {roc_subst.auc:.6f}"
                    )
                    self.logger.print(
                        f"AD test AuROC, normalized, with normal training set substituted with counterfactuals "
                        f"(usually in [0, 1] with 1 being best, larger than 1 when substituted AUC exceeds original, "
                        f"becomes 0 when substituted AUC is 0.5, less than 0 when substituted AUC is less than 0.5): "
                        f"{auc_subst_normalized:.6f}"
                    )

                    # train_subset.indices = train_normal_ids + list(range(original_size, original_size + cf_imgs.shape[0]))
                    # new_ad_model = reset_model(ad_model)
                    # _, roc, train_labels, train_ascores = ad_trainer_parent.train_cls(
                    #     new_ad_model, ds, cset, cstr + "-CFExtension", seed, None, tqdm_mininterval, logging=False
                    # )
                    # roc_extnd, prc, eval_labels, eval_ascores = ad_trainer_parent.eval_cls(
                    #     new_ad_model, ds, cset, cstr + "-CFExtension", seed, tqdm_mininterval, logging=False
                    # )
                    # auc_extnd_normalized = (roc_extnd.auc - 0.5) / (roc_norm_vs_anomalous.auc - 0.5)
                    # self.logger.print(f"AD test AuROC with extended normal training set using counterfactuals: {roc_extnd.auc:.6f}")
                    # self.logger.print(
                    #     f"AD test AuROC, normalized, with normal training set extended with counterfactuals "
                    #     f"(usually in [0, 1] with 1 being best, larger than 1 when extended AUC exceeds original, "
                    #     f"becomes 0 when extended AUC is 0.5, less than 0 when extended AUC is less than 0.5): "
                    #     f"{auc_extnd_normalized:.6f}"
                    # )

                    sub_auc = roc_subst.auc
                    self.logger.logjson('results', {
                        'substitutability': {f'cls_{int_set_to_str(cset)}': {
                            f'seed_{seed}': {
                                'substituted': roc_subst.auc,
                                # 'extended': roc_extnd.auc,
                                'substituted_normalized': auc_subst_normalized,
                                # 'extended_normalized': auc_extnd_normalized
                            }}
                        },
                    }, update=True)

        self.logger.print("Cleaning temporarily saved generated images off the disk...")
        shutil.rmtree(true_normal_path)
        shutil.rmtree(cf_path)
        shutil.rmtree(anamalous_path)
        shutil.rmtree(test_subset_1_path)
        shutil.rmtree(test_subset_2_path)
        generator.train()
        concept_clf.train()

        jsonfile = self.logger.logjson('results', {
            'summary': {f'cls_{int_set_to_str(cset)}': {
                f'seed_{seed}': {
                    'fid_cf_div_anom': fid_score_paper,
                    'cf_vs_normal_auc': roc_norm_vs_counterfact.auc,
                    'substituted_ad_auc': sub_auc,
                    'concept_clf_acc': acc_concept_clf
                }
            }},
        }, update=True)

        with open(jsonfile, 'r') as reader:
            res = json.load(reader)
        criteria = ("fid_cf_div_anom", "cf_vs_normal_auc", "substituted_ad_auc", "concept_clf_acc")
        for crit in criteria:
            for ckey in res['summary']:
                if ckey in criteria:
                    continue
                res['summary'][ckey][crit] = np.mean([res['summary'][ckey][skey][crit] for skey in res['summary'][ckey] if skey not in criteria])
            res['summary'][crit] = np.mean([res['summary'][ckey][crit] for ckey in res['summary'] if ckey not in criteria])
        self.logger.logjson('results', res, update=True)

        return {
            cstr: {
                seed: dict(
                    roc_norm_vs_counterfact=roc_norm_vs_counterfact.auc, ascores_counterfactuals=ascores_counterfactuals,
                    ascores_true_normals=ascores_true_normals, ascores_true_anomalies=ascores_true_anomalies,
                    fid_score_true_normal_vs_counterfactual=fid_score, concept_clf_accuracy=acc_concept_clf,
                    auc_norm_vs_counterfact_normalized=auc_norm_vs_counterfact_normalized,
                    normalized_fid_score=normalized_fid_score,
                )
            }
        }
