# %%
import argparse
import copy
import fnmatch
from functools import partial
import glob
import operator
import os
import sys
from typing import Dict

import numpy as np
import pytorch_lightning as pl
import torch
from pytorch_lightning import loggers as pl_loggers

from ray import tune

import shutil

from tqdm import tqdm

from config import AutoConfig
from config_utils import dict_to_list, load_from_yaml
from datamodule import build_dm, AllDatamodule
from models import VEModel
from topyneck import VoxelOutBlock

from run_utils import *


LRS = [3e-3, 4e-3, 5e-3]


def run_train_with_boost(
    cfg: AutoConfig,
    progress=False,
    bigmodel_path=None,
    sub_dir="r1s1",
    **kwargs,
):
    cfg = copy.deepcopy(cfg)

    dm: AllDatamodule = build_dm(cfg)
    dm.setup()

    cbs, lgs, log_dir, ckpt_dir = get_callbacks_and_loggers(cfg, sub_dir=sub_dir)

    os.makedirs(log_dir, exist_ok=True)

    model_args = (
        cfg,
        dm.num_voxel_dict,
        dm.roi_dict,
        dm.neuron_coords_dict,
        dm.noise_ceiling_dict,
    )
    torch.save(cfg, os.path.join(log_dir, "cfg.pth"))

    model = VEModel(*model_args)

    # boost speed with learned weights
    state_dict = load_state_dict(bigmodel_path)
    # subject_list = dm.dss[0].keys()
    # filtered_state_dict = {}
    # for key, value in state_dict.items():
    #     found = False
    #     for subject in subject_list:
    #         if subject in key:
    #             found = True
    #             break
    #     if not found:
    #         continue
    #     # # # print(cfg.DATASET.ROIS)
    #     # # print(dm.dss[0].keys())
    #     # print(subject)

    #     vi = dm.dss[0][subject].voxel_index
    #     if "voxel_outs" in key:
    #         raise ValueError("voxel_outs should not be in state_dict")
    #         value = value[vi]
    #     filtered_state_dict[key] = value

    model.load_state_dict(state_dict, strict=False)

    freeze(model.neck.neuron_projectors)
    # for subject in subject_list:
    # freeze(model.neck.neuron_projectors[subject].projectors)
    # freeze(model.neck.voxel_outs)
    # freeze(model.backbone)

    trainer = pl.Trainer(
        precision=cfg.TRAINER.PRECISION,
        accelerator="cuda",
        gradient_clip_val=cfg.TRAINER.GRADIENT_CLIP_VAL,
        # strategy=DDPStrategy(find_unused_parameters=False),
        devices=cfg.TRAINER.DEVICES,
        max_epochs=cfg.TRAINER.MAX_EPOCHS,
        max_steps=cfg.TRAINER.MAX_STEPS,
        val_check_interval=cfg.TRAINER.VAL_CHECK_INTERVAL,
        accumulate_grad_batches=cfg.TRAINER.ACCUMULATE_GRAD_BATCHES,
        limit_train_batches=cfg.TRAINER.LIMIT_TRAIN_BATCHES,
        limit_val_batches=cfg.TRAINER.LIMIT_VAL_BATCHES,
        log_every_n_steps=cfg.TRAINER.LOG_TRAIN_N_STEPS,
        callbacks=cbs,
        logger=lgs,
        enable_checkpointing=True,
        enable_progress_bar=progress,
    )

    soup_path = os.path.join(log_dir, "soup.pth")
    if os.path.exists(soup_path):
        print("soup exists, skipping run_train_with_boost")
        return (None, None, trainer, dm, model, log_dir, ckpt_dir, None)

    trainer.fit(model, datamodule=dm)

    val_score, test_score = validate_test_log(trainer, model, dm, prefix="SOUP_NO/one/")

    best_k_models = None
    best_model_path = None

    trainer.checkpoint_callback.to_yaml(os.path.join(log_dir, "checkpoint.yaml"))

    ckpt: ModelCheckpoint = trainer.checkpoint_callback
    best_k_models = ckpt.best_k_models
    best_model_path = ckpt.best_model_path

    return (best_k_models, best_model_path, trainer, dm, model, log_dir, ckpt_dir, cbs)


