import argparse
import os
import sys
from random import seed, shuffle

import numpy as np
import ray
from ray import tune

from train import run_tune

from datamodule import AllDatamodule
from config_utils import get_cfg_defaults, load_from_yaml


def get_parser():
    parser = argparse.ArgumentParser(description="Ray Tune")

    parser.add_argument(
        "-v", "--verbose", action="store_true", help="verbose", default=False
    )

    parser.add_argument(
        "-p", "--progress", action="store_true", help="progress", default=False
    )

    parser.add_argument(
        "--rm", action="store_true", default=False, help="Remove all previous results"
    )

    parser.add_argument("--start", type=float, default=0.0, help="start of jobs")
    parser.add_argument("--end", type=float, default=1.0, help="end of jobs")
    parser.add_argument("--data_dir", type=str, default="/data/VWE", help="data dir")

    return parser


def run_tune(
    name, cfg, tune_config, rm=False, progress=False, verbose=False, num_samples=1
):
    from train import run_tune

    if rm:
        import shutil

        shutil.rmtree(os.path.join(cfg.RESULTS_DIR, name), ignore_errors=True)

    ana = tune.run(
        tune.with_parameters(run_tune, cfg=cfg, progress=progress),
        local_dir=cfg.RESULTS_DIR,
        config=tune_config,
        resources_per_trial={"cpu": 1, "gpu": 1},
        num_samples=num_samples,
        name=name,
        verbose=verbose,
        resume="AUTO+ERRORED",
    )


