# %%
import argparse
import copy
import fnmatch
from functools import partial
import glob
import operator
import os
import sys
from typing import Dict

import numpy as np
import pytorch_lightning as pl
import torch
from pytorch_lightning import loggers as pl_loggers

from ray import tune

import shutil

from tqdm import tqdm

from config import AutoConfig
from config_utils import dict_to_list, load_from_yaml
from datamodule import build_dm, AllDatamodule
from models import DVEModel, DarkVEModel, PDVEModel, VEModel
from topyneck import VoxelOutBlock

from run_utils import *


parser = argparse.ArgumentParser(description="run with all subjects to cluster")
parser.add_argument(
    "--projector_path",
    type=str,
    default="/data/script_base/xdab/state_dict.pth",
    help="bigmodel path",
)
parser.add_argument(
    "--config",
    type=str,
    default="/workspace/configs/dino_mania.yaml",
    help="config file",
)
parser.add_argument(
    "--results_dir", type=str, default="/data/results/xfaa/mkv", help="results dir"
)
parser.add_argument("--rm", action="store_true", help="remove old results")
parser.add_argument(
    "--name", type=str, default="veroi_m_gen3", help="name of the experiment"
)
parser.add_argument("--n_roi", type=int, default=18, help="number of roi")
parser.add_argument("--start", type=int, default=0, help="start index")
parser.add_argument("--end", type=int, default=18, help="end index")
parser.add_argument("--roi_prefix", type=str, default="veroi_m", help="prefix of roi")
parser.add_argument(
    "--dark_postfix", type=str, default=".mania_veroi_m_gen2_darkfull", help="dark postfix"
)
parser.add_argument("--dark", type=str, default="full", help="dark model type")

args = parser.parse_args()

if args.dark.startswith("top"):
    top_n = int(args.dark[3:])
    transfer_mat = np.load("/data/script_base/xeaac/veroi_m/transfer_mat.npy")


def run_stage_one(
    cfg: AutoConfig,
    progress=True,
    stage=1,
    projector_path=None,
    log_dir=None,
    sub_dir=None,
    dark="full",  # 'full', 'partial', 'no'
    n_roi=18,
    roi_prefix="veroi_m",
    **kwargs,
):
    cfg = copy.deepcopy(cfg)

    dm: AllDatamodule = build_dm(cfg)
    dm.setup()

    # sub_dir = f"step{stage}"

    cbs, lgs, log_dir, ckpt_dir = get_callbacks_and_loggers(
        cfg, log_dir=log_dir, sub_dir=sub_dir
    )

    soup_path = os.path.join(log_dir, "soup.pth")
    if os.path.exists(soup_path):
        print("soup exists, skipping")
        return soup_path

    os.makedirs(log_dir, exist_ok=True)

    model_args = (
        cfg,
        dm.num_voxel_dict,
        dm.roi_dict,
        dm.neuron_coords_dict,
        dm.noise_ceiling_dict,
    )

    torch.save(cfg, os.path.join(log_dir, "cfg.pth"))

    model = DarkVEModel(*model_args)

    state_dict = load_state_dict(projector_path)
    model.load_state_dict(state_dict, strict=False)

    freeze(model.backbone)
    freeze(model.neck.neuron_projectors)

    trainer = pl.Trainer(
        precision=cfg.TRAINER.PRECISION,
        accelerator="cuda",
        gradient_clip_val=cfg.TRAINER.GRADIENT_CLIP_VAL,
        # strategy=DDPStrategy(find_unused_parameters=False),
        devices=cfg.TRAINER.DEVICES,
        max_epochs=cfg.TRAINER.MAX_EPOCHS,
        max_steps=cfg.TRAINER.MAX_STEPS,
        val_check_interval=cfg.TRAINER.VAL_CHECK_INTERVAL,
        accumulate_grad_batches=cfg.TRAINER.ACCUMULATE_GRAD_BATCHES,
        limit_train_batches=cfg.TRAINER.LIMIT_TRAIN_BATCHES,
        limit_val_batches=cfg.TRAINER.LIMIT_VAL_BATCHES,
        log_every_n_steps=cfg.TRAINER.LOG_TRAIN_N_STEPS,
        callbacks=cbs,
        logger=lgs,
        enable_checkpointing=False
        if (stage == 2 or stage == 4 or stage == 0.1)
        else True,
        enable_progress_bar=progress,
    )

    trainer.fit(model, datamodule=dm)

    val_score, test_score = validate_test_log(trainer, model, dm, prefix="SOUP_NO/one/")

    best_k_models = None
    best_model_path = None
    if stage != 2 and stage != 4 and stage != 0.1:
        trainer.checkpoint_callback.to_yaml(os.path.join(log_dir, "checkpoint.yaml"))

        ckpt: ModelCheckpoint = trainer.checkpoint_callback
        best_k_models = ckpt.best_k_models
        best_model_path = ckpt.best_model_path

    from run_utils import greedy_soup_sh_voxel

    val_score, test_score = greedy_soup_sh_voxel(
        trainer, dm, model, best_k_models, log_dir
    )
    if cfg.TRAINER.CALLBACKS.CHECKPOINT.REMOVE:
        shutil.rmtree(ckpt_dir, ignore_errors=True)
    return soup_path