def run_multiple_stage_one(cfg: AutoConfig, args):
    log_dir = tune.get_trial_dir()
    soup_path = os.path.join(log_dir, "soup.pth")
    if os.path.exists(soup_path):
        print("soup exists, skipping run_multiple_stage_one")
        return None, None, None, None

    # run 3 times stage 1
    best_models_run = {}
    for i in range(1, args.num_runs + 1):
        cfg.OPTIMIZER.LR = LRS[i - 1]
        sub_dir = f"r1s1_{i}"
        (
            best_k_models,
            best_model_path,
            trainer,
            dm,
            model,
            log_dir,
            ckpt_dir,
            cbs,
        ) = run_train_with_boost(
            copy.deepcopy(cfg),
            progress=args.progress,
            bigmodel_path=args.bigmodel_path,
            sub_dir=sub_dir,
        )
        model_path_1 = os.path.join(log_dir, "soup.pth")
        if not os.path.exists(model_path_1):
            if model.cfg.MODEL_SOUP.RECIPE == "greedy":
                val_score, test_score = greedy_soup_sh_voxel(
                    trainer, dm, model, best_k_models, log_dir, target="heldout"
                )
            elif model.cfg.MODEL_SOUP.RECIPE == "uniform":
                val_score, test_score = uniform_soup_sh_voxel(
                    trainer,
                    dm,
                    model,
                    best_k_models,
                    log_dir,
                )
            else:
                raise ValueError("invalid recipe")
            if cfg.TRAINER.CALLBACKS.CHECKPOINT.REMOVE:
                shutil.rmtree(ckpt_dir, ignore_errors=True)
        else:
            val_score = torch.load(os.path.join(log_dir, "soup_val_score.pth"))
            test_score = torch.load(os.path.join(log_dir, "soup_test_score.pth"))

        best_models_run[model_path_1] = val_score

    return best_models_run, trainer, dm, model


def greedy_soup_stage_one_runs(
    trainer, dm, model, best_k_models, log_dir, target="heldout"
):
    soup_path = os.path.join(log_dir, "soup.pth")
    if os.path.exists(soup_path):
        print("soup exists, skipping greedy_soup_stage_one_runs")
        return None, None

    model.zero_flag = True
    val_score, test_score = greedy_soup_sh_voxel(
        trainer,
        dm,
        model,
        best_k_models,
        log_dir,
        prefix="SOUP_RUNS/greedy",
    )

    for path in best_k_models.keys():
        if os.path.exists(path):
            os.remove(path)

    return val_score, test_score


def run_stage_two(cfg: AutoConfig, args, stage_1_soup_path):
    (
        best_k_models,
        best_model_path,
        trainer,
        dm,
        model,
        log_dir,
        ckpt_dir,
        cbs,
    ) = run_train_one_stage(
        copy.deepcopy(cfg),
        progress=args.progress,
        stage=2,
        model_path=stage_1_soup_path,
    )
    model_path_2 = os.path.join(log_dir, "soup.pth")
    if not os.path.exists(model_path_2):
        if not cfg.TRAINER.STAGE_2_EMA:
            if cfg.FINETUNE.SOUP == "uniform":
                uniform_soup_sp_voxel(trainer, dm, model, cbs, log_dir)
            elif cfg.FINETUNE.SOUP == "greedy":
                greedy_soup_sp_voxel(trainer, dm, model, cbs, log_dir)
            else:
                raise NotImplementedError
        else:
            model_path_2 = os.path.join(log_dir, "soup.pth")
            torch.save(model.state_dict(), model_path_2)
            pass  # handled by callback

        if cfg.TRAINER.CALLBACKS.CHECKPOINT.REMOVE:
            shutil.rmtree(ckpt_dir, ignore_errors=True)

    return model_path_2


