import argparse
import json
import os
import os.path as pt
import sys
from argparse import ArgumentParser, Namespace
from typing import Callable, Tuple, Dict
from typing import List
from typing import Union, Any
from warnings import filterwarnings

import cv2
import numpy as np
import torch
import torch.nn as nn
from torchvision.transforms import Compose

from xad.counterfactual import XTRAINER_CHOICES
from xad.counterfactual.dissect_trainer import XTrainer
from xad.datasets import DS_CHOICES as IMG_DS_CHOICES
from xad.datasets import no_classes
from xad.training import TRAINER
from xad.training.ad_trainer import XADTrainer
from xad.utils.logger import Logger
from xad.utils.logger import SetupEncoder
from xad.models.bases import ConditionalGenerator, ConditionalDiscriminator, ConceptNN, ADNN
from xad.utils.training_tools import int_set_to_str, find_permutation_snapshot

filterwarnings(action='ignore', category=DeprecationWarning, module='torch')
cv2.setNumThreads(0)  # possible deadlock fix?


def default_argsparse(modify_descr: Callable[[str, ], str], modify_parser: Callable[[ArgumentParser], None] = None,
                      modify_args: Callable[[Namespace], None] = None) -> Namespace:
    """
    Creates and applies the argument parser for all training scripts.
    @param modify_descr: function that modifies the default description.
    @param modify_parser: function that modifies the default parser provided by this method.
        Can be used to, e.g., add further arguments or change the default values for arguments.
    @param modify_args: function that modifies the actual arguments retrieved by the parser.
    @return: the parsed arguments.
    """
    parser = ArgumentParser(
        description=modify_descr(
            "Iterates over a collection of sets of classes found in the dataset and multiple random seeds per class_set. "
            "For each class_set-seed combination, it trains and evaluates a given AD model and objective. "
            "All classes in the current set are considered normal. "
            "It always evaluates using the full test set. "
        )
    )
    parser.add_argument(
        '-ds', '--dataset', type=str, default=None, choices=IMG_DS_CHOICES.keys(),
        help="The dataset for which to train the AD model. All datasets have an established train and test split. "
             "During training, use only normal samples from this dataset. For testing, use all samples. "
    )
    parser.add_argument(
        '-oe', '--oe-dataset', type=str, default=None, choices=set(IMG_DS_CHOICES.keys()).union({'none', }), nargs='+',
        help="Optional Outlier Exposure (OE) datasets. If given, concatenate an equally sized batch of random "
             "samples from these datasets to the batch of normal training samples from the main dataset. "
             "These concatenated samples are used as auxiliary training anomalies. To train a semi-supervised model,  "
             "set -oe to the main dataset (e.g., -ds 3dshapes -oe 3dshapes), in which case the OE dataset "
             "will consist of the non-normal classes. Usually you should combine this with --oe-size. "
    )
    parser.add_argument(
        '--oe-size', type=int, default=np.infty,
        help="Optional. If given, uses a random subset of the OE dataset as OE with the subset having the provided size. "
    )
    parser.add_argument(
        '--oe-classes', nargs="*", type=str, default=None,
        help="Optional. Defines a list of a set of classes that are used from the OE dataset. "
             "The classes in each set are separated by '+'. E.g., '2+4 5+3' trains first an AD model "
             "with an OE set made of class 2 and 4 from the OE data and then with 1 and 5."
             "Defaults to all available classes of the given OE dataset. "
    )
    parser.add_argument(
        '-b', '--batch-size', type=int, default=200,
        help="The batch size. If there is an OE dataset, the overall batch size will be twice as large as an equally "
             "sized batch of OE samples gets concatenated to the batch of normal training samples."
    )
    parser.add_argument(
        '-e', '--epochs', type=int, default=50,
        help="How many full iterations of the dataset are to be trained per class-seed combination."
    )
    parser.add_argument(
        '-lr', '--learning-rate', type=float, default=1e-3,
        help="The initial learning rate."
    )
    parser.add_argument(
        '-wdk', '--weight-decay', type=float, default=1e-4,
        help="The weight decay."
    )
    parser.add_argument(
        '--milestones', type=int, nargs='+', default=[],
        help="Milestones for the learning rate scheduler; at each milestone the learning rate is reduced by 0.1."
    )
    parser.add_argument(
        '-o', '--objective', type=str, default='hsc', choices=TRAINER.keys(),
        help="This defines the objective with which the AD models are trained. It determines both the loss and anomaly score."
             "Some objective may require certain network architectures (e.g., autoencoders). "
    )
    parser.add_argument(
        '--classes', type=str, nargs='+', default=None, metavar='INTLIST',
        help='Defines a list of a set of classes that are normal. The classes in each set are separated by "+".'
             'E.g., "5+6 1+5" trains first an AD model for the classes 5 and 6 being normal and then for 1 and 5 being normal.'
             'Defaults to a list of all available classes of the given dataset, '
             'which makes the code train one AD model per class where this class is treated normal. '
    )
    parser.add_argument(
        '-d', '--devices', type=int, nargs='+', metavar='GPU-ID', default=None,
        help="Which device to use for training. AD and DISSECT run with one GPU. Training DiffEdit requires two GPUs with each"
             "at least 40GB VRAM."
             "CPU training will be very slow. "
    )
    parser.add_argument(
        '-it', '--iterations', type=int, default=2,
        help="The number of iterations for each class with different random seeds. "
    )
    parser.add_argument(
        '--continue-run', type=str, metavar='FILE-PATH', default=None,
        help="Optional. If provided, needs to be a path to a logging directory of a previous experiment. "
             "Load the configuration from the previous experiment where available. Some configurations need to be matched "
             "manually, such as the data transformation pipelines. "
             "If the same default configuration is used (e.g., `train_cifar10.py`), no matching will be required. "
             "Then, load the model snapshots and training state and continue the experiment. "
             "The trainer will start with reevaluating all completed classes and seeds, "
             "which should yield the same metrics again. "
             "For unfinished and not-yet-started class-seed combinations, train and evaluate as usual. "
             "Create a new logging directory by concatenating `---CNTD` to the old directory name. "
    )
    parser.add_argument(
        '--comment', type=str, default='',
        help="Optional. This string will be concatenated to the default logging directory name, which is `log_YYYYMMDDHHMMSS`, "
             "where the latter is the current datetime."
    )
    parser.add_argument(
        '--superdir', type=str, default=".",
        help='Optional. If this run does not continue a previous run, the script will create a new '
             'logging directory `xad/data/results/log_YYYYMMDDHHMMSS`. `--superdir` will change this to '
             '`xad/data/results/SUPERDIR/log_log_YYYYMMDDHHMMSS`, where SUPERDIR is the string given via this argument. '
    )
    parser.add_argument(
        '--fix-seeds', type=int, nargs='+', default=None,
        help='Fixes the initial random torch seeds per class and iteration.'
    )
    parser.add_argument(
        '--train-split', type=float, default=0.0,
        help='Defines a ratio of samples that will be split from the training (normal and/or OE) data for further validation.'
    )
    parser.add_argument(
        '--tqdm-mininterval', type=float, default=0.1,
        help='Minimum progress display update interval [default: 0.1] seconds.'
    )
    if modify_parser is not None:
        modify_parser(parser)
    args = parser.parse_args()
    if args.devices is None:
        args.devices = list(range(torch.cuda.device_count()))
    if args.oe_dataset is not None:
        args.oe_dataset = [ds for ds in args.oe_dataset if ds != 'none']
    if args.fix_seeds is not None and args.continue_run is not None:
        print("WARNING: --fix-seeds and --continue-run may not work together. The data loader's state is not "
              "saved, and thus the latest class-seed will continue training at epoch x while the fixed seed will make"
              "epoch x have the random state of epoch 1 in the original training!", file=sys.stderr)
    if modify_args is not None:
        modify_args(args)
    if args.classes is not None:
        args.classes = [
            set([int(c) for c in c_set.split('+')]) if not isinstance(c_set, int) else set([c_set])
            for c_set in args.classes
        ]
    else:
        args.classes = [{i} for i in range(no_classes(args.dataset))]
    if args.oe_classes is not None:
        if len(args.oe_dataset) > 1:
            raise ValueError('Explicitly specifying OE class sets is not supported for multiple OE sets atm.')
        args.oe_classes = [
            set([int(c) for c in c_set.split('+')]) if not isinstance(c_set, int) else set([c_set])
            for c_set in args.oe_classes
        ]
    if args.oe_classes is not None and len(args.oe_classes) != len(args.classes):
        raise ValueError(
            f"The number of the OE class sets ({len(args.oe_classes)}) does "
            f"not match the number of the normal class sets ({len(args.classes)})!"
        )
    if args.continue_run is not None and args.objective == 'dsvdd':
        raise NotImplementedError("As of now the center of DSVDD is not stored in the snapshot!")
    return args


