import argparse
import json
import os
from multiprocessing import Pool, Queue, current_process
from pathlib import Path
from subprocess import call

import torch

AVAILABLE_MODELS = (
    []
    + ["akt", "atdkt", "atkt", "deep_irt", "dimkt", "dkt", "dkt_forget", "dkt_plus"]
    + ["dkvmn", "dtransformer", "folibikt", "gkt", "hawkes", "iekt", "kqn", "lpkt"]
    + ["qdkt", "qikt", "saint", "sakt", "simplekt", "skvmn"]
)
AVAILABLE_DATASETS = (
    []
    + ["assist2009", "algebra2005", "bridge2algebra2006", "nips_task34"]
    + ["statics2011", "assist2015", "poj", "ednet"]
)


def eval_pykt_model(save_dir: tuple) -> None:

    gpu_id = queue.get()
    process_id = current_process().ident

    # Build command to start model training
    cmd = f"python {os.getcwd()}/pykt-toolkit/examples/wandb_predict.py --use_wandb=0 --save_dir={save_dir} --bz=64"

    # Start computation in subprocess
    try:
        print(f"{process_id}: starting process on GPU {gpu_id}")
        print(f"{process_id}: {cmd}")

        with open(f"{save_dir}/log_eval", "w") as log_file:
            code = call(
                cmd.split(" "),
                cwd=f"{os.getcwd()}/pykt-toolkit/examples",
                env=dict(os.environ) | {"CUDA_VISIBLE_DEVICES": f"{gpu_id}"},
                stdout=log_file,
            )

        # If successful, catch last line and store result as a json file
        if code == 0:
            with open(f"{save_dir}/log_eval", "r") as f:
                results = json.loads(f.readlines()[-1].rstrip("\n").replace("'", '"'))
            with open(f"{save_dir}/test_results.json", "w") as f:
                json.dump(results, f)

        print(f"{process_id}: finished")
    finally:
        queue.put(gpu_id)


def train_pykt_model(config: tuple) -> None:

    gpu_id = queue.get()
    process_id = current_process().ident
    model, dataset, fold, save_dir = config

    # Load hyperparameter
    with open(f"pykt-hyperparameters/{model}.json") as f:
        try:
            hp_config = next(c for c in json.load(f)[dataset] if c["fold"] == fold)
        except KeyError as e:
            print(config)
            raise e

    # Build command to start model training
    cmd = f"python {os.getcwd()}/pykt-toolkit/examples/wandb_{'sparsekt' if model.startswith('sparsekt') else model}_train.py"
    cmd += f" --dataset_name={dataset}"
    for k, v in hp_config.items():
        cmd += f" --{k}={v}"
    cmd += f" --use_wandb=0 --add_uuid=0"
    cmd += f" --save_dir={save_dir}/{model}/{dataset}/{fold}"

    # Start computation in subprocess
    try:
        print(f"{process_id}: starting process on GPU {gpu_id}")
        print(f"{process_id}: {cmd}")

        Path(f"{save_dir}/{model}/{dataset}/{fold}").mkdir(parents=True, exist_ok=True)

        with open(f"{save_dir}/{model}/{dataset}/{fold}/log_train", "w") as log_file:
            code = call(
                cmd.split(" "),
                cwd=f"{os.getcwd()}/pykt-toolkit/examples",
                env=dict(os.environ) | {"CUDA_VISIBLE_DEVICES": f"{gpu_id}"},
                stdout=log_file,
            )

        # Create file to mark finished computations
        if code == 0:
            with open(f"{save_dir}/{model}/{dataset}/{fold}/done", "w") as f:
                f.write("done")

        print(f"{process_id}: finished")
    finally:
        queue.put(gpu_id)