def main(args, cfg: AutoConfig):
    cfg = copy.deepcopy(cfg)

    roi = cfg.DATASET.ROIS[0]
    # if roi == "st_1":
    #     cfg.OPTIMIZER.LR = 0.0001
    #     cfg.OPTIMIZER.SCHEDULER.LR_MIN = 0.0001
    if args.debug:
        cfg.DATASET.SUBJECT_LIST = ["ALG"]
        cfg.TRAINER.MAX_EPOCHS = 12
        cfg.TRAINER.STAGE_2_MAX_EPOCHS = 12

    log_dir = tune.get_trial_dir()
    stage_1_soup_path = os.path.join(log_dir, "soup.pth")
    if not os.path.exists(stage_1_soup_path):
        best_models_run, trainer, dm, model = run_multiple_stage_one(cfg, args)

        log_dir = tune.get_trial_dir()
        val_score, test_score = greedy_soup_stage_one_runs(
            trainer, dm, model, best_models_run, log_dir
        )
    else:
        print("soup exists, skipping stage 1")

    # stage_2_soup_path = run_stage_two(cfg, args, stage_1_soup_path)


def run_tune_job(tune_dict, args, cfg):
    cfg.merge_from_list(dict_to_list(tune_dict))

    main(args, cfg)


def run_tune(name, args, tune_config, rm=False, verbose=False, num_samples=1):
    cfg_file_basename = os.path.basename(args.config).split(".")[0]

    cfg = load_from_yaml(args.config)
    cfg.RESULTS_DIR = args.results_dir
    cfg.RESULTS_DIR = os.path.join(cfg.RESULTS_DIR, cfg_file_basename)

    if rm:
        import shutil

        shutil.rmtree(os.path.join(cfg.RESULTS_DIR, name), ignore_errors=True)

    ana = tune.run(
        tune.with_parameters(run_tune_job, args=args, cfg=cfg),
        local_dir=cfg.RESULTS_DIR,
        config=tune_config,
        resources_per_trial={"cpu": 1, "gpu": 1},
        num_samples=num_samples,
        name=name,
        verbose=verbose,
        resume="AUTO+ERRORED",
    )


def get_parser():
    parser = argparse.ArgumentParser(description="run with all subjects to cluster")

    parser.add_argument(
        "-v", "--verbose", action="store_true", help="verbose", default=False
    )
    parser.add_argument(
        "-p", "--progress", action="store_true", help="progress", default=False
    )
    parser.add_argument("--debug", action="store_true", help="debug", default=False)
    parser.add_argument(
        "--rm", action="store_true", default=False, help="Remove all previous results"
    )
    parser.add_argument(
        "--bigmodel_path",
        type=str,
        default="/data/script_base/xdab/state_dict.pth",
        help="bigmodel path",
    )
    parser.add_argument(
        "--roi_prefix", type=str, default="veroi_m", help="prefix of roi"
    )
    parser.add_argument(
        "--config",
        type=str,
        default="/workspace/configs/dino_mania.yaml",
        help="config file",
    )
    parser.add_argument(
        "--results_dir",
        type=str,
        default="/data/results/xeaa_mkv",
        help="results dir",
    )
    parser.add_argument("--num_samples", type=int, default=1, help="num_samples")
    parser.add_argument("--num_runs", type=int, default=1, help="num_runs")
    parser.add_argument(
        "--name", type=str, default="veroi_m", help="name of the experiment"
    )
    parser.add_argument("--n_rois", type=int, default=18, help="num rois")
    parser.add_argument("--start", type=int, default=0, help="start index")
    parser.add_argument("--end", type=int, default=18, help="end index")
    return parser


if __name__ == "__main__":
    parser = get_parser()
    args = parser.parse_args()

    roi_job_list = [[f"{args.roi_prefix}_{i}"] for i in range(1, args.n_rois + 1)]

    roi_job_list = roi_job_list[args.start : args.end]
    args.name = args.name + "_" + str(args.start) + "_" + str(args.end)

    tune_config = {
        "DATASET.ROIS": tune.grid_search(roi_job_list),
    }

    run_tune(
        args.name,
        args,
        tune_config,
        rm=args.rm,
        verbose=args.verbose,
        num_samples=args.num_samples,
    )


# TODO: make a better job scheduler
