import argparse
import os
import sys
from random import seed, shuffle

import numpy as np
import ray
from ray import tune

from datamodule import AllDatamodule
from config_utils import dict_to_list, get_cfg_defaults, load_from_yaml


def get_parser():
    parser = argparse.ArgumentParser(description="Ray Tune")

    parser.add_argument(
        "-v", "--verbose", action="store_true", help="verbose", default=False
    )

    parser.add_argument(
        "-p", "--progress", action="store_true", help="progress", default=False
    )

    parser.add_argument(
        "--rm", action="store_true", default=False, help="Remove all previous results"
    )

    parser.add_argument(
        "--config_path",
        type=str,
        default="/workspace/configs/dino_mania.yaml",
        help="config file",
    )
    parser.add_argument("--num_samples", type=int, default=1, help="num samples")

    parser.add_argument(
        "--projector_path",
        type=str,
        default="/data/script_base/xdab/state_dict.pth",
        help="bigmodel path",
    )
    # parser.add_argument("--start", type=float, default=0.0, help="start of jobs")
    # parser.add_argument("--end", type=float, default=1.0, help="end of jobs")
    parser.add_argument("--data_dir", type=str, default="/data/VWET", help="data dir")

    return parser


from run_utils import *


def run_train_with_boost(
    tune_dict,
    cfg: AutoConfig = None,
    progress=False,
    projector_path=None,
    **kwargs,
):
    cfg.merge_from_list(dict_to_list(tune_dict))
    cfg = copy.deepcopy(cfg)

    dm: AllDatamodule = build_dm(cfg)
    dm.setup()

    sub_dir = f"stage_1"
    cbs, lgs, log_dir, ckpt_dir = get_callbacks_and_loggers(cfg, sub_dir=sub_dir)

    soup_path = os.path.join(log_dir, "soup.pth")
    if os.path.exists(soup_path):
        print("soup exists, skipping")
        return (None, None, None, None, None, log_dir, ckpt_dir, None)

    os.makedirs(log_dir, exist_ok=True)

    model_args = (
        cfg,
        dm.num_voxel_dict,
        dm.roi_dict,
        dm.neuron_coords_dict,
        dm.noise_ceiling_dict,
    )
    torch.save(cfg, os.path.join(log_dir, "cfg.pth"))

    model = VEModel(*model_args)

    # boost speed with learned weights
    state_dict = load_state_dict(projector_path)
    model.load_state_dict(state_dict, strict=False)

    freeze(model.neck.neuron_projectors)
    freeze(model.backbone)

    trainer = pl.Trainer(
        precision=cfg.TRAINER.PRECISION,
        accelerator="cuda",
        gradient_clip_val=cfg.TRAINER.GRADIENT_CLIP_VAL,
        # strategy=DDPStrategy(find_unused_parameters=False),
        devices=cfg.TRAINER.DEVICES,
        max_epochs=cfg.TRAINER.MAX_EPOCHS,
        max_steps=cfg.TRAINER.MAX_STEPS,
        val_check_interval=cfg.TRAINER.VAL_CHECK_INTERVAL,
        accumulate_grad_batches=cfg.TRAINER.ACCUMULATE_GRAD_BATCHES,
        limit_train_batches=cfg.TRAINER.LIMIT_TRAIN_BATCHES,
        limit_val_batches=cfg.TRAINER.LIMIT_VAL_BATCHES,
        log_every_n_steps=cfg.TRAINER.LOG_TRAIN_N_STEPS,
        callbacks=cbs,
        logger=lgs,
        enable_checkpointing=True,
        enable_progress_bar=progress,
    )

    trainer.fit(model, datamodule=dm)

    best_k_models = None
    best_model_path = None

    trainer.checkpoint_callback.to_yaml(os.path.join(log_dir, "checkpoint.yaml"))

    ckpt: ModelCheckpoint = trainer.checkpoint_callback
    best_k_models = ckpt.best_k_models
    best_model_path = ckpt.best_model_path

    return (best_k_models, best_model_path, trainer, dm, model, log_dir, ckpt_dir, cbs)


def run(
    tune_dict,
    cfg: AutoConfig = None,
    progress=False,
    projector_path=None,
    **kwargs,
):
    cfg.merge_from_list(dict_to_list(tune_dict))
    cfg = copy.deepcopy(cfg)

    (
        best_k_models,
        best_model_path,
        trainer,
        dm,
        model,
        log_dir,
        ckpt_dir,
        cbs,
    ) = run_train_with_boost(
        tune_dict,
        cfg=cfg,
        progress=progress,
        projector_path=projector_path,
        **kwargs,
    )

    from run_utils import greedy_soup_sh_voxel

    val_score, test_score = greedy_soup_sh_voxel(
        trainer, dm, model, best_k_models, log_dir, target="heldout"
    )
    if cfg.TRAINER.CALLBACKS.CHECKPOINT.REMOVE:
        shutil.rmtree(ckpt_dir, ignore_errors=True)


def run_tune(
    name,
    cfg,
    tune_config,
    rm=False,
    progress=False,
    verbose=False,
    num_samples=1,
    projector_path=None,
):
    if rm:
        import shutil

        shutil.rmtree(os.path.join(cfg.RESULTS_DIR, name), ignore_errors=True)

    ana = tune.run(
        tune.with_parameters(
            run,
            cfg=cfg,
            progress=progress,
            projector_path=projector_path,
        ),
        local_dir=cfg.RESULTS_DIR,
        config=tune_config,
        resources_per_trial={"cpu": 1, "gpu": 1},
        num_samples=num_samples,
        name=name,
        verbose=verbose,
        resume="AUTO+ERRORED",
    )


# -
if __name__ == "__main__":
    parser = get_parser()
    args = parser.parse_args()

    cfg = load_from_yaml(args.config_path)
    name = "big_model"
    cfg.RESULTS_DIR = "/data/results/xdabb/dino_mania/"

    tune_config = {
        "OPTIMIZER.LR": tune.grid_search([0.003]),
    }

    run_tune(
        name,
        cfg,
        tune_config,
        args.rm,
        args.progress,
        args.verbose,
        args.num_samples,
        args.projector_path,
    )