if __name__ == "__main__":

    assert torch.cuda.is_available()
    device_count = torch.cuda.device_count()

    description = f"""Run benchmarks of the pykt-toolkit.
    Available models: {AVAILABLE_MODELS}
    Available datasets: {AVAILABLE_DATASETS}\n\n"""

    parser = argparse.ArgumentParser(
        description=description,
        formatter_class=argparse.RawTextHelpFormatter,
    )
    parser.add_argument(
        "mode",
        type=str,
        help="mode of execution (`train` or `eval`)",
    )
    parser.add_argument(
        "--models",
        type=str,
        default="",
        help="comma-separated model names (default: all models)",
    )
    parser.add_argument(
        "--datasets",
        type=str,
        default="",
        help="comma-separated dataset names (default: all datasets)",
    )
    parser.add_argument(
        "--folds",
        type=str,
        default="",
        help="comma-separated list of folds to compute (default: 0,1,2,3,4)",
    )
    parser.add_argument(
        "--devices",
        type=str,
        default="",
        help="comma-separated list of CUDA device ids that can be used for computations (default: all available devices)",
    )
    parser.add_argument(
        "--save_dir",
        type=str,
        default=f"{os.getcwd()}/outputs/pykt-models",
        help=f"file path to results folder (default: {os.getcwd()}/outputs/pykt-models)",
    )
    parser.add_argument(
        "--dry",
        type=bool,
        default=False,
        action=argparse.BooleanOptionalAction,
        help="list computations that configuration would trigger (default: --no-dry)",
    )

    args = parser.parse_args()

    # Parse comma separated inputs
    mode = args.mode
    models = args.models.split(",") if args.models else AVAILABLE_MODELS
    datasets = args.datasets.split(",") if args.datasets else AVAILABLE_DATASETS
    folds = [int(x) for x in (args.folds.split(",") if args.folds else range(5))]
    devices = [
        int(x)
        for x in (args.devices.split(",") if args.devices else range(device_count))
    ]

    # Validate input
    assert mode in ("eval", "train")
    assert all([m in AVAILABLE_MODELS for m in models])
    assert all([d in AVAILABLE_DATASETS for d in datasets])
    assert all([f in [x for x in range(5)] for f in folds])
    assert all([d in [x for x in range(device_count)] for d in devices])

    # Determine list of outstanding compute tasks
    tasks = []

    if mode == "train":
        for m in models:
            for d in datasets:
                for f in folds:
                    # Catch invalid configs
                    invalid_config = (
                        (m == "dkt_forget" and d in ("assist2009", "assist2015"))
                        or (m == "lpkt" and d in ("statics2011", "assist2015", "poj"))
                        or (m == "iekt" and d in ("statics2011", "assist2015", "poj"))
                        or (m == "dimkt" and d in ("statics2011", "assist2015", "poj"))
                        or (m == "atdkt" and d in ("statics2011", "assist2015", "poj"))
                        or (m == "qikt" and d in ("statics2011", "assist2015", "poj"))
                        or (m == "qdkt" and d in ("statics2011", "assist2015", "poj"))
                        or (m == "hawkes" and d in ("statics2011", "assist2015", "poj"))
                    )
                    if invalid_config:
                        continue
                    if not Path(f"{args.save_dir}/{m}/{d}/{f}/done").is_file():
                        tasks.append((m, d, f, args.save_dir))
    else:
        included_path_patterns = [
            f"/{m}/{d}/{f}/" for f in folds for d in datasets for m in models
        ]
        for root_path in Path(args.save_dir).glob("**/done"):
            model_checkpoint_path = next(root_path.parent.glob("*/*.ckpt"), None)
            if (
                model_checkpoint_path is not None
                and not Path(
                    f"{model_checkpoint_path.parent}/test_results.json"
                ).is_file()
                and any(
                    substring in str(root_path) for substring in included_path_patterns
                )
            ):
                tasks.append(model_checkpoint_path.parent)

    if args.dry:
        for i, task in enumerate(tasks):
            print(f"{i:3}: {task}")
    else:
        # Initialize the queue with the device ids
        queue = Queue()
        for d in devices:
            queue.put(d)

        # Initialize pool of workers, add tasks and start computation
        pool = Pool(processes=len(devices))
        for _ in pool.imap_unordered(
            train_pykt_model if mode == "train" else eval_pykt_model, tasks
        ):
            pass
        pool.close()
        pool.join()