def counterfactual_argsparse(modify_descr: Callable[[str, ], str], modify_parser: Callable[[ArgumentParser], None] = None,
                             modify_args: Callable[[Namespace], None] = None):
    """ adds counterfactual-training-specific arguments to the default parser """
    def combined_parser_modify(parser):
        group = parser.add_argument_group('counterfactual')
        group.add_argument(
            '--x-ad-load', type=str, default=None,
            help="Optional. If provided, needs to point to a directory with AD model snapshots. "
                 "The filenames need to be `snapshot_cls{CLS_ID}_it{SEED_ID}.pt`. In contrast to `--continue`, "
                 "this argument will just load the AD models and start the counterfactual training using those "
                 "without attempting to continue training the AD models. Note that you need to manually match the "
                 "network architectures."
        )
        group.add_argument(
            '--x-eval-snapshot', type=str, metavar='FILE-PATH', default=None,
            help="Optional. If provided, needs to be a path to a counterfactual generator snapshot in LOGDIR/counterfactual/snapshots/ "
                "for LOGDIR being the logging directory of a previous experiment. "
                "Load the configuration from this previous experiment where available. Some configurations need to be matched "
                "manually, such as the data transformation pipelines. "
                "If the same default configuration is used (e.g., `train_cifar10.py`), no matching will be required. "
                "Then, load the specified model snapshots and just run the evaluation. "
                "Create a new logging directory by concatenating `---EVAL` to the old directory name and creates a subdir with the "
                "name of the generator snapshot."
        )
        group.add_argument(
            '--x-concepts', type=int, default=2,
            help="The number of concepts the generator can be conditioned on. "
        )
        group.add_argument(
            '-xb', '--x-batch-size', type=int, default=None,
            help="The batch size for counterfactual training. Defaults to `--batch-size`."
                 "If there is an OE dataset, the overall batch size will be twice as large as an equally "
                 "sized batch of OE samples gets concatenated to the batch of normal training samples. "
        )
        group.add_argument(
            '-xe', '--x-epochs', type=int, default=50,
            help="How many full iterations of the dataset are to be trained per counterfactual training."
        )
        group.add_argument(
            '-xlr', '--x-learning-rate', type=float, default=1e-3,
            help="The initial learning rate of the counterfactual training."
        )
        group.add_argument(
            '-xwdk', '--x-weight-decay', type=float, default=1e-4,
            help="The weight decay of the counterfactual training."
        )
        group.add_argument(
            '--x-milestones', type=int, nargs='+', default=[40],
            help="Milestones for the learning rate scheduler of the counterfactual trainer; "
                 "at each milestone the learning rate is reduced by 0.1. "
                 "In case of DiffEdit, the milestones can be floats. "  
        )
        group.add_argument(
            '--x-milestone-alpha', type=float, default=0.1,
            help="Milestone alpha for the learning rate scheduler of the counterfactual trainer; "
                 "at each milestone the learning rate is reduced by alpha." 
        )
        group.add_argument(
            '--x-method', type=str, choices=XTRAINER_CHOICES, default="dissect",
            help="Which CE generation method to use."
        )
        group.add_argument(
            '--x-gen-every', type=int, default=5, help='Optimize generator every nth batch.'
        )
        group.add_argument(
            '--x-disc-every', type=int, default=1, help='Optimize discriminator every nth batch.'
        )
        group.add_argument(
            '--x-lamb-conc', type=float, default=1, help='Constant factor for concept disentanglement losses.'
        )
        group.add_argument(
            '--x-lamb-gen', type=float, default=1, help='Constant factor for generic generator loss.'
        )
        group.add_argument(
            '--x-lamb-asc', type=float, default=1, help='Constant factor for anomaly score deviation losses.'
        )

        dissect_group = parser.add_argument_group('DISSECT')
        dissect_group.add_argument(
            '--x-normal-training-only', action='store_true',
            help="Whether to use only normal training samples for the counterfactual training instead of using "
                 "normal and OE samples."
        )
        dissect_group.add_argument(
            '--x-lamb-cyc', type=float, default=1, help='Constant factor for reconstruction losses.'
        )
        dissect_group.add_argument(
            '--x-cluster-ncc', action='store_true',
            help="Whether to use add an additional regularizer for disentanglement that uses an NCC with "
                 "the centers obtained from a kmeans on the AD features of all normal training data."
        )
        dissect_group.add_argument(
            '--x-discrete-anomaly-scores', type=int, default=5,
            help="The number of discrete anomaly scores the generator can be conditioned on. "
        )

        diffedit_group = parser.add_argument_group('DiffEdit')
        diffedit_group.add_argument(
            '--x-lamb-dist', type=float, default=1e-3,
            help="Constant factor for distance loss between DiffEdit's initial "
                 "suggestion and the tuned version for generating disentangled CEs."
        )
        diffedit_group.add_argument(
            '--x-mask-encode-strength', type=float, default=0.4,
            help="The noise strength for generating the mask with DiffEdit."
        )
        diffedit_group.add_argument(
            '--x-mask-thresholding-ratio', type=float, default=2.0,
            help="The threshold for binarizing the mask with DiffEdit."
        )
        diffedit_group.add_argument(
            '--x-diffusion-inference-steps', type=int, default=40,
            help="The number of inference steps for the diffusion model."
        )
        diffedit_group.add_argument(
            '--x-diffusion-resolution', type=int, default=512,
            help="The resolution of the images used for the diffusion model."
        )
        if modify_parser is not None:
            modify_parser(parser)

    def combined_args_modify(args):
        if args.x_batch_size is None:
            args.x_batch_size = args.batch_size
        if args.continue_run is not None and args.x_ad_load is not None:
            raise ValueError('--continue-run and --x-ad-load are mutually exclusive.')
        if args.x_eval_snapshot is not None and (args.x_ad_load is not None or args.continue_run is not None):
            raise ValueError('--x-eval-snapshot, --continue-run, and --x-ad-load are mutually exclusive.')
        # if args.continue_run is not None and not pt.exists(args.continue_run):
        #     raise ValueError(f'--continue-run does not exist ({args.continue_run})')
        # if args.x_ad_load is not None and not pt.exists(args.x_ad_load):
        #     raise ValueError(f'--x-ad-load does not exist ({args.x_ad_load})')
        # if args.x_eval_snapshot is not None and not pt.exists(args.x_eval_snapshot):
        #     raise ValueError(f'--x-eval-snapshot does not exist ({args.x_eval_snapshot})')
        if args.x_eval_snapshot is not None:
            if args.x_eval_snapshot.split(os.sep)[-2] != "snapshots":
                raise ValueError(f'--x-eval-snapshot is not in subfolder "snapshots/" ({args.x_eval_snapshot})')
            if args.x_eval_snapshot.split(os.sep)[-3] != "counterfactual":
                raise ValueError(f'--x-eval-snapshot is not in subfolder "counterfactual/snapshots/" ({args.x_eval_snapshot})')
            if "_generator_" not in args.x_eval_snapshot:
                raise ValueError(f'--x-eval-snapshot seems not to be a generator snapshot ({args.x_eval_snapshot})')
            if args.objective == "dsvdd":
                raise NotImplementedError("As of now the center of DSVDD is not stored in the snapshot!")
        if modify_args is not None:
            modify_args(args)

    return default_argsparse(modify_descr, combined_parser_modify, combined_args_modify)