# def run_stage_two(cfg: AutoConfig, stage_1_soup_path):
#     (
#         best_k_models,
#         best_model_path,
#         trainer,
#         dm,
#         model,
#         log_dir,
#         ckpt_dir,
#         cbs,
#     ) = run_train_one_stage(
#         copy.deepcopy(cfg),
#         progress=True,
#         stage=2,
#         model_path=stage_1_soup_path,
#     )
#     model_path_2 = os.path.join(log_dir, "soup.pth")
#     if not os.path.exists(model_path_2):
#         val_score, test_score = uniform_soup_sp_voxel(trainer, dm, model, cbs, log_dir)
#         if cfg.TRAINER.CALLBACKS.CHECKPOINT.REMOVE:
#             shutil.rmtree(ckpt_dir, ignore_errors=True)

#     return model_path_2


def run(cfg: AutoConfig, args):
    cfg = copy.deepcopy(cfg)

    if args.dark == "full":
        cfg.LOSS.DARK.USE = True
        cfg.DATASET.ROIS = [f"{args.roi_prefix}_{i}" for i in range(1, args.n_roi + 1)]
        print(cfg.DATASET.ROIS)
        print(cfg.LOSS.DARK.GT_ROIS)
    elif args.dark == "self":
        cfg.DATASET.ROIS = cfg.LOSS.DARK.GT_ROIS
        cfg.LOSS.DARK.USE = True
    elif args.dark.startswith("top"):
        from_roi = int(cfg.LOSS.DARK.GT_ROIS[0].split("_")[-1]) - 1
        row = transfer_mat[from_roi]
        top_rois = np.argsort(row)[::-1][: int(args.dark[3:])]
        top_rois = [f"{args.roi_prefix}_{i+1}" for i in top_rois]
        cfg.DATASET.ROIS = top_rois
        print(cfg.DATASET.ROIS, cfg.LOSS.DARK.GT_ROIS)
        cfg.LOSS.DARK.USE = True
    else:
        raise ValueError(
            f"dark must be one of 'full', 'self', 'topn', but got {args.dark}"
        )

    cfg.DATASET.DARK_POSTFIX = args.dark_postfix
    soup1 = run_stage_one(
        cfg,
        dark=args.dark,
        sub_dir="step1",
        projector_path=args.projector_path,
        n_roi=args.n_roi,
        roi_prefix=args.roi_prefix,
    )
    # cfg.LOSS.DARK.USE = False
    # cfg.DATASET.ROIS = cfg.LOSS.DARK.GT_ROIS
    # cfg.DATASET.DARK_POSTFIX = ""
    # soup2 = run_stage_two(cfg, soup1)


def run_tune(tune_dict, cfg: AutoConfig, args):
    cfg.merge_from_list(dict_to_list(tune_dict))

    run(cfg, args)


def run_ray(name, args, tune_config, rm=False, verbose=True, num_samples=1):
    cfg_file_basename = os.path.basename(args.config).split(".")[0]

    cfg = load_from_yaml(args.config)
    cfg.RESULTS_DIR = args.results_dir
    cfg.RESULTS_DIR = os.path.join(cfg.RESULTS_DIR, cfg_file_basename)

    if rm:
        import shutil

        shutil.rmtree(os.path.join(cfg.RESULTS_DIR, name), ignore_errors=True)

    ana = tune.run(
        tune.with_parameters(run_tune, args=args, cfg=cfg),
        local_dir=cfg.RESULTS_DIR,
        config=tune_config,
        resources_per_trial={"cpu": 1, "gpu": 1},
        num_samples=num_samples,
        name=name,
        verbose=verbose,
        resume="AUTO+ERRORED",
    )


if __name__ == "__main__":
    roi_job_list = [[f"{args.roi_prefix}_{i}"] for i in range(1, args.n_roi + 1)]

    roi_job_list = roi_job_list[args.start : args.end]
    args.name = (
        args.name
        + "_"
        + "dark"
        + args.dark
        + "_"
        + str(args.start)
        + "_"
        + str(args.end)
    )

    tune_config = {
        "LOSS.DARK.GT_ROIS": tune.grid_search(roi_job_list),
    }

    run_ray(
        args.name,
        args,
        tune_config,
        rm=args.rm,
    )
