import os
os.environ['OPENBLAS_NUM_THREADS'] = '1'
import argparse
import os.path as pt
from typing import Callable, Tuple, List

import torch.nn as nn
import numpy as np
from torchvision.transforms import Compose

from xad.main import counterfactual_argsparse, load_counterfactual_setup, create_trainers_from_args
from xad.main import load_setup
from xad.models.bases import ADNN, ConditionalGenerator, ConditionalDiscriminator, ConceptNN
from xad.training.ad_trainer import XADTrainer
from xad.utils.training_tools import int_set_to_str, str_to_int_set


def string_format_ignore_missing(s: str, raise_err_if_not_existing=False, **kwargs) -> str:
    missing = {}
    for i in range(1000):
        try:
            if raise_err_if_not_existing:
                for k in kwargs:
                    if f'{{{k}}}' not in s:
                        raise ValueError(f"Key '{k}' was not found in string '{s}'")
            return s.format(**kwargs, **missing)
        except KeyError as err:
            if i >= 1000:
                raise err
            else:
                arg = str(err).removeprefix("'").removesuffix("'")
                missing[arg] = "{" + arg + "}"


def default_argparse_doc_head(script_path: str) -> Callable:
    return lambda s: (f"{s} This specific script comes with a default configuration for "
                      f"{os.path.basename(script_path)[:-3].replace('train_', '')}.")


def default_comment(script_path: str) -> str:
    return '{obj}_' + os.path.basename(script_path)[:-3].replace('train_', '') + '_{OE}OE_{CFOE}_{XMTHD}_{NormCls}'


def default_comment_format(args: argparse.Namespace) -> dict[str, str]:
    return dict(
        obj=args.objective, oelimit=f'_OE{args.oe_size}' if args.oe_size < np.infty else '',
        OE=f"{'+'.join(args.oe_dataset)}{'ClassRestricted' if args.oe_classes is not None else ''}",
        CFOE=f'NoCFOE' if args.x_normal_training_only else 'CFOE',
        NormCls="Norm-" + "-".join([int_set_to_str(clsset) for clsset in args.classes]),
        XGEN=f'{args.x_lamb_gen}', XASC=f'{args.x_lamb_asc}', XCYC=f'{args.x_lamb_cyc}',
        XCONC=f'{args.x_lamb_conc}', XEP=f'{args.x_epochs}', IT=f'{args.iterations}',
        XLR=f'{args.x_learning_rate:.0e}', XWDK=f'{args.x_weight_decay:.0e}',
        EP=f'{args.epochs}',
        XMASKST=f'{args.x_mask_encode_strength}', XDDIST=f'{args.x_lamb_dist}',
        XMASKTH=f'{args.x_mask_thresholding_ratio}', XDINF=f'{args.x_diffusion_inference_steps}',
        XDRES=f'{args.x_diffusion_resolution}', XMTHD=f'{args.x_method}',
    )


def main(modify_parser: Callable[[argparse.ArgumentParser], None], get_transforms: Callable[[], Tuple[Compose, Compose]],
         get_models: Callable[[argparse.Namespace], List[nn.Module]],
         modify_args: Callable[[argparse.Namespace], None] = None,
         run=True, xtraining=True) \
        -> Tuple[Tuple[List[List[ADNN]], dict], XADTrainer, argparse.Namespace]:
    args = counterfactual_argsparse(default_argparse_doc_head(__file__), modify_parser, modify_args)
    args.comment = args.comment.format(**default_comment_format(args))

    train_transform, val_transform = get_transforms()
    snapshots, xsnapshots, continue_run = load_counterfactual_setup(args.continue_run, args, train_transform, val_transform)
    if args.x_ad_load is not None:
        snapshots, _ = load_setup(args.x_ad_load, args, train_transform, val_transform)
        args.epochs = 0
    if args.x_eval_snapshot is not None:
        exp_dir = os.sep.join(args.x_eval_snapshot.split(os.sep)[:-3])
        specific_snapshot = args.x_eval_snapshot.split(os.sep)[-1]
        csetstr, itstr = specific_snapshot.split("_generator_")[1][:-3].split("_")
        cset, it = str_to_int_set(csetstr[3:]), int(itstr[2:])
        snapshots, xsnapshots, exp_dir = load_counterfactual_setup(exp_dir, args, train_transform, val_transform)
        specific_snapshot = pt.join(exp_dir, 'counterfactual', 'snapshots', specific_snapshot)
        if not pt.exists(specific_snapshot):
            raise ValueError(f"Specific snapshot does not exist! ({specific_snapshot})")
        snapshots = [[snapshots[args.classes.index(cset)][it]]]
        xsnapshots = [[specific_snapshot]]
        args.classes = [cset]
        args.iterations = 1
        args.epochs = 0
        args.x_epochs = 0
        args.x_eval_snapshot = specific_snapshot

    model, *xmodels = get_models(args)
    xtrainer, trainer = create_trainers_from_args(
        args, xmodels, model, train_transform, val_transform, continue_run, args.x_eval_snapshot
    )
    if not xtraining:
        trainer.xtrainer = None
        del xtrainer
    if run:
        res = trainer.run(
            args.classes, args.iterations, snapshots, xsnapshots, seeds=args.fix_seeds,
            tqdm_mininterval=args.tqdm_mininterval, **{k: v for k, v in [
                ("label_smoothing", args.label_smoothing) if "label_smoothing" in vars(args) else (None, None)
            ] if v is not None}
        )
    else:
        res = ([], {})
    return res, trainer, args