def create_trainer(trainer: str, comment: str, dataset: str, oe_dataset: List[str], epochs: int,
                   lr: float, wdk: float, milestones: List[int], batch_size: int,
                   gpus: List[int], model: ADNN, train_transform: Compose, val_transform: Compose,
                   oe_limit_samples: Union[int, List[int]] = np.infty, oe_limit_classes: int = np.infty,
                   logpath: str = None, xtrainer: XTrainer = None, **kwargs) -> XADTrainer:
    """
    This simply parses its parameters to create the correct trainer defined by the `trainer` str that defines the
    objective for the trainer. It also sets some additional parameters such as the datapath that defaults to
    `xad/data` and creates a logger for the trainer. Returns the created trainer.
    For a description of the parameters have a look at :class:`xad.training.ad_trainer.XADTrainer`.
    """
    datapath = pt.abspath(pt.join(__file__, '..', '..', '..', '..', 'data'))
    kwargs = dict(kwargs)
    superdir = kwargs.pop('superdir', '.')
    continue_run = kwargs.pop('continue_run', None)
    x_eval_snapshot = kwargs.pop('x_eval_snapshot', None)

    if continue_run is None and x_eval_snapshot is None:
        logger = Logger(pt.join(datapath, 'results', superdir) if logpath is None else logpath, comment)
    elif continue_run is not None:
        cntd_dir = (continue_run + '---CNTD') if not continue_run.endswith(os.sep) else (continue_run[:-1] + '---CNTD')
        logger = Logger(cntd_dir, noname=True)
    else:
        continue_run = os.sep.join(x_eval_snapshot.split(os.sep)[:-3])
        subdir = x_eval_snapshot.split(os.sep)[-1][:-3]
        eval_snp_dir = pt.join(
            ((continue_run + '---EVAL') if not continue_run.endswith(os.sep) else (continue_run[:-1] + '---EVAL')),
            subdir
        )
        logger = Logger(eval_snp_dir, noname=True)

    if xtrainer is not None:
        xtrainer.logger = Logger(pt.join(logger.dir, 'counterfactual'), noname=True)

    trainer = TRAINER[trainer](
        model, train_transform, val_transform, dataset, oe_dataset, pt.join(datapath, 'datasets'), logger,
        epochs, lr, wdk, milestones, batch_size, torch.device(gpus[0]),
        oe_limit_samples, oe_limit_classes, xtrainer=xtrainer, **kwargs
    )
    return trainer


