# %%
import argparse
import copy
import fnmatch
from functools import partial
import glob
import operator
import os
import sys
from typing import Dict

import numpy as np
import pytorch_lightning as pl
import torch
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks import (
    BackboneFinetuning,
    EarlyStopping,
    ModelCheckpoint,
)
from ray import tune
from ray.tune.integration.pytorch_lightning import (
    TuneReportCallback,
    TuneReportCheckpointCallback,
)
import shutil

from tqdm import tqdm

from callbacks import (
    DisableBN,
    EMAModel,
    SaveTopWeightEachVoxel,
    FreezeBackbone,
    ModifyBNMoment,
    RemoveCheckpoint,
    SaveFinalFC,
    EmptyCache,
    StageFinetuning,
    LoadBestCheckpointOnEnd,
    LoadBestCheckpointOnVal,
)
from config import AutoConfig
from config_utils import load_from_yaml
from datamodule import build_dm, AllDatamodule
from models import VEModel
from topyneck import VoxelOutBlock


def get_callbacks_and_loggers(
    cfg: AutoConfig, sub_dir: str = None, log_dir: str = None, version: str = "."
):
    log_dir = tune.get_trial_dir() if log_dir is None else log_dir
    log_dir = "tb_logs" if log_dir is None else log_dir
    if sub_dir is not None:
        log_dir = os.path.join(log_dir, sub_dir)
    ckpt_dir = os.path.join(log_dir, "ckpt")

    metrics_name = "VAL/PearsonCorrCoef/" + cfg.TRAINER.CALLBACKS.EARLY_STOP.SUBJECT
    callbacks = []
    callbacks.append(EmptyCache())
    if cfg.MODEL.BACKBONE.BN_MOMENTUM != -1:
        callbacks.append(ModifyBNMoment(cfg.MODEL.BACKBONE.BN_MOMENTUM))
    callbacks.append(
        EarlyStopping(
            monitor=metrics_name,
            min_delta=0.0,
            patience=cfg.TRAINER.CALLBACKS.EARLY_STOP.PATIENCE,
            verbose=False,
            mode="max",
        )
    )
    if cfg.TRAINER.CALLBACKS.CHECKPOINT.SAVE_TOP_K > 0:
        callbacks.append(
            ModelCheckpoint(
                monitor=metrics_name,
                dirpath=ckpt_dir,
                filename="{epoch:d}-{VAL/PearsonCorrCoef/mean:.6f}"
                + "-{VAL/MeanAbsoluteError/mean:.6f}"
                + "-{VAL/MeanSquaredError/mean:.6f}",
                auto_insert_metric_name=True,
                save_weights_only=True,
                save_top_k=cfg.TRAINER.CALLBACKS.CHECKPOINT.SAVE_TOP_K,
                mode="max",
            )
        )

    if cfg.TRAINER.CALLBACKS.CHECKPOINT.LOAD_BEST_ON_VAL:
        assert cfg.TRAINER.CALLBACKS.CHECKPOINT.SAVE_TOP_K > 0
        callbacks.append(LoadBestCheckpointOnVal())

    if cfg.TRAINER.CALLBACKS.CHECKPOINT.LOAD_BEST_ON_END:
        assert cfg.TRAINER.CALLBACKS.CHECKPOINT.SAVE_TOP_K > 0
        callbacks.append(LoadBestCheckpointOnEnd())

    if cfg.MODEL.BACKBONE.DISABLE_BN:
        from callbacks import DisableBN

        callbacks.append(DisableBN())

    from callbacks import SaveNeuronLocation

    callbacks.append(
        SaveNeuronLocation(
            log_dir,
            save=cfg.ANALYSIS.SAVE_NEURON_LOCATION,
            draw=cfg.ANALYSIS.DRAW_NEURON_LOCATION,
        )
    )

    loggers = []
    loggers.append(pl_loggers.TensorBoardLogger(log_dir, version=version))
    loggers.append(pl_loggers.CSVLogger(log_dir, version=version))

    return callbacks, loggers, log_dir, ckpt_dir


def freeze(model):
    from pytorch_lightning.callbacks.finetuning import BaseFinetuning

    BaseFinetuning.freeze(model, train_bn=True)
    return model


