from models import DVEModel, PDVEModel
from train import *
from run_utils import *


def run_train_one(
    cfg: AutoConfig,
    progress=True,
    stage=1,
    projector_path=None,
    log_dir=None,
    sub_dir=None,
    **kwargs,
):
    cfg = copy.deepcopy(cfg)

    dm: AllDatamodule = build_dm(cfg)
    dm.setup()

    # sub_dir = f"step{stage}"

    cbs, lgs, log_dir, ckpt_dir = get_callbacks_and_loggers(
        cfg, log_dir=log_dir, sub_dir=sub_dir
    )

    # soup_path = os.path.join(log_dir, "soup.pth")
    # if os.path.exists(soup_path):
    #     print("soup exists, skipping")
    #     return (None, None, None, None, None, log_dir, ckpt_dir, None)

    os.makedirs(log_dir, exist_ok=True)

    model_args = (
        cfg,
        dm.num_voxel_dict,
        dm.roi_dict,
        dm.neuron_coords_dict,
        dm.noise_ceiling_dict,
    )

    torch.save(cfg, os.path.join(log_dir, "cfg.pth"))

    model = PDVEModel(*model_args)

    state_dict = load_state_dict(projector_path)
    model.load_state_dict(state_dict, strict=False)
    
    freeze(model.backbone)
    freeze(model.neck.neuron_projectors)

    trainer = pl.Trainer(
        precision=cfg.TRAINER.PRECISION,
        accelerator="cuda",
        gradient_clip_val=cfg.TRAINER.GRADIENT_CLIP_VAL,
        # strategy=DDPStrategy(find_unused_parameters=False),
        devices=cfg.TRAINER.DEVICES,
        max_epochs=cfg.TRAINER.MAX_EPOCHS,
        max_steps=cfg.TRAINER.MAX_STEPS,
        val_check_interval=cfg.TRAINER.VAL_CHECK_INTERVAL,
        accumulate_grad_batches=cfg.TRAINER.ACCUMULATE_GRAD_BATCHES,
        limit_train_batches=cfg.TRAINER.LIMIT_TRAIN_BATCHES,
        limit_val_batches=cfg.TRAINER.LIMIT_VAL_BATCHES,
        log_every_n_steps=cfg.TRAINER.LOG_TRAIN_N_STEPS,
        callbacks=cbs,
        logger=lgs,
        enable_checkpointing=False
        if (stage == 2 or stage == 4 or stage == 0.1)
        else True,
        enable_progress_bar=progress,
    )

    trainer.fit(model, datamodule=dm)

    val_score, test_score = validate_test_log(trainer, model, dm, prefix="SOUP_NO/one/")

    best_k_models = None
    best_model_path = None
    if stage != 2 and stage != 4 and stage != 0.1:
        trainer.checkpoint_callback.to_yaml(os.path.join(log_dir, "checkpoint.yaml"))

        ckpt: ModelCheckpoint = trainer.checkpoint_callback
        best_k_models = ckpt.best_k_models
        best_model_path = ckpt.best_model_path

    from run_utils import greedy_soup_sh_voxel

    val_score, test_score = greedy_soup_sh_voxel(
        trainer, dm, model, best_k_models, log_dir
    )
    soup_path = os.path.join(log_dir, "soup.pth")
    if cfg.TRAINER.CALLBACKS.CHECKPOINT.REMOVE:
        shutil.rmtree(ckpt_dir, ignore_errors=True)
    # return (best_k_models, best_model_path, trainer, dm, model, log_dir, ckpt_dir, cbs)
    return soup_path


def run_stage_two(cfg: AutoConfig, stage_1_soup_path):
    (
        best_k_models,
        best_model_path,
        trainer,
        dm,
        model,
        log_dir,
        ckpt_dir,
        cbs,
    ) = run_train_one_stage(
        copy.deepcopy(cfg),
        progress=True,
        stage=2,
        model_path=stage_1_soup_path,
    )
    model_path_2 = os.path.join(log_dir, "soup.pth")
    if not os.path.exists(model_path_2):
        val_score, test_score = uniform_soup_sp_voxel(trainer, dm, model, cbs, log_dir)
        if cfg.TRAINER.CALLBACKS.CHECKPOINT.REMOVE:
            shutil.rmtree(ckpt_dir, ignore_errors=True)

    return model_path_2

if __name__ == "__main__":
    cfg : AutoConfig = load_from_yaml("/workspace/configs/dino_base.yaml")
    cfg.LOSS.DARK.USE = True
    cfg.DATASET.DARK_POSTFIX = ".htroi_gen1"
    cfg.DATASET.SUBJECT_LIST = ["NSD_01", "NSD_04", "NSD_07", "HCP"]
    cfg.DATASET.ROIS = ["htroi_11"]
    cfg.LOSS.DARK.GT_ROIS = cfg.DATASET.ROIS
    cfg.OPTIMIZER.LR = 3e-3
    cfg.LOSS.DARK.MAX_EPOCH = 100
    cfg.LOSS.DARK.IGNORE_OTHER_ROIS = True
    
    # cfg.TRAINER.DEVICES = [2]
    
    sub_dir="pd_nobug_4sub_htroi11_run2"
    # sub_dir = "debug"
    
    s1 = run_train_one(cfg, projector_path="/data/script_base/xdab/state_dict.pth", sub_dir=sub_dir)
    run_stage_two(cfg, s1)