def create_xtrainer(trainer: str, xmodels: List[nn.Module],
                    n_concepts: int, epochs: int, lr: float, wdk: float,
                    milestones: List[int], batch_size: int, **kwargs) -> XTrainer:
    """
    Same as :func:`create_trainer` but for counterfactual training.
    Doesn't create a logger since logger is set by :func:`create_trainer`.
    """
    xtrainer = XTRAINER_CHOICES[trainer](
        xmodels, n_concepts, epochs, lr, wdk, milestones,
        batch_size, **kwargs
    )
    return xtrainer


def create_trainers_from_args(args: argparse.Namespace, xmodels: List[nn.Module],
                              model: ADNN, train_transform: Compose, val_transform: Compose,
                              continue_run: str = None, x_eval_snapshot: str = None) -> Tuple[XTrainer, XADTrainer]:
    print('Trainers created with:\n', vars(args))
    xtrainer = create_xtrainer(
        args.x_method, xmodels, args.x_concepts,
        args.x_epochs, args.x_learning_rate, args.x_weight_decay, args.x_milestones, args.x_batch_size,
        devices=args.devices, milestone_alpha=args.x_milestone_alpha, 
        lamb_gen=args.x_lamb_gen, lamb_asc=args.x_lamb_asc, lamb_conc=args.x_lamb_conc,
        gen_every=args.x_gen_every, disc_every=args.x_disc_every,
        # DISSECT
        oe=not args.x_normal_training_only, 
        lamb_cyc=args.x_lamb_cyc, 
        cluster_ncc=args.x_cluster_ncc, n_discrete_anomaly_scores=args.x_discrete_anomaly_scores,
        # DiffEdit
        lamb_dist=args.x_lamb_dist, mask_encode_strength=args.x_mask_encode_strength,
        mask_thresholding_ratio=args.x_mask_thresholding_ratio, diffusion_inference_steps=args.x_diffusion_inference_steps,
        diffusion_resolution=args.x_diffusion_resolution, 
        additive_gen=not vars(args).get('x_gen_non_additive', False),
        gen_use_mask=not vars(args).get("x_gen_ignore_mask", False),
    )
    trainer = create_trainer(
        args.objective, args.comment, args.dataset, args.oe_dataset, args.epochs, args.learning_rate, args.weight_decay,
        args.milestones, args.batch_size, args.devices, model, train_transform, val_transform,
        oe_limit_samples=args.oe_size, continue_run=continue_run, superdir=args.superdir,
        xtrainer=xtrainer, train_split=args.train_split, oe_limit_classes=args.oe_classes,
        x_eval_snapshot=x_eval_snapshot
    )
    return xtrainer, trainer