def log_metric(trainer, metric, keys, prefix=""):
    if isinstance(metric, list):
        assert len(metric) == 1
        metric = metric[-1]
    for key in keys:
        if key in metric:
            logger = trainer.logger.experiment
            logger.add_scalar(prefix + key, metric[key], trainer.global_step)


def validate_test_log(trainer, model, dm, prefix=""):
    metric = trainer.validate(model, datamodule=dm)[-1]
    log_metric(
        trainer,
        metric,
        ["VAL/PearsonCorrCoef/mean", "VAL/PearsonCorrCoef/challenge"],
        prefix=prefix,
    )
    val_score = metric["VAL/PearsonCorrCoef/mean"]
    metric = trainer.test(model, datamodule=dm)[-1]
    log_metric(
        trainer,
        metric,
        ["TEST/PearsonCorrCoef/mean", "TEST/PearsonCorrCoef/challenge"],
        prefix=prefix,
    )
    test_score = metric["TEST/PearsonCorrCoef/mean"]
    return val_score, test_score


def uniform_soup_sp_voxel(trainer, dm, model, cbs, log_dir):
    found = False
    for cb in cbs:
        if isinstance(cb, SaveTopWeightEachVoxel):
            vcb: SaveTopWeightEachVoxel = cb
            found = True
            break
    assert found, "SaveTopWeightEachVoxel not found"

    subject_list = list(vcb.w_queue.keys())

    for subject in subject_list:
        w_queue = vcb.w_queue[subject]
        b_queue = vcb.b_queue[subject]
        n_voxels = len(w_queue)
        n_layers = len(w_queue[0][0])
        new_w = []
        new_b = []
        for k in range(n_layers):
            new_w.append([])
            new_b.append([])
            for i in range(n_voxels):
                new_w[k].append([])
                new_b[k].append([])
                for j in range(len(w_queue[i])):
                    if w_queue[i][j] is not None:
                        new_w[k][i].append(w_queue[i][j][k])
                        new_b[k][i].append(b_queue[i][j][k])
        # mean over models
        for k in range(n_layers):
            for i in range(n_voxels):
                new_w[k][i] = torch.stack(new_w[k][i]).mean(dim=0)
                new_b[k][i] = torch.stack(new_b[k][i]).mean(dim=0)
        for k in range(n_layers):
            new_w[k] = torch.stack(new_w[k])
            new_b[k] = torch.stack(new_b[k])

        for k in range(n_layers):
            vo: VoxelOutBlock = model.neck.voxel_outs[subject]
            vo.weight[k].data = new_w[k]
            vo.bias[k].data = new_b[k]

    val_score, test_score = validate_test_log(
        trainer, model, dm, prefix="SOUP_SP/uniform/"
    )

    save_model_path = os.path.join(log_dir, "soup.pth")
    state_dict = model.state_dict()
    torch.save(state_dict, save_model_path)
    torch.save(test_score, os.path.join(log_dir, "soup_test_score.pth"))
    torch.save(val_score, os.path.join(log_dir, "soup_val_score.pth"))
    torch.save(test_score, os.path.join(log_dir, f"soup_test_score={test_score:.6f}"))
    torch.save(val_score, os.path.join(log_dir, f"soup_val_score={val_score:.6f}"))

    return val_score, test_score


