# %%
import argparse
import copy
import fnmatch
from functools import partial
import glob
import operator
import os
import sys
from typing import Dict

import numpy as np
import pytorch_lightning as pl
import torch
from pytorch_lightning import loggers as pl_loggers
from ray import tune
import shutil

from tqdm import tqdm

from config import AutoConfig
from config_utils import load_from_yaml
from datamodule import build_dm, AllDatamodule
from models import VEModel
from topyneck import VoxelOutBlock

from run_utils import (
    greedy_soup_from_runs,
    run_train_one_stage,
    uniform_soup_sh_voxel,
    greedy_soup_sh_voxel,
    uniform_soup_sp_voxel,
)
from config_utils import dict_to_list


WDS = [1e-4, 1e-5, 2e-5, 3e-5, 4e-5, 5e-5, 6e-5, 7e-5, 8e-5, 9e-5]


def get_parser():
    parser = argparse.ArgumentParser(description="run with all subjects to cluster")

    parser.add_argument(
        "-v", "--verbose", action="store_true", help="verbose", default=False
    )
    parser.add_argument(
        "-p", "--progress", action="store_true", help="progress", default=False
    )
    parser.add_argument(
        "--rm", action="store_true", default=False, help="Remove all previous results"
    )
    parser.add_argument(
        "--num_samples", type=int, default=3, help="num samples for model soup"
    )
    parser.add_argument(
        "--config",
        type=str,
        default="/workspace/configs/crn_tiny_debug.yaml",
        help="config file",
    )
    parser.add_argument(
        "--results_dir", type=str, default="/data/results/xaaa", help="results dir"
    )

    return parser


def train_fn(cfg, progress=False, stage=1, model_path=None, log_dir=None):
    (
        best_k_models,
        best_model_path,
        trainer,
        dm,
        model,
        log_dir,
        ckpt_dir,
        cbs,
    ) = run_train_one_stage(
        cfg, progress, stage=stage, model_path=model_path, log_dir=log_dir
    )

    if stage == 1 or stage == 3:
        val_score, test_score = greedy_soup_sh_voxel(
            trainer, dm, model, best_k_models, log_dir, target="heldout"
        )
    if stage == 2 or stage == 4:
        val_score, test_score = uniform_soup_sp_voxel(
            trainer,
            dm,
            model,
            cbs,
            log_dir,
        )

    # model_path = os.path.join(log_dir, "soup.pth")
    if cfg.TRAINER.CALLBACKS.CHECKPOINT.REMOVE:
        shutil.rmtree(ckpt_dir, ignore_errors=True)

    # return val_score, test_score, model_path


def tune_fn(
    tune_dict, cfg: AutoConfig, progress=False, stage=1, model_path=None, log_dir=None
):
    cfg.merge_from_list(dict_to_list(tune_dict))
    train_fn(
        cfg, progress=progress, stage=stage, model_path=model_path, log_dir=log_dir
    )


def run_tune(
    fn,
    name,
    cfg,
    tune_config,
    rm=False,
    verbose=False,
    num_samples=1,
    **kwargs,
):
    from train import run_tune

    if rm:
        import shutil

        shutil.rmtree(os.path.join(cfg.RESULTS_DIR, name), ignore_errors=True)

    ana = tune.run(
        fn,
        local_dir=cfg.RESULTS_DIR,
        config=tune_config,
        resources_per_trial={"cpu": 1, "gpu": 1},
        num_samples=num_samples,
        name=name,
        verbose=verbose,
        resume="AUTO+ERRORED",
    )

    result_dir = os.path.join(cfg.RESULTS_DIR, name)

    return ana, result_dir


def main():
    parser = get_parser()
    args = parser.parse_args()

    # num_models = args.num_samples
    # args.num_samples = 1
    # assert num_models > 0 and num_models <= 10

    cfg_file_basename = os.path.basename(args.config).split(".")[0]

    cfg = load_from_yaml(args.config)
    cfg.RESULTS_DIR = args.results_dir
    cfg.RESULTS_DIR = os.path.join(cfg.RESULTS_DIR, cfg_file_basename)

    if args.rm:
        shutil.rmtree(cfg.RESULTS_DIR, ignore_errors=True)

    tune_config1 = {
        "OPTIMIZER.WEIGHT_DECAY": tune.grid_search(WDS[: args.num_samples]),
    }

    ana, exp_dir = run_tune(
        tune.with_parameters(tune_fn, cfg=cfg, progress=args.progress, stage=1),
        "stage_1",
        cfg,
        tune_config1,
        verbose=args.verbose,
        num_samples=1,
    )

    model_path_1 = os.path.join(exp_dir, "soup.pth")
    if not os.path.exists(model_path_1):
        model_path_1, val_score, test_score = greedy_soup_from_runs(exp_dir)

    tune_config2 = {
        "OPTIMIZER.VOXEL_WEIGHT_DECAY": tune.grid_search(WDS[: args.num_samples]),
    }

    ana, exp_dir = run_tune(
        tune.with_parameters(
            tune_fn, cfg=cfg, progress=args.progress, stage=2, model_path=model_path_1
        ),
        "stage_2",
        cfg,
        tune_config2,
        verbose=args.verbose,
        num_samples=1,
    )

    model_path_2 = os.path.join(exp_dir, "soup.pth")
    if not os.path.exists(model_path_2):
        model_path_2, val_score, test_score = greedy_soup_from_runs(exp_dir)

    print("done")
    print("model_path_1: ", model_path_1)
    print("model_path_2: ", model_path_2)


if __name__ == "__main__":
    main()