def load_setup(path: str, args: Namespace, check_train_transform: Compose,
               check_val_transform: Compose) -> Tuple[List[List[str]], str]:
    """
    Loads the setup/configuration from given path, including all model snapshots.
    Can be used to repeat or continue a previous experiment.
    @param path: the path to the logging directory of the experiment from which the configuration is to be loaded.
    @param args: the args namespace where the setup is to be loaded to.
    @param check_train_transform: since the transforms cannot be automatically loaded,
        check if their logged string representation matches this parameter's string representation.
    @param check_val_transform: since the transforms cannot be automatically loaded,
        check if their logged string representation matches this parameter's string representation.
    @return: a tuple of
        - a list (len = #classes) with each element again being a list (len = #seeds) of filepaths to model snapshots.
          Some can be None. The snapshots may also contain the training state such as the last epoch trained.
        - the path from which the configuration was loaded.
    """
    if path is None:
        return None, None
    elif path.startswith('sftp://'):  # 7 chars
        path = path[7:][path[7:].index('/'):]  # sft://foo@bar.far.com/PATH -> /PATH
    print(f'Load setup from {path}')
    with open(pt.join(path, 'setup.json'), 'r') as reader:
        setup = json.load(reader)
    assert [x.replace("'normalize'", 'normalize') for x in json.loads(json.dumps(check_train_transform, cls=SetupEncoder))] == \
           setup.pop('train_transform'), \
           f'The loaded train transformation string representation does not match the set one. Please match manually. '
    assert [x.replace("'normalize'", 'normalize') for x in json.loads(json.dumps(check_val_transform, cls=SetupEncoder))] == \
           setup.pop('test_transform'), \
           f'The loaded test transformation string representation does not match the set one. Please match manually. '
    setup_load: List[List[str]] = setup.pop('load')  # len = #class_sets, each element has len = #seeds of model snapshots.
    assert setup_load is None or all([isinstance(seed, str) or seed is None for cls in setup_load for seed in cls])
    assert setup.pop('dataset') == args.dataset, \
        f'It seems like the set dataset ({args.dataset}) is not the one found in the loaded experiment. Please match manually. '
    assert f'_{args.objective}_' in path, \
        f'It seems like the set objective ({args.objective}) is not the one found in the loaded experiment. ' \
        f'Please match manually. '
    args.oe_dataset = setup.pop('oe_dataset')
    args.epochs = setup.pop('epochs')
    args.learning_rate = setup.pop('lr')
    args.weight_decay = setup.pop('wdk')
    args.milestones = setup.pop('milestones')
    args.batch_size = setup.pop('batch_size')
    assert setup.pop('ad_mode', 'ovr') == 'ovr'  # legacy
    args.oe_size = setup.pop('oe_limit_samples', np.infty)
    args.oe_classes = setup.pop('oe_limit_classes', None)
    args.model = setup.pop('model', None)
    run_classes = setup.pop('run_classes')
    args.classes = [eval(cset) for cset in run_classes] if run_classes is not None else None
    args.iterations = setup.pop('run_seeds')
    # args.fix_seeds = setup.pop('seeds')
    setup.pop('seeds')
    args.train_split = setup.pop('train_split', 0.0)
    setup.pop('workers')
    setup.pop('device')
    setup.pop('datapath')
    setup.pop('logger')
    assert len(setup) == 0, f'There are unexpected arguments in the loaded setup: {setup.keys()}.'
    classes = args.classes if args.classes is not None else [{c} for c in range(no_classes(args.dataset))]
    snapshots = []
    for cid, cset in enumerate(classes):
        snapshots.append([])
        for i in range(args.iterations):
            snapshot = pt.join(path, 'snapshots', f'snapshot_cls{int_set_to_str(cset)}_it{i}.pt')
            snapshot = find_permutation_snapshot(snapshot, cset)
            if not pt.exists(snapshot):
                snapshot = None
                if setup_load is not None and setup_load[cid][i] is not None:
                    snapshot = setup_load[cid][i]
            snapshots[-1].append(snapshot)
    return snapshots, path