def greedy_soup_sp_voxel(trainer, dm, model: VEModel, cbs, log_dir, target="heldout"):
    found = False
    for cb in cbs:
        if isinstance(cb, SaveTopWeightEachVoxel):
            vcb: SaveTopWeightEachVoxel = cb
            found = True
            break
    assert found, "SaveTopWeightEachVoxel not found"

    cfg = model.cfg

    subject_list = list(model.neck.voxel_outs.keys())
    n_models = vcb.top_n
    
    for subject in subject_list:
        print(f"greedy soup {subject}")
        num_vo_layers = len(vcb.w_queue[subject][0][0])
        n_voxels = len(vcb.w_queue[subject])
        best_score_so_far = [-114514 for i in range(n_voxels)]
        greedy_soup_ingredients_w = [[] for i in range(n_voxels)]
        greedy_soup_ingredients_b = [[] for i in range(n_voxels)]
        num_ingredients = [0 for i in range(n_voxels)]
        for j in range(n_models):
            new_w = [
                model.neck.voxel_outs[subject].weight[k].data.clone()
                for k in range(num_vo_layers)
            ]
            new_b = [
                model.neck.voxel_outs[subject].bias[k].data.clone()
                for k in range(num_vo_layers)
            ]
            for i in tqdm(range(n_voxels)):
                greedy_soup_ingredients_w[i].append(
                    vcb.w_queue[subject][i][j]
                )  # [n_voxels, n_models, num_layers]
                greedy_soup_ingredients_b[i].append(vcb.b_queue[subject][i][j])
                num_ingredients[i] += 1
                # num_ingredients = [
                #     len(greedy_soup_ingredients_w[w]) for w in range(n_voxels)
                # ]

                for k in range(num_vo_layers):
                    ws, bs = [], []
                    for jj in range(num_ingredients[i]):
                        ws.append(greedy_soup_ingredients_w[i][jj][k])
                        bs.append(greedy_soup_ingredients_b[i][jj][k])
                    ws = torch.stack(ws).mean(dim=0)
                    bs = torch.stack(bs).mean(dim=0)
                    new_w[k][i] = ws
                    new_b[k][i] = bs
            for k in range(num_vo_layers):
                device = model.neck.voxel_outs[subject].weight[k].device
                model.neck.voxel_outs[subject].weight[k].data = new_w[k].clone().to(device)
                model.neck.voxel_outs[subject].bias[k].data = new_b[k].clone().to(device)
            if target == "val":
                trainer.validate(model, dataloaders=dm.val_dataloader(subject=subject))
                current_score = model.voxel_score[subject]
            elif target == "heldout":
                trainer.test(model, dataloaders=dm.test_dataloader(subject=subject))
                current_score = model.voxel_score[subject]
            else:
                raise ValueError(f"Unknown greedy target {cfg.FINETUNE.SOUP.GREEDY_TARGET}")

            for i in range(n_voxels):
                if current_score[i] > best_score_so_far[i]:
                    best_score_so_far[i] = current_score[i]
                else:
                    greedy_soup_ingredients_w[i].pop()
                    greedy_soup_ingredients_b[i].pop()
                    num_ingredients[i] -= 1

            print(f"finishing greedy soup {j}/{n_models}")

        # load best soup
        new_w = [
            model.neck.voxel_outs[subject].weight[k].data.clone() for k in range(num_vo_layers)
        ]
        new_b = [
            model.neck.voxel_outs[subject].bias[k].data.clone() for k in range(num_vo_layers)
        ]
        for i in range(n_voxels):
            # num_ingredients = [len(greedy_soup_ingredients_w[w]) for w in range(n_voxels)]
            for k in range(num_vo_layers):
                ws, bs = [], []
                for jj in range(num_ingredients[i]):
                    ws.append(greedy_soup_ingredients_w[i][jj][k])
                    bs.append(greedy_soup_ingredients_b[i][jj][k])
                ws = torch.stack(ws).mean(dim=0)
                bs = torch.stack(bs).mean(dim=0)
                new_w[k][i] = ws
                new_b[k][i] = bs
        for k in range(num_vo_layers):
            device = model.neck.voxel_outs[subject].weight[k].device
            model.neck.voxel_outs[subject].weight[k].data = new_w[k].clone().to(device)
            model.neck.voxel_outs[subject].bias[k].data = new_b[k].clone().to(device)

    val_score, test_score = validate_test_log(
        trainer, model, dm, prefix="SOUP_SP/greedy/"
    )

    save_model_path = os.path.join(log_dir, "soup.pth")
    state_dict = model.state_dict()
    torch.save(state_dict, save_model_path)
    torch.save(test_score, os.path.join(log_dir, "soup_test_score.pth"))
    torch.save(val_score, os.path.join(log_dir, "soup_val_score.pth"))
    torch.save(test_score, os.path.join(log_dir, f"soup_test_score={test_score:.6f}"))
    torch.save(val_score, os.path.join(log_dir, f"soup_val_score={val_score:.6f}"))

    return val_score, test_score