# -
if __name__ == "__main__":
    parser = get_parser()
    args = parser.parse_args()

    # all_subject_list = AllDatamodule.get_all_subject_list(args.data_dir)

    cfg = load_from_yaml("/workspace/configs/dino_bb.yaml")

    from backbone import LAYER_DICT, RESOLUTION_DICT

    # b = "resnet50"
    # cfg.DATASET.SUBJECT_LIST = ["all"]
    cfg.DATASET.SUBJECT_LIST = ["NSD_04"]
    cfg.MODEL.NECK.CONV_HEAD.KERNEL_SIZE = 5
    cfg.MODEL.NECK.CONV_HEAD.DEPTH = 3
    cfg.MODEL.NECK.CONV_HEAD.BN = False
    cfg.MODEL.NECK.CONV_HEAD.LN = True
    cfg.MODEL.NECK.CONV_HEAD.SKIP_CONNECTION = True

    # cfg.TRAINER.LIMIT_TRAIN_BATCHES = 1.
    # layers = LAYER_DICT[b]
    # cfg.MODEL.BACKBONE.NAME = b
    # cfg.MODEL.BACKBONE.LAYERS = LAYER_DICT[b]
    # cfg.DATASET.RESOLUTION = RESOLUTION_DICT[b]
    # cfg.ANALYSIS.SAVE_NEURON_LOCATION = False
    # cfg.TRAINER.CALLBACKS.CHECKPOINT.REMOVE = True

    # cfg.TRAINER.MAX_EPOCHS = 100

    # cfg.ANALYSIS.DRAW_NEURON_LOCATION = False
    # cfg.ANALYSIS.SAVE_NEURON_LOCATION = False

    # cfg.TRAINER.CALLBACKS.EARLY_STOP.PATIENCE = 40

    # jobs = [[x] for x in all_subject_list][::-1]

    tune_config = {
        "OPTIMIZER.LR": tune.grid_search([0.01]),
        "OPTIMIZER.GATE_REGULARIZER": tune.grid_search([0., 1e-4, 3e-5, 100.]),
        # "MODEL.NECK.CONV_HEAD.BN": tune.grid_search([False]),
        # "MODEL.NECK.CONV_HEAD.LN": tune.sample_from(lambda spec: not spec.config["MODEL.NECK.CONV_HEAD.BN"]),
        # "MODEL.BACKBONE.LAYERS": tune.grid_search(
        #     [
        #         ["layer2", "layer5", "layer8", "layer11"],
        #         ["layer3", "layer7", "layer11"],
        #         ["layer5", "layer8", "layer11"],
        #     ]
        # ),
        # "DATASET.RESOLUTION": tune.grid_search([[224, 224], [126, 126], [182, 182], [280, 280]]),
        # "MODEL.NECK.CONV_HEAD.KERNELS": tune.grid_search([[5, 5], [7, 7]]),
        # "MODEL.NECK.CONV_HEAD.LAST_KERNELS": tune.sample_from(lambda spec: spec.config["MODEL.NECK.CONV_HEAD.KERNELS"]),
        # "DATASET.SUBJECT_LIST": tune.grid_search([["NSD_01"], ["NSD_08"]]),
        # # "DATASET.ROIS": tune.grid_search([['st_1'], ['st_7']]),
        # "DATASET.ROIS": tune.sample_from(
        #     lambda spec: ["st_2"]
        #     if spec.config["DATASET.SUBJECT_LIST"] == ["NSD_08"]
        #     else ["st_7"]
        # ),
        # "OPTIMIZER.LR": tune.sample_from(
        #     lambda spec: 3e-3
        #     if spec.config["DATASET.SUBJECT_LIST"] == ["NSD_08"]
        #     else 3e-3
        # )
        # "OPTIMIZER.GATE_REGULARIZER": tune.grid_search([3e-4, 1e-4, 3e-5, 1e-5]),
        # "OPTIMIZER.MU_REGULARIZER_MCENTER": tune.grid_search([0., 1e-4, 3e-5, 1e-5]),
        # "OPTIMIZER.MU_REGULARIZER_PDIST": tune.sample_from(lambda spec: spec.config["OPTIMIZER.MU_REGULARIZER_MCENTER"]),
        # "OPTIMIZER.MU_REGULARIZER_PCENTER": tune.sample_from(lambda spec: spec.config["OPTIMIZER.MU_REGULARIZER_MCENTER"]),
        # "DATAMODULE.BATCH_SIZE": tune.grid_search([32, 64]),
        # "LOSS.SYNC.EMA_KEY": tune.grid_search(["vgrad", 'running_v_grad']),
        # "MODEL.NECK.CONV_HEAD.MAX_DIM": tune.grid_search([4096, 1024, 512, 256]),
        # "MODEL.BACKBONE.NAME": tune.grid_search(['CLIP-RN50', 'CLIP-RN50x64', 'CLIP-RN50x16', 'CLIP-RN50x4']),
        # "LOSS.SYNC.EMA_BETA": tune.grid_search([0.99, 0.9,]),
        # "OPTIMIZER.GATE_REGULARIZER": tune.grid_search([1e-3, 1e-4]),
        # "MODEL.NEURON_PROJECTOR.SEPARATE_LAYERS": tune.grid_search([True, False]),
        # "OPTIMIZER.GATE_REGULARIZER": tune.grid_search([1e-2, 3e-3, 1e-3, 3e-4]),
        # "DATASET.SUBJECT_LIST": tune.grid_search([['all_nsd']]),
        # "DATASET.SUBJECT_LIST": tune.grid_search(jobs),
        # "DATASET.SUBJECT_LIST": tune.grid_search(
        #     [
        #         ["ALG"],
        #         ["NSD_01"],
        #     ]
        # ),
        # "OPTIMIZER.MU_REGULARIZER_CENTER": tune.grid_search([1e-2]),
        # "OPTIMIZER.MU_REGULARIZER_MEAN": tune.grid_search([3e-3, 1e-3, 3e-4, 1e-4]),
        # "OPTIMIZER.MU_REGULARIZER": tune.grid_search([1e-4]),
        # "MODEL.NECK.CONV_HEAD.WIDTH": tune.grid_search([16]),
        # "FINETUNE.USE_LINEAR": tune.grid_search([True]),
        # "FINETUNE.SOUP": tune.grid_search(["uniform"]),
        # "OPTIMIZER.VOXEL_WEIGHT_DECAY": tune.grid_search([1e-3]),
        # "OPTIMIZER.LR": tune.grid_search([1e-2]),
        # "TRAINER.MAX_EPOCHS": tune.grid_search([100]),
        # "STAGE": tune.grid_search(["finetune"]),
        # "FINETUNE.SOURCE": tune.grid_search(["all_nsd"]),
        # "MODEL.BACKBONE.LAYERS": tune.grid_search([
        # ['layer4', 'layer3', 'layer2', 'layer1'], ['layer1', 'layer3', 'layer2']]),
        # "MODEL.NECK.CONV_HEAD.KERNELS": tune.grid_search([[5, 3], [5, 5]]),
        # "OPTIMIZER.MU_REGULARIZER": tune.grid_search([0.0, 1e-2, 1e-4, 1e-6]),
    }

    name = "GATE_REGULARIZER"
    # name = "head_1x1"

    run_tune(name, cfg, tune_config, args.rm, args.progress, args.verbose, 1)
