import argparse
import json
import os

import acquire
import dataset
import model
import utils
import training
import result


DIR = os.path.dirname(__file__)
RESULT_DIR = os.path.join(DIR, "..", "results")
DATA_DIR = os.path.join(DIR, "..", "data")


DATASETS = [
    dataset.Blobs,
    dataset.BlobsThin,
    dataset.BlobsThinBalanced,
    dataset.Quadrant,
    dataset.QuadrantCenteredGauss,
    dataset.QuadrantOffCenterGauss,
    dataset.QuadrantLarge,
    dataset.Mnist,
    dataset.MnistAugment,
    dataset.Cifar10,
    dataset.Cifar10Augmented,
    dataset.Cifar100,
    dataset.Svhn
]

MODELS = [
    model.Mlp,
    model.SimpleCnn,
    model.LargeCnn,
    model.Vgg,
    model.VggPretrained
]

ACQUISITION_FNS = [
    acquire.Random,
    acquire.LeastConfidence,
    acquire.MaxEntropy,
    acquire.ProbabilisticCoreset,
    acquire.GreedyCoreset,
    acquire.LcBeamCoreset,
    acquire.Top2EntropyBeamCoreset,
    acquire.LcBeamPWeightedCoreset,
    acquire.LcBeamPWeightedOptimCore,
    acquire.LcBeamPWeightedRelConfCoreset,
    acquire.LcBeamPWeightedRelConfOptimCore,
    acquire.SelectiveProbabilisticCoreset,
    acquire.SpCoresetGreedyInit
]


def main(param_path, download):
    result_name = os.path.splitext(
        os.path.basename(param_path)
    )[0]
    data_name = os.path.basename(
        os.path.dirname(param_path)
    )

    dname = os.path.join(RESULT_DIR, data_name, result_name)
    if not os.path.isdir(dname):
        os.makedirs(dname)

    utils.log.init(os.path.join(dname, "log.txt"))

    with open(param_path, "r") as f:
        params = json.load(f)

        (
            datapool_creator,
            trainer,
            models,
            acquirefs,
            max_labelled,
            init_labelled,
            experiment_repeats,
            reinit_model_per_acquisition_step,
            first_training_epochs,
            min_training_score,
            save_first_labelled_pool_path,
            seed
        ) = _init_from_params(params, download)

        results_path = os.path.join(dname, "results.json")
        with result.Result(params, results_path) as res:
            expts = []
            for acquiref in acquirefs:
                for model in models:
                    expts.append(training.Experiment(
                        trainer=trainer,
                        model=model,
                        acquiref=acquiref,
                        max_labelled=max_labelled,
                        reinit_model_per_acquisition_step=reinit_model_per_acquisition_step,
                        first_training_epochs=first_training_epochs,
                        min_training_score=min_training_score,
                        save_first_labelled_pool_path=save_first_labelled_pool_path,
                        result=res
                    ))
            training.Experiment.repeat(
                datapool_creator=datapool_creator,
                init_labelled=init_labelled,
                n=experiment_repeats,
                expts=expts,
                start_seed=seed
            )


def _init_from_params(params, download):
    datapool_creator = _lookup_class(
        {
            "type": params["datapool_creator"],
            "config": {
                "path": DATA_DIR,
                "download": download,
                "label_smooth": params["label_smooth"]
            }
        },
        options=DATASETS
    )

    disable_tqdm = params["disable_tqdm"]
    utils.Bar.config(disable=disable_tqdm)

    device = params["device"]
    max_batch_size = params["max_batch_size"]
    budget = params["budget"]
    max_labelled = params["max_labelled"]
    init_labelled = params["init_labelled"]
    experiment_repeats = params["experiment_repeats"]
    reinit_model_per_acquisition_step = params["reinit_model_per_acquisition_step"]
    first_training_epochs = params["first_training_epochs"]
    min_training_score = params["min_training_score"]

    special_key = "save_first_labelled_pool_path"
    save_first_labelled_pool_path = None if special_key not in params else params[special_key]
    seed = params["seed"]


    trainer = training.Trainer.from_params(
        device=device,
        test_batchsize=max_batch_size,
        **params["trainer"]
    )

    models = [
        _lookup_class(
            {
                "type": p["type"],
                "config": {
                    "input_size": datapool_creator.get_input_dim(),
                    "targets": datapool_creator.get_classes(),
                    **p["config"]
                }
            },
            options=MODELS
        )
        for p in params["models"]
    ]

    acquirefs = [
        _lookup_class(
            {
                "type": p["type"],
                "config": {
                    "budget": budget,
                    "device": device,
                    "batch_size": max_batch_size,
                    **p["config"]
                }
            },
            options=ACQUISITION_FNS
        )
        for p in params["acquire"]
    ]

    return (
        datapool_creator,
        trainer,
        models,
        acquirefs,
        max_labelled,
        init_labelled,
        experiment_repeats,
        reinit_model_per_acquisition_step,
        first_training_epochs,
        min_training_score,
        save_first_labelled_pool_path,
        seed
    )


def _lookup_class(params, options):
    return {
        Class.get_name(): Class
        for Class in options
    }[params["type"]](**params["config"])


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--param_path", required=True)
    parser.add_argument("--download", type=int, default=0)
    args = parser.parse_args()

    main(**vars(args))