def uniform_soup_sh_voxel(
    trainer,
    dm,
    model,
    best_k_models: Dict[str, float],
    log_dir: str,
    prefix="SOUP_RUNS/uniform",
):
    NUM_MODELS = len(best_k_models)
    for j, path in enumerate(best_k_models):
        state_dict = torch.load(path)["state_dict"]
        if j == 0:
            uniform_soup = {k: v * (1.0 / NUM_MODELS) for k, v in state_dict.items()}
        else:
            uniform_soup = {
                k: v * (1.0 / NUM_MODELS) + uniform_soup[k]
                for k, v in state_dict.items()
            }
    model.load_state_dict(uniform_soup)
    val_score, test_score = validate_test_log(
        trainer, model, dm, prefix=prefix
    )

    save_path = os.path.join(log_dir, "soup.pth")
    torch.save(uniform_soup, save_path)
    torch.save(test_score, os.path.join(log_dir, "soup_test_score.pth"))
    torch.save(val_score, os.path.join(log_dir, "soup_val_score.pth"))
    torch.save(test_score, os.path.join(log_dir, f"soup_test_score={test_score:.6f}"))
    torch.save(val_score, os.path.join(log_dir, f"soup_val_score={val_score:.6f}"))

    return val_score, test_score


def greedy_soup_sh_voxel(
    trainer,
    dm,
    model,
    best_k_models: Dict[str, float],
    log_dir: str,
    target="heldout",
    prefix="SOUP_SH/greedy/",
):
    NUM_MODELS = len(best_k_models)
    best_k_models = sorted(best_k_models.items(), key=operator.itemgetter(1))
    best_k_models.reverse()
    sorted_models = [x[0] for x in best_k_models]
    greedy_soup_ingredients = [sorted_models[0]]
    greedy_soup_params = torch.load(sorted_models[0])
    if "state_dict" in greedy_soup_params:
        greedy_soup_params = greedy_soup_params["state_dict"]
    # best_score_so_far = best_k_models[0][1]
    best_score_so_far = 0.0
    for j in range(0, NUM_MODELS):
        print(f"Greedy soup: {j}/{NUM_MODELS}")

        # Get the potential greedy soup, which consists of the greedy soup with the new model added.
        new_ingredient_params = torch.load(sorted_models[j])
        if "state_dict" in new_ingredient_params:
            new_ingredient_params = new_ingredient_params["state_dict"]
        num_ingredients = len(greedy_soup_ingredients)
        potential_greedy_soup_params = {
            k: greedy_soup_params[k].clone()
            * (num_ingredients / (num_ingredients + 1.0))
            + new_ingredient_params[k].clone() * (1.0 / (num_ingredients + 1))
            for k in new_ingredient_params
        }
        model.load_state_dict(potential_greedy_soup_params)
        if target == "val":
            ret = trainer.validate(model, datamodule=dm)
            current_score = ret[-1]["VAL/PearsonCorrCoef/mean"]
        elif target == "heldout":
            ret = trainer.test(model, datamodule=dm)
            current_score = ret[-1]["TEST/PearsonCorrCoef/mean"]
        else:
            raise ValueError(f"Invalid target: {target}")
        print(f"Current score: {current_score}, best score so far: {best_score_so_far}")

        if current_score > best_score_so_far:
            greedy_soup_ingredients.append(sorted_models[j])
            best_score_so_far = current_score
            greedy_soup_params = potential_greedy_soup_params
            print(f"Greedy soup improved to {len(greedy_soup_ingredients)} models.")

    model.load_state_dict(greedy_soup_params)
    val_score, test_score = validate_test_log(trainer, model, dm, prefix=prefix)

    save_path = os.path.join(log_dir, "soup.pth")
    torch.save(greedy_soup_params, save_path)
    torch.save(test_score, os.path.join(log_dir, "soup_test_score.pth"))
    torch.save(val_score, os.path.join(log_dir, "soup_val_score.pth"))
    torch.save(test_score, os.path.join(log_dir, f"soup_test_score={test_score:.6f}"))
    torch.save(val_score, os.path.join(log_dir, f"soup_val_score={val_score:.6f}"))

    return val_score, test_score