def load_counterfactual_setup(path: str, args: Namespace, check_train_transform: Compose,
                              check_val_transform: Compose, ) -> Tuple[List[List[str]], List[List[str]], str]:
    """
    Loads the setup/configuration from given path, including all model snapshots.
    Can be used to repeat or continue a previous experiment.
    @param path: the path to the logging directory of the experiment from which the configuration is to be loaded.
    @param args: the args namespace where the setup is to be loaded to.
    @param check_train_transform: since the transforms cannot be automatically loaded,
        check if their logged string representation matches this parameter's string representation.
    @param check_val_transform: since the transforms cannot be automatically loaded,
        check if their logged string representation matches this parameter's string representation.
    @return: a tuple of
        - a list (len = #classes) with each element again being a list (len = #seeds) of filepaths to model snapshots.
          Some can be None. The snapshots may also contain the training state such as the last epoch trained.
        - the path from which the configuration was loaded.
    """
    if path is None:
        return None, None, None
    elif path.startswith('sftp://'):  # 7 chars
        path = path[7:][path[7:].index('/'):]  # sft://foo@bar.far.com/PATH -> /PATH
    snapshots, path = load_setup(path, args, check_train_transform, check_val_transform)
    counterfactual_path = pt.join(path, 'counterfactual', 'setup.json')
    if not pt.exists(counterfactual_path):
        return snapshots, path
    with open(counterfactual_path, 'r') as reader:
        setup = json.load(reader)
    setup_load: Dict[str, Dict[int, str]] = setup.pop('load')  # maps class_set_str_representation -> iteration -> snapshot path
    assert setup_load is None or all([isinstance(seed, str) or seed is None for cls in setup_load for seed in cls])
    args.x_epochs = setup.pop('x_epochs')
    args.x_learning_rate = setup.pop('x_lr')
    args.x_weight_decay = setup.pop('x_wdk')
    args.x_milestones = setup.pop('x_milestones')
    args.x_batch_size = setup.pop('x_batch_size')
    args.x_concepts = setup.pop('x_n_concepts')
    args.x_milestone_alpha = safe_pop(setup, "x_milestone_alpha", default=None) or safe_pop(
        setup, 'milestone_alpha', default=args.x_milestone_alpha
    )
    if any([f"_{xmthd}_" in path for xmthd in XTRAINER_CHOICES.keys()]):
        assert f'_{args.x_method}_' in path, \
            f'It seems like the set x_method ({args.x_method}) is not the one found in the loaded experiment. ' \
            f'Please match manually. '
    else:
        print(
            f"Could not verify whether the set x_method ({args.x_method}) matches the loaded experiment. Please verify manually!", 
            file=sys.stderr
        )
    args.x_gen_every = safe_pop(setup, 'x_gen_every', default=None) or safe_pop(setup, 'gen_every')
    args.x_disc_every = safe_pop(setup, 'x_disc_every', default=None) or safe_pop(setup, 'disc_every') 
    args.x_lamb_conc = safe_pop(setup, 'x_lamb_conc', default=None) or safe_pop(setup, 'lamb_conc')
    args.x_lamb_gen = safe_pop(setup, 'x_lamb_gen', default=None) or safe_pop(setup, 'lamb_gen')
    args.x_lamb_asc = safe_pop(setup, 'x_lamb_asc', default=None) or safe_pop(setup, 'lamb_asc')
    setup.pop('x_discriminator', None)  # is hard-coded, cannot be checked; legacy
    setup.pop('x_xmodels', None)  # is hard-coded,
    # DISSECT
    args.x_normal_training_only = not safe_pop(setup, 'x_oe')
    args.x_lamb_cyc = safe_pop(setup, 'x_lamb_cyc', default=None) or safe_pop(setup, 'lamb_cyc')
    args.x_cluster_ncc = safe_pop(setup, 'x_cluster_ncc', default=None) or safe_pop(
        setup, 'cluster_ncc', default=False
    )
    args.x_discrete_anomaly_scores = safe_pop(
        setup, 'x_n_discrete_anomaly_scores', default=None) or safe_pop(setup, 'n_discrete_anomaly_scores'
    )
    # DiffEdit
    args.x_lamb_dist = safe_pop(setup, "x_lamb_dist", default=None) or safe_pop(
        setup, 'lamb_dist', default=args.x_lamb_dist
        )
    args.x_mask_encode_strength = safe_pop(setup, "x_mask_encode_strength", default=None) or safe_pop(
        setup, 'mask_encode_strength', default=args.x_mask_encode_strength
    )
    args.x_thresholding_ratio = safe_pop(setup, "x_mask_thresholding_ratio", default=None) or safe_pop(
        setup, 'mask_thresholding_ratio', default=args.x_mask_thresholding_ratio
    )
    args.x_diffusion_inference_steps = safe_pop(setup, "x_diffusion_inference_steps", default=None) or safe_pop(
        setup, 'diffusion_inference_steps', default=args.x_diffusion_inference_steps
    )
    args.x_diffusion_resolution = safe_pop(setup, "x_diffusion_resolution", default=None) or safe_pop(
        setup, 'diffusion_resolution', default=args.x_diffusion_resolution
    )
    args.x_devices = safe_pop(setup, "x_devices", default=None) or safe_pop(
        setup, 'devices', default=args.x_devices
    )
    args.x_gen_non_additive = not (safe_pop(setup, "x_additive_gen", default=None) or safe_pop(
        setup, 'additive_gen', default=args.x_gen_non_additive
    ))
    args.x_gen_ignore_mask = not (safe_pop(setup, "x_gen_use_mask", default=None) or safe_pop(
        setup, 'gen_use_mask', default=args.x_gen_non_additive
    ))
    # Snapshots etc.
    kwargs_left = setup.pop('kwargs', {})
    assert len(kwargs_left) == 0, f"There are unexpected arguments in the loaded setup's kwargs {kwargs_left.keys()}."
    x_kwargs_left = setup.pop('x_kwargs', {})
    assert len(x_kwargs_left) == 0, f"There are unexpected arguments in the loaded setup's x_kwargs {x_kwargs_left.keys()}."
    setup.pop('x_logger')
    assert len(setup) == 0, f'There are unexpected arguments in the loaded setup: {setup.keys()}.'
    classes = args.classes if args.classes is not None else [{c} for c in range(no_classes(args.dataset))]
    xsnapshots = []
    for cid, cset in enumerate(classes):
        xsnapshots.append([])
        for i in range(args.iterations):
            generator_snapshot = pt.join(
                path, 'counterfactual', 'snapshots', f'snapshot_generator_cls{int_set_to_str(cset)}_it{i}.pt'
            )
            generator_snapshot = find_permutation_snapshot(generator_snapshot, cset)
            if not pt.exists(generator_snapshot):  # maybe there's an intermediate snapshot without epit (legacy?)
                generator_snapshot = pt.join(
                    path, 'counterfactual', 'snapshots', f'intermediate_snapshot_generator_cls{int_set_to_str(cset)}_it{i}.pt'
                )
                generator_snapshot = find_permutation_snapshot(generator_snapshot, cset)
                if not pt.exists(generator_snapshot):  # maybe there's an intermediate snapshot with epit
                    intermediate_generator_snapshots = sorted([
                        snapshot_file for snapshot_file in os.listdir(pt.join(path, 'counterfactual', 'snapshots', )) 
                        if snapshot_file.endswith(f'_generator_cls{int_set_to_str(cset)}_it{i}.pt') 
                        and snapshot_file.startswith('intermediate_epit')
                    ], key=lambda x: int(x[len('intermediate_epit'):-len(f'_snapshot_generator_cls{int_set_to_str(cset)}_it{i}.pt')]))
                    if len(intermediate_generator_snapshots) >= 1:
                        generator_snapshot = pt.join(
                            path, 'counterfactual', 'snapshots', intermediate_generator_snapshots[-1]
                        )
                    else:
                        generator_snapshot = None
                        if setup_load is not None and setup_load.get(int_set_to_str(cset), {}).get(i, None) is not None:
                            generator_snapshot = setup_load.get(int_set_to_str(cset), {}).get(i, None)
            xsnapshots[-1].append(generator_snapshot)
    return snapshots, xsnapshots, path

def safe_pop(setup: dict, item: str, **kwargs) -> Any:
    try:
        return setup.pop(item)
    except KeyError as err:
        try:
            return setup.get("kwargs", {}).pop(item)
        except KeyError:
            try:
                return setup.get("x_kwargs", {}).pop(item)
            except KeyError:
                if 'default' in kwargs:
                    return kwargs.get('default')
                else:
                    raise err
    