import argparse
import os
import sys
from random import seed, shuffle

import numpy as np
import ray
from ray import tune

from train import run_tune

from datamodule import AllDatamodule
from config_utils import get_cfg_defaults, load_from_yaml


def get_parser():
    parser = argparse.ArgumentParser(description="Ray Tune")

    parser.add_argument(
        "-v", "--verbose", action="store_true", help="verbose", default=False
    )

    parser.add_argument(
        "-p", "--progress", action="store_true", help="progress", default=False
    )

    parser.add_argument(
        "--rm", action="store_true", default=False, help="Remove all previous results"
    )

    parser.add_argument("--start", type=float, default=0.0, help="start of jobs")
    parser.add_argument("--end", type=float, default=1.0, help="end of jobs")
    parser.add_argument("--data_dir", type=str, default="/data/VWET", help="data dir")

    return parser


def run_tune(
    name, cfg, tune_config, rm=False, progress=False, verbose=False, num_samples=1
):
    from train import run_tune

    if rm:
        import shutil

        shutil.rmtree(os.path.join(cfg.RESULTS_DIR, name), ignore_errors=True)

    ana = tune.run(
        tune.with_parameters(run_tune, cfg=cfg, progress=progress),
        local_dir=cfg.RESULTS_DIR,
        config=tune_config,
        resources_per_trial={"cpu": 1, "gpu": 1},
        num_samples=num_samples,
        name=name,
        verbose=verbose,
        resume="AUTO+ERRORED",
    )


# -
if __name__ == "__main__":

    parser = get_parser()
    args = parser.parse_args()

    cfg = load_from_yaml("/workspace/configs/dino_mania.yaml")
    cfg.RESULTS_DIR = "/data/results/xdaa/dino_mania/"
    
    cfg.TRAINER.MAX_EPOCHS = 200
    cfg.TRAINER.CALLBACKS.EARLY_STOP.PATIENCE = 114514
    cfg.ANALYSIS.SAVE_NEURON_LOCATION = True

    tune_config = {
        "DATASET.SUBJECT_LIST": tune.grid_search(
            [
                ["NSD_01"],
                ["NSD_02"],
                ["NSD_03"],
                ["NSD_04"],
                ["NSD_05"],
                ["NSD_06"],
                ["NSD_07"],
                ["NSD_08"],
                ["HCP"],
                ["ALG"],
                ["EEG"],
                ["MEG"],
            ]
        ),
    }

    name = "suplong"
    # name = "head_1x1"

    run_tune(name, cfg, tune_config, args.rm, args.progress, args.verbose, 1)