def greedy_soup_from_runs(exp_dir, rm_after_use=True):
    runs = []
    for run in os.listdir(exp_dir):
        if not run.startswith("tune"):
            continue
        run = os.path.join(exp_dir, run)
        if os.path.isdir(run):
            runs.append(run)

    hparms_path = glob.glob(os.path.join(runs[0], "**", "hparams.yaml"), recursive=True)
    assert len(hparms_path) == 1
    hparms_path = hparms_path[0]
    cfg = load_from_yaml(hparms_path)
    args_path = glob.glob(os.path.join(runs[0], "**", "model_args.pth"), recursive=True)
    assert len(args_path) == 1
    args_path = args_path[0]
    model_args = torch.load(args_path)

    model = VEModel(*model_args)

    dm = build_dm(cfg)
    dm.setup()

    from pytorch_lightning.loggers import TensorBoardLogger

    callbacks, loggers, log_dir, ckpt_dir = get_callbacks_and_loggers(
        cfg, log_dir=os.path.join(exp_dir, "soup_logs")
    )

    model.zero_flag = True

    trainer = pl.Trainer(
        precision=cfg.TRAINER.PRECISION,
        accelerator="cuda",
        devices=cfg.TRAINER.DEVICES,
        enable_progress_bar=False,
        callbacks=callbacks,
        logger=loggers,
    )

    best_k_models = {}
    for run in runs:
        # model_path = os.path.join(run, "soup.pth")
        model_path = glob.glob(run + "/**/soup.pth", recursive=True)
        assert len(model_path) == 1
        model_path = model_path[0]
        # if not os.path.exists(model_path):
        #     continue
        # score = torch.load(os.path.join(run, "soup_val_score.pth"))
        score_path = glob.glob(run + "/**/soup_val_score.pth", recursive=True)
        assert len(score_path) == 1
        score_path = score_path[0]
        score = torch.load(score_path)
        best_k_models[model_path] = score

    print(best_k_models)

    val_score, test_score = greedy_soup_sh_voxel(
        trainer, dm, model, best_k_models, exp_dir, target="heldout"
    )
    model_path = os.path.join(exp_dir, "soup.pth")

    if rm_after_use:
        for path in best_k_models:
            try:
                os.remove(path)
            except Exception as e:
                print(e)

    return model_path, val_score, test_score


def load_state_dict(model_path):
    state_dict = torch.load(model_path, map_location="cpu")
    if "state_dict" in state_dict:
        state_dict = state_dict["state_dict"]
    return state_dict


def run_train_one_stage(
    cfg: AutoConfig,
    progress=False,
    stage=1,
    model_path=None,
    conv_path=None,
    projector_path=None,
    log_dir=None,
    **kwargs,
):
    cfg = copy.deepcopy(cfg)
    if stage == 1:  # voxel shared part
        assert model_path is None
    if stage == 2 or stage == 4:  # voxel specific part
        cfg.TRAINER.CALLBACKS.CHECKPOINT.SAVE_TOP_K = 0
        cfg.TRAINER.CALLBACKS.CHECKPOINT.LOAD_BEST_ON_END = False
        cfg.TRAINER.CALLBACKS.EARLY_STOP.PATIENCE = 1000  # TODO: add to inf
        cfg.TRAINER.MAX_EPOCHS = cfg.TRAINER.STAGE_2_MAX_EPOCHS
        cfg.OPTIMIZER.LR = cfg.TRAINER.STAGE_2_LR
        # cfg.OPTIMIZER.SCHEDULER.CYCLE_LIMIT = 5
        # cfg.OPTIMIZER.SCHEDULER.K_DECAY = 1.0
        # cfg.OPTIMIZER.SCHEDULER.LR_MIN = cfg.TRAINER.STAGE_2_LR / 100
        # cfg.OPTIMIZER.SCHEDULER.T_INITIAL = 20
        cfg.OPTIMIZER.VOXEL_WEIGHT_DECAY = cfg.TRAINER.STAGE_2_WD
        # cfg.TRAINER.LR = cfg.TRAINER.STAGE_2_LR
        # cfg.TRAINER.LIMIT_TRAIN_BATCHES = 1.0
        cfg.MODEL.MAX_TRAIN_VOXELS = 1145141919810
        # if cfg.FINETUNE.TRAIN_SHARED:
        #     cfg.LOSS.SYNC.USE = True
        # else:
        cfg.LOSS.SYNC.USE = False
        cfg.ANALYSIS.DRAW_NEURON_LOCATION = False
        assert model_path is not None
    if stage == 3:  # voxel shared part
        cfg.OPTIMIZER.LR /= 10
        assert model_path is not None
    if stage == 0.1:  # transfer
        cfg.TRAINER.CALLBACKS.CHECKPOINT.SAVE_TOP_K = 0
        cfg.TRAINER.CALLBACKS.CHECKPOINT.LOAD_BEST_ON_END = False
        # cfg.TRAINER.MAX_EPOCHS = 20
        cfg.MODEL.MAX_TRAIN_VOXELS = 1145141919810
        cfg.FINETUNE.TOP_N = 10
        cfg.LOSS.SYNC.USE = False
        cfg.ANALYSIS.DRAW_NEURON_LOCATION = False
        assert conv_path is not None
        assert projector_path is not None

    dm: AllDatamodule = build_dm(cfg)
    dm.setup()

    # if stage == 2 or stage == 4:
    #     if cfg.STAGE_2.FIT_TO_VALIDATION:
    #         dm.dss[0] = dm.dss[1] # use validation set for training
    #         dm.dss[1] = dm.dss[2] # use test set for validation
    #         dm.dss[2] = dm.dss[0] # use validation set for test
    #         cfg.TRAINER.LIMIT_TRAIN_BATCHES = 1.0

    sub_dir = f"stage_{stage}"

    cbs, lgs, log_dir, ckpt_dir = get_callbacks_and_loggers(
        cfg, log_dir=log_dir, sub_dir=sub_dir
    )

    soup_path = os.path.join(log_dir, "soup.pth")
    if os.path.exists(soup_path):
        print("soup exists, skipping")
        return (None, None, None, None, None, log_dir, ckpt_dir, None)

    os.makedirs(log_dir, exist_ok=True)

    model_args = (
        cfg,
        dm.num_voxel_dict,
        dm.roi_dict,
        dm.neuron_coords_dict,
        dm.noise_ceiling_dict,
    )
    if stage != 0.1:
        # torch.save(model_args, os.path.join(log_dir, "model_args.pth"))
        torch.save(cfg, os.path.join(log_dir, "cfg.pth"))

    model = VEModel(*model_args)

    if model_path is not None:
        state_dict = load_state_dict(model_path)
        try:
            model.load_state_dict(state_dict, strict=False)
        except Exception as e:
            print(e)
            print("load state dict failed, trying to load backbone only")
            state_dict = {k: v for k, v in state_dict.items() if "voxel_outs" not in k}
            model.load_state_dict(state_dict, strict=False)

    if stage == 2 or stage == 4:
        # fix conv blocks and shared voxel part, only train voxel specific part
        if not cfg.TRAINER.STAGE_2_EMA:
            vcb = SaveTopWeightEachVoxel(top_n=cfg.FINETUNE.TOP_N)
            cbs.append(vcb)
        else:
            cb = EMAModel(model, beta=cfg.TRAINER.STAGE_2_EMA_BETA)
            cbs.append(cb)

        saved_neuron_projectors = copy.deepcopy(
            model.neck.neuron_projectors.state_dict()
        )
        saved_layer_gates = copy.deepcopy(model.neck.layer_gates.state_dict())
        for subject in model.subject_list:
            model.neck.add_subject(
                subject,
                dm.neuron_coords_dict[subject],
                overwrite=True,
                use_linear=True,
            )
        model.neck.neuron_projectors.load_state_dict(saved_neuron_projectors)
        model.neck.layer_gates.load_state_dict(saved_layer_gates)

        freeze(model.backbone)
        freeze(model.conv_blocks)
        # freeze(model.neck.neuron_projectors)
        if model.image_shifter is not None:
            freeze(model.image_shifter)
        if not cfg.FINETUNE.TRAIN_SHARED:
            freeze(model.neck.neuron_projectors)
            freeze(model.neck.layer_gates)
            if hasattr(model.neck, "eye_shifters"):
                freeze(model.neck.eye_shifters)
            if hasattr(model.neck, "neuron_shifters"):
                freeze(model.neck.neuron_shifters)
    if stage == 3:
        # fix voxel specific part, only train voxel shared part
        saved_voxel_outs = copy.deepcopy(model.neck.voxel_outs.state_dict())
        saved_neuron_projectors = copy.deepcopy(
            model.neck.neuron_projectors.state_dict()
        )
        saved_layer_gates = copy.deepcopy(model.neck.layer_gates.state_dict())
        # saved_neck_params = copy.deepcopy(model.neck.state_dict())
        model = VEModel(*model_args)
        model.neck.voxel_outs.load_state_dict(saved_voxel_outs)
        model.neck.neuron_projectors.load_state_dict(saved_neuron_projectors)
        model.neck.layer_gates.load_state_dict(saved_layer_gates)
        # model.neck.load_state_dict(saved_neck_params)
        freeze(model.backbone)
        freeze(model.neck.neuron_projectors)
        freeze(model.neck.voxel_outs)

    if stage == 0.1:
        vcb = SaveTopWeightEachVoxel(top_n=cfg.FINETUNE.TOP_N)
        cbs.append(vcb)
        # measure transfer score between voxels
        state_dict = load_state_dict(conv_path)
        state_dict = {k: v for k, v in state_dict.items() if "conv_blocks" in k}
        model.load_state_dict(state_dict, strict=False)
        saved_conv_blocks = copy.deepcopy(model.conv_blocks.state_dict())
        state_dict = load_state_dict(projector_path)
        state_dict = {k: v for k, v in state_dict.items() if "voxel_outs" not in k}
        model.load_state_dict(state_dict, strict=False)
        # saved_neuron_projectors = copy.deepcopy(model.neck.neuron_projectors.state_dict())
        # saved_layer_gates = copy.deepcopy(model.neck.layer_gates.state_dict())
        model.conv_blocks.load_state_dict(saved_conv_blocks)
        freeze(model.backbone)
        freeze(model.conv_blocks)
        # freeze(model.neck.neuron_projectors)
        if model.image_shifter is not None:
            freeze(model.image_shifter)
        if not cfg.FINETUNE.TRAIN_SHARED:
            freeze(model.neck.neuron_projectors)
            freeze(model.neck.layer_gates)
            if hasattr(model.neck, "eye_shifters"):
                freeze(model.neck.eye_shifters)
            if hasattr(model.neck, "neuron_shifters"):
                freeze(model.neck.neuron_shifters)

    trainer = pl.Trainer(
        precision=cfg.TRAINER.PRECISION,
        accelerator="cuda",
        gradient_clip_val=cfg.TRAINER.GRADIENT_CLIP_VAL,
        # strategy=DDPStrategy(find_unused_parameters=False),
        devices=cfg.TRAINER.DEVICES,
        max_epochs=cfg.TRAINER.MAX_EPOCHS,
        max_steps=cfg.TRAINER.MAX_STEPS,
        val_check_interval=cfg.TRAINER.VAL_CHECK_INTERVAL,
        accumulate_grad_batches=cfg.TRAINER.ACCUMULATE_GRAD_BATCHES,
        limit_train_batches=cfg.TRAINER.LIMIT_TRAIN_BATCHES,
        limit_val_batches=cfg.TRAINER.LIMIT_VAL_BATCHES,
        log_every_n_steps=cfg.TRAINER.LOG_TRAIN_N_STEPS,
        callbacks=cbs,
        logger=lgs,
        enable_checkpointing=False
        if (stage == 2 or stage == 4 or stage == 0.1)
        else True,
        enable_progress_bar=progress,
    )

    trainer.fit(model, datamodule=dm)

    val_score, test_score = validate_test_log(trainer, model, dm, prefix="SOUP_NO/one/")

    best_k_models = None
    best_model_path = None
    if stage != 2 and stage != 4 and stage != 0.1:
        trainer.checkpoint_callback.to_yaml(os.path.join(log_dir, "checkpoint.yaml"))

        ckpt: ModelCheckpoint = trainer.checkpoint_callback
        best_k_models = ckpt.best_k_models
        best_model_path = ckpt.best_model_path

    return (best_k_models, best_model_path, trainer, dm, model, log_dir, ckpt_dir, cbs)
