import operator
import os
import sys
from typing import Dict

import numpy as np
import pytorch_lightning as pl
import torch
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks import (
    BackboneFinetuning,
    EarlyStopping,
    ModelCheckpoint,
)
from ray import tune
from ray.tune.integration.pytorch_lightning import (
    TuneReportCallback,
    TuneReportCheckpointCallback,
)
import shutil

from tqdm import tqdm

from callbacks import (
    DisableBN,
    SaveTopWeightEachVoxel,
    FreezeBackbone,
    ModifyBNMoment,
    RemoveCheckpoint,
    SaveFinalFC,
    EmptyCache,
    StageFinetuning,
    LoadBestCheckpointOnEnd,
    LoadBestCheckpointOnVal,
)
from config import AutoConfig
from datamodule import build_dm, AllDatamodule
from models import VEModel
from topyneck import VoxelOutBlock


def get_callbacks_and_loggers(cfg: AutoConfig, sub_dir: str = None):

    log_dir = tune.get_trial_dir()
    ray_flag = True if log_dir is not None else False
    log_dir = log_dir if ray_flag else "tb_logs"
    if sub_dir is not None:
        log_dir = os.path.join(log_dir, sub_dir)
    ckpt_dir = os.path.join(log_dir, "ckpt")

    metrics_name = "VAL/PearsonCorrCoef/" + cfg.TRAINER.CALLBACKS.EARLY_STOP.SUBJECT
    callbacks = []
    callbacks.append(EmptyCache())
    if cfg.MODEL.BACKBONE.BN_MOMENTUM != -1:
        callbacks.append(ModifyBNMoment(cfg.MODEL.BACKBONE.BN_MOMENTUM))
    callbacks.append(
        EarlyStopping(
            monitor=metrics_name,
            min_delta=0.001,
            patience=cfg.TRAINER.CALLBACKS.EARLY_STOP.PATIENCE,
            verbose=False,
            mode="max",
        )
    )
    # if (
    #     cfg.MODEL.BACKBONE.FREEZE
    #     or cfg.TRAINER.CALLBACKS.BACKBONE.INITIAL_RATIO_LR == 0
    # ):
    #     cfg.TRAINER.CALLBACKS.BACKBONE.UN_FREEZE_AT_EPOCH = 1145141919810
    # callbacks.append(
    #     StageFinetuning(
    #         unfreeze_backbone_at_epoch=cfg.TRAINER.CALLBACKS.BACKBONE.UN_FREEZE_AT_EPOCH
    #         if not cfg.MODEL.BACKBONE.FREEZE
    #         else 1145141919810,
    #         backbone_initial_ratio_lr=cfg.TRAINER.CALLBACKS.BACKBONE.INITIAL_RATIO_LR,
    #         should_align=cfg.TRAINER.CALLBACKS.BACKBONE.SHOULD_ALIGN,
    #         train_bn=cfg.TRAINER.CALLBACKS.BACKBONE.TRAIN_BN,
    #         verbose=cfg.TRAINER.CALLBACKS.BACKBONE.VERBOSE,
    #         # freeze_modules=cfg.TRAINER.CALLBACKS.BACKBONE.FREEZE_MODULES,
    #         unfreeze_modules=cfg.TRAINER.CALLBACKS.BACKBONE.UNFREEZE_MODULES,
    #     )
    # )
    # callbacks.append(
    #     BackboneFinetuning(
    #         unfreeze_backbone_at_epoch=cfg.TRAINER.CALLBACKS.BACKBONE.UN_FREEZE_AT_EPOCH,
    #         backbone_initial_ratio_lr=cfg.TRAINER.CALLBACKS.BACKBONE.INITIAL_RATIO_LR,
    #         should_align=cfg.TRAINER.CALLBACKS.BACKBONE.SHOULD_ALIGN,
    #         train_bn=cfg.TRAINER.CALLBACKS.BACKBONE.TRAIN_BN,
    #         verbose=cfg.TRAINER.CALLBACKS.BACKBONE.VERBOSE,
    #     )
    # )
    if cfg.TRAINER.CALLBACKS.CHECKPOINT.SAVE_TOP_K > 0:
        callbacks.append(
            ModelCheckpoint(
                monitor=metrics_name,
                dirpath=ckpt_dir,
                filename="{epoch:d}-{VAL/PearsonCorrCoef/mean:.6f}"
                + "-{VAL/MeanAbsoluteError/mean:.6f}"
                + "-{VAL/MeanSquaredError/mean:.6f}",
                auto_insert_metric_name=True,
                save_weights_only=True,
                save_top_k=cfg.TRAINER.CALLBACKS.CHECKPOINT.SAVE_TOP_K,
                mode="max",
            )
        )

    if cfg.TRAINER.CALLBACKS.CHECKPOINT.LOAD_BEST_ON_VAL:
        assert cfg.TRAINER.CALLBACKS.CHECKPOINT.SAVE_TOP_K > 0
        callbacks.append(LoadBestCheckpointOnVal())

    if cfg.TRAINER.CALLBACKS.CHECKPOINT.LOAD_BEST_ON_END:
        assert cfg.TRAINER.CALLBACKS.CHECKPOINT.SAVE_TOP_K > 0
        callbacks.append(LoadBestCheckpointOnEnd())

    if cfg.TRAINER.CALLBACKS.SAVE_OUTPUT:
        assert cfg.TRAINER.CALLBACKS.CHECKPOINT.SAVE_TOP_K > 0
        from callbacks import SaveOutput

        # this need to be before remove checkpoint
        callbacks.append(SaveOutput(log_dir))
    # if cfg.TRAINER.CALLBACKS.CHECKPOINT.REMOVE:
    #     callbacks.append(RemoveCheckpoint(ckpt_dir))

    # if ray_flag:
    #     callbacks.append(TuneReportCallback())

    if cfg.MODEL.BACKBONE.DISABLE_BN:
        from callbacks import DisableBN

        callbacks.append(DisableBN())

    if cfg.ANALYSIS.SAVE_LAST_LINEAR_LAYER:
        callbacks.append(SaveFinalFC(log_dir))
    if cfg.ANALYSIS.TRANSFER:
        from transfer import Transfer

        callbacks.append(Transfer())

    from callbacks import SaveNeuronLocation

    callbacks.append(
        SaveNeuronLocation(log_dir, save=cfg.ANALYSIS.SAVE_NEURON_LOCATION, draw=cfg.ANALYSIS.DRAW_NEURON_LOCATION)
    )

    loggers = []
    loggers.append(
        pl_loggers.TensorBoardLogger(log_dir, version="." if ray_flag else None)
    )
    loggers.append(pl_loggers.CSVLogger(log_dir, version="." if ray_flag else None))

    return callbacks, loggers, log_dir, ckpt_dir


def freeze(model):
    from pytorch_lightning.callbacks.finetuning import BaseFinetuning

    BaseFinetuning.freeze(model, train_bn=True)
    return model


def run_train(cfg: AutoConfig, trainer: pl.Trainer = None, **kwargs):

    if cfg.STAGE == "pretrain":
        cfg.MODEL.BACKBONE.FREEZE = True
    if cfg.STAGE == "projector":
        cfg.MODEL.BACKBONE.FREEZE = True
    if cfg.STAGE == "finetune":
        assert len(cfg.DATASET.SUBJECT_LIST) == 1
        cfg.MODEL.BACKBONE.FREEZE = True
        cfg.TRAINER.CALLBACKS.CHECKPOINT.SAVE_TOP_K = 0
        cfg.TRAINER.CALLBACKS.CHECKPOINT.LOAD_BEST_ON_END = False
        cfg.TRAINER.CALLBACKS.EARLY_STOP.PATIENCE = 100
        # cfg.TRAINER.MAX_EPOCHS = 100
        cfg.MODEL.MAX_TRAIN_VOXELS = 1145141919810
        cfg.LOSS.SYNC.USE = False

    dm: AllDatamodule = build_dm(cfg)
    dm.setup()

    callbacks, loggers, log_dir, ckpt_dir = get_callbacks_and_loggers(cfg)
    os.makedirs(log_dir, exist_ok=True)

    model_args = (
        cfg,
        dm.num_voxel_dict,
        dm.roi_dict,
        dm.neuron_coords_dict,
        dm.noise_ceiling_dict,
    )
    torch.save(cfg, os.path.join(log_dir, "cfg.pth"))

    model = VEModel(*model_args)

    if cfg.STAGE == "finetune":
        vcb = SaveTopWeightEachVoxel(top_n=cfg.FINETUNE.TOP_N)
        callbacks.append(vcb)
        subject = cfg.DATASET.SUBJECT_LIST[0]
        model.neck.add_subject(
            subject,
            dm.neuron_coords_dict[subject],
            overwrite=True,
            use_linear=cfg.FINETUNE.USE_LINEAR,
        )
        # if cfg.FINETUNE.SOURCE == "single":
        #     d = 'single_subject_models'
        # elif cfg.FINETUNE.SOURCE == "all_nsd":
        #     d = 'all_nsd'

        d = cfg.FINETUNE.SOURCE
        s = subject
        if d == 'all_nsd':
            s = 'all_nsd'
        conv_block_w_path = os.path.join(
            f"/data/weights/conv_blocks/{d}", f"{s}.pth"
        )
        w = torch.load(conv_block_w_path)
        model.load_state_dict(w, strict=False)
        voxel_shared_w_path = os.path.join(
            f"/data/weights/voxel_shared/{d}", f"{s}.pth"
        )
        w = torch.load(voxel_shared_w_path)
        model.load_state_dict(w, strict=False)
            # raise NotImplementedError
        # else:
        #     raise ValueError

        freeze(model.backbone)
        freeze(model.conv_blocks)
        if model.image_shifter is not None:
            freeze(model.image_shifter)
        freeze(model.neck.neuron_projectors)
        freeze(model.neck.layer_gates)
        freeze(model.neck.eye_shifters)
        freeze(model.neck.neuron_shifters)

    ### TODO: load pretrained model, and projector

    # torch._dynamo.config.verbose=True
    # model.backbone = torch.compile(model.backbone)  # 2.0 API
    # model.neck = torch.compile(model.neck)  # 2.0 API
    # model = torch.compile(model)  # 2.0 API

    from pytorch_lightning.strategies import DDPStrategy

    # ddp = DDPStrategy(process_group_backend="nccl")

    if trainer is None:
        trainer = pl.Trainer(
            precision=cfg.TRAINER.PRECISION,
            accelerator="cuda",
            gradient_clip_val=cfg.TRAINER.GRADIENT_CLIP_VAL,
            # strategy=DDPStrategy(find_unused_parameters=False),
            devices=cfg.TRAINER.DEVICES,
            max_epochs=cfg.TRAINER.MAX_EPOCHS,
            val_check_interval=cfg.TRAINER.VAL_CHECK_INTERVAL,
            accumulate_grad_batches=cfg.TRAINER.ACCUMULATE_GRAD_BATCHES,
            limit_train_batches=cfg.TRAINER.LIMIT_TRAIN_BATCHES,
            limit_val_batches=cfg.TRAINER.LIMIT_VAL_BATCHES,
            log_every_n_steps=cfg.TRAINER.LOG_TRAIN_N_STEPS,
            callbacks=callbacks,
            logger=loggers,
            enable_checkpointing=False if cfg.STAGE == "finetune" else True,
            enable_progress_bar=kwargs["progress"]
            if "progress" in kwargs.keys()
            else True,
        )

    print("Length of training dataloader: ")
    print(dm.train_dataloader().__len__())
    print()

    trainer.fit(model, datamodule=dm)

    def log_metric(metric, keys, prefix=""):
        if isinstance(metric, list):
            assert len(metric) == 1
            metric = metric[-1]
        for key in keys:
            if key in metric:
                logger = trainer.logger.experiment
                logger.add_scalar(prefix + key, metric[key], trainer.global_step)

    def validate_test_log(model, dm, prefix=""):
        metric = trainer.validate(model, datamodule=dm)[-1]
        log_metric(
            metric,
            ["VAL/PearsonCorrCoef/mean", "VAL/PearsonCorrCoef/challenge"],
            prefix=prefix,
        )
        val_score = metric["VAL/PearsonCorrCoef/mean"]
        metric = trainer.test(model, datamodule=dm)[-1]
        log_metric(
            metric,
            ["TEST/PearsonCorrCoef/mean", "TEST/PearsonCorrCoef/challenge"],
            prefix=prefix,
        )
        return val_score

    validate_test_log(model, dm, prefix="SOUP/single/")

    if cfg.STAGE == "finetune":
        n_voxels = len(vcb.w_queue)
        # n_models = len(vcb.w_queue[0])
        num_layers = len(vcb.w_queue[0][0])
        if cfg.FINETUNE.SOUP == "uniform":
            new_w = []
            new_b = []
            for k in range(num_layers):
                new_w.append([])
                new_b.append([])
                for i in range(n_voxels):
                    new_w[k].append([])
                    new_b[k].append([])
                    for j in range(len(vcb.w_queue[i])):
                        new_w[k][i].append(vcb.w_queue[i][j][k])
                        new_b[k][i].append(vcb.b_queue[i][j][k])
            # mean over models
            for k in range(num_layers):
                for i in range(n_voxels):
                    new_w[k][i] = torch.stack(new_w[k][i]).mean(dim=0)
                    new_b[k][i] = torch.stack(new_b[k][i]).mean(dim=0)
            for k in range(num_layers):
                new_w[k] = torch.stack(new_w[k])
                new_b[k] = torch.stack(new_b[k])

            for k in range(num_layers):
                vo: VoxelOutBlock = model.neck.voxel_outs[subject]
                vo.weight[k].data = new_w[k]
                vo.bias[k].data = new_b[k]

            validate_test_log(model, dm, prefix="FINETUNE/uniform/")
        elif cfg.FINETUNE.SOUP == "greedy":

            @torch.no_grad()
            def do_greedy_soup():
                n_models = vcb.top_n
                best_score_so_far = [-114514 for i in range(n_voxels)]
                greedy_soup_ingredients_w = [[] for i in range(n_voxels)]
                greedy_soup_ingredients_b = [[] for i in range(n_voxels)]
                for j in range(1, n_models):
                    new_w = [
                        model.neck.voxel_outs[subject].weight[k].data.clone()
                        for k in range(num_layers)
                    ]
                    new_b = [
                        model.neck.voxel_outs[subject].bias[k].data.clone()
                        for k in range(num_layers)
                    ]
                    for i in tqdm(range(n_voxels)):
                        greedy_soup_ingredients_w[i].append(
                            vcb.w_queue[i][j]
                        )  # [n_voxels, n_models, num_layers]
                        greedy_soup_ingredients_b[i].append(vcb.b_queue[i][j])
                        num_ingredients = [
                            len(greedy_soup_ingredients_w[w]) for w in range(n_voxels)
                        ]

                        for k in range(num_layers):
                            ws, bs = [], []
                            for jj in range(num_ingredients[i]):
                                ws.append(greedy_soup_ingredients_w[i][jj][k])
                                bs.append(greedy_soup_ingredients_b[i][jj][k])
                            ws = torch.stack(ws).mean(dim=0)
                            bs = torch.stack(bs).mean(dim=0)
                            new_w[k][i] = ws
                            new_b[k][i] = bs
                    for k in range(num_layers):
                        device = model.neck.voxel_outs[subject].weight[k].device
                        model.neck.voxel_outs[subject].weight[k].data = (
                            new_w[k].clone().to(device)
                        )
                        model.neck.voxel_outs[subject].bias[k].data = (
                            new_b[k].clone().to(device)
                        )
                    if cfg.FINETUNE.SOUP_TARGET == "val":
                        trainer.validate(model, datamodule=dm)
                        current_score = model.voxel_score
                    elif cfg.FINETUNE.SOUP_TARGET == "heldout":
                        trainer.test(model, datamodule=dm)
                        current_score = model.voxel_score
                    else:
                        raise ValueError(
                            f"Unknown greedy target {cfg.FINETUNE.SOUP.GREEDY_TARGET}"
                        )

                    for i in range(n_voxels):
                        if current_score[i] > best_score_so_far[i]:
                            best_score_so_far[i] = current_score[i]
                        else:
                            greedy_soup_ingredients_w[i].pop()
                            greedy_soup_ingredients_b[i].pop()

                    print(f"finishing greedy soup {j}/{n_models}")

                # load best soup
                new_w = [
                    model.neck.voxel_outs[subject].weight[k].data.clone()
                    for k in range(num_layers)
                ]
                new_b = [
                    model.neck.voxel_outs[subject].bias[k].data.clone()
                    for k in range(num_layers)
                ]
                for i in range(n_voxels):
                    num_ingredients = [
                        len(greedy_soup_ingredients_w[w]) for w in range(n_voxels)
                    ]
                    for k in range(num_layers):
                        ws, bs = [], []
                        for jj in range(num_ingredients[i]):
                            ws.append(greedy_soup_ingredients_w[i][jj][k])
                            bs.append(greedy_soup_ingredients_b[i][jj][k])
                        ws = torch.stack(ws).mean(dim=0)
                        bs = torch.stack(bs).mean(dim=0)
                        new_w[k][i] = ws
                        new_b[k][i] = bs
                for k in range(num_layers):
                    device = model.neck.voxel_outs[subject].weight[k].device
                    model.neck.voxel_outs[subject].weight[k].data = (
                        new_w[k].clone().to(device)
                    )
                    model.neck.voxel_outs[subject].bias[k].data = (
                        new_b[k].clone().to(device)
                    )

                return model

            model = do_greedy_soup()

            validate_test_log(model, dm, prefix="FINETUNE/greedy/")
        return

    trainer.checkpoint_callback.to_yaml(os.path.join(log_dir, "checkpoint.yaml"))
    path = trainer.checkpoint_callback.best_model_path

    if not cfg.MODEL_SOUP.USE:
        if cfg.TRAINER.CALLBACKS.CHECKPOINT.REMOVE:
            shutil.rmtree(ckpt_dir, ignore_errors=True)
        return model, path

    ### do model soup

    ckpt: ModelCheckpoint = trainer.checkpoint_callback
    best_k_models = ckpt.best_k_models
    NUM_MODELS = len(best_k_models)

    recipe = cfg.MODEL_SOUP.RECIPE

    if "uniform" in recipe:
        for j, path in enumerate(best_k_models):
            state_dict = torch.load(path)["state_dict"]
            if j == 0:
                uniform_soup = {
                    k: v * (1.0 / NUM_MODELS) for k, v in state_dict.items()
                }
            else:
                uniform_soup = {
                    k: v * (1.0 / NUM_MODELS) + uniform_soup[k]
                    for k, v in state_dict.items()
                }
        model.load_state_dict(uniform_soup)
        torch.save(uniform_soup, os.path.join(log_dir, "uniform_soup.pth"))
        validate_test_log(model, dm, prefix="SOUP/uniform/")

    if "greedy" in recipe:
        best_k_models = sorted(best_k_models.items(), key=operator.itemgetter(1))
        best_k_models.reverse()
        sorted_models = [x[0] for x in best_k_models]
        greedy_soup_ingredients = [sorted_models[0]]
        greedy_soup_params = torch.load(sorted_models[0])["state_dict"]
        # best_score_so_far = best_k_models[0][1]
        best_score_so_far = 0.0
        for j in range(0, NUM_MODELS):
            print(f"Greedy soup: {j}/{NUM_MODELS}")

            # Get the potential greedy soup, which consists of the greedy soup with the new model added.
            new_ingredient_params = torch.load(sorted_models[j])["state_dict"]
            num_ingredients = len(greedy_soup_ingredients)
            potential_greedy_soup_params = {
                k: greedy_soup_params[k].clone()
                * (num_ingredients / (num_ingredients + 1.0))
                + new_ingredient_params[k].clone() * (1.0 / (num_ingredients + 1))
                for k in new_ingredient_params
            }
            model.load_state_dict(potential_greedy_soup_params)
            if cfg.MODEL_SOUP.GREEDY_TARGET == "val":
                current_score = trainer.validate(model, datamodule=dm)[-1][
                    "VAL/PearsonCorrCoef/mean"
                ]
            elif cfg.MODEL_SOUP.GREEDY_TARGET == "heldout":
                current_score = trainer.test(model, datamodule=dm)[-1][
                    "TEST/PearsonCorrCoef/mean"
                ]
            else:
                raise ValueError(
                    f"Invalid cfg.MODEL_SOUP.GREEDY_TARGET: {cfg.MODEL_SOUP.GREEDY_TARGET}"
                )
            print(
                f"Current score: {current_score}, best score so far: {best_score_so_far}"
            )

            if current_score > best_score_so_far:
                greedy_soup_ingredients.append(sorted_models[j])
                best_score_so_far = current_score
                greedy_soup_params = potential_greedy_soup_params
                print(f"Greedy soup improved to {len(greedy_soup_ingredients)} models.")

        model.load_state_dict(greedy_soup_params)
        torch.save(greedy_soup_params, os.path.join(log_dir, "greedy_soup.pth"))
        validate_test_log(model, dm, prefix="SOUP/greedy/")

    if cfg.TRAINER.CALLBACKS.CHECKPOINT.REMOVE:
        shutil.rmtree(ckpt_dir, ignore_errors=True)


def run_tune(tune_dict: Dict, cfg: AutoConfig, **kwargs):
    # from time import sleep
    # max_t = 10
    # rand_t = np.random.rand() * max_t
    # sleep(rand_t)

    from config_utils import dict_to_list

    cfg.merge_from_list(dict_to_list(tune_dict))
    run_train(cfg, **kwargs)


if __name__ == "__main__":

    import shutil

    # shutil.rmtree("tb_logs", ignore_errors=True)

    from config_utils import get_cfg_defaults, load_from_yaml

    cfg = load_from_yaml("/workspace/configs/dino_mania.yaml")

    # cfg.MODEL.NECK.NEURON_PROJECTOR.NUM_NEURON_LATENT = 7

    # cfg = get_cfg_defaults()
    # cfg.DATASET.SUBJECT_LIST = ["NSD_01"]

    from backbone import LAYER_DICT, RESOLUTION_DICT

    # b = "resnet50"
    # r = RESOLUTION_DICT[b]
    # r = [224, 224]
    # cfg.MODEL.MAX_TRAIN_VOXELS = 100000000
    # cfg.DATASET.SUBJECT_LIST = ["all"]
    # rois = 'htt_1,htt_3,htt_9,htt_10,htt_12,htt_13,htt_18,htt_20'
    # rois = rois.split(',')
    # cfg.DATASET.ROIS = rois
    # cfg.MODEL.BACKBONE.NAME = "CLIP-RN50x4"
    # cfg.STAGE = "finetune"
    cfg.DATASET.SUBJECT_LIST = ["NSD_01"]
    cfg.DATAMODULE.BATCH_SIZE = 32
    # cfg.ANALYSIS.DRAW_NEURON_LOCATION = True
    # cfg.MODEL.MAX_TRAIN_VOXELS = 4000
    # cfg.MODEL.NECK.POOL_HEAD.USE = False
    # cfg.MODEL.NECK.CONV_HEAD.BN = False
    # cfg.MODEL.NECK.CONV_HEAD.LN = True
    # cfg.MODEL.NECK.CONV_HEAD.SKIP_CONNECTION = True
    # cfg.MODEL.NECK.CONV_HEAD.KERNEL_SIZE = 5
    # cfg.MODEL.NECK.CONV_HEAD.DEPTH = 3
    # cfg.MODEL.NECK.CONV_HEAD.CONV1X1 = False
    # cfg.MODEL.BACKBONE.LAYERS = ["layer1"]
    # cfg.ANALYSIS.DRAW_NEURON_LOCATION = False
    # cfg.MODEL.LAYER_GATE.MEAN = 'geometric_mean'
    # cfg.MODEL.NECK.CONV_HEAD.MAX_DIM = 512
    # cfg.MODEL.NEURON_PROJECTOR.SEPARATE_LAYERS = True
    # cfg.ANALYSIS.DRAW_NEURON_LOCATION = False
    # cfg.OPTIMIZER.GATE_REGULARIZER = 1e-2
    # cfg.MODEL.BACKBONE.LAYERS = ['layer2']
    # cfg.OPTIMIZER.PRECISION = 32
    # cfg.DATAMODULE.BATCH_SIZE = 32
    # cfg.TRAINER.ACCUMULATE_GRAD_BATCHES = 1
    # cfg.TRAINER.LIMIT_TRAIN_BATCHES = 0.1
    # cfg.TRAINER.MAX_EPOCHS = 50
    # cfg.FINETUNE.TOP_N = 10
    # cfg.OPTIMIZER.LR = 1e-2
    # cfg.FINETUNE.USE_LINEAR = True
    # cfg.FINETUNE.SOURCE = 'all_nsd'
    # cfg.TRAINER.PRECISION = 32
    # cfg.TRAINER.CALLBACKS.CHECKPOINT.SAVE_TOP_K = 5
    # cfg.TRAINER.CALLBACKS.EARLY_STOP.PATIENCE = 33
    # cfg.TRAINER.CALLBACKS.CHECKPOINT.REMOVE = False
    # cfg.MODEL.NECK.CONV_HEAD.REDUCE_DIM = False
    # cfg.MODEL.MAX_TRAIN_VOXELS = 4096
    # cfg.MODEL.BACKBONE.LAYERS = LAYER_DICT[b]
    # cfg.DATASET.RESOLUTION = r
    # cfg.MODEL.NECK.NAME = "GateNeck"
    # cfg.MODEL.NECK.CONV_HEAD.KERNEL_SIZE = 5
    # cfg.MODEL.NECK.CONV_HEAD.DEPTH = 2
    # cfg.MODEL.NECK.CONV_HEAD.WIDTH = 128
    # cfg.MODEL.LAYER_GATE.USE = True
    # cfg.LOSS.SYNC.EMA_KEY = "vgrad"
    # cfg.OPTIMIZER.X_SHIFT_ZERO_REGULARIZER = 1.0
    # cfg.OPTIMIZER.X_SHIFT_SMOOTH_REGULARIZER = 1.0
    # cfg.OPTIMIZER.P_MU_SHIFT_REGULARIZER = 1.0
    # cfg.TRAINER.CALLBACKS.BACKBONE.UN_FREEZE_AT_EPOCH = 1
    # cfg.OPTIMIZER.MU_REGULARIZER = 0.0
    # cfg.LOSS.SYNC.UPDATE_RULE = 'raw'

    # cfg.MODEL.NECK.IMAGE_SHIFTER.USE = True
    # cfg.MODEL.NEURON_SHIFTER.USE = True

    cfg.TRAINER.DEVICES = [0]
    # callbacks, loggers, log_dir = get_callbacks_and_loggers(cfg)

    # trainer = pl.Trainer(
    #     accelerator="gpu",
    #     devices=[1],
    #     precision=16,
    #     max_epochs=1,
    #     limit_train_batches=0.01,
    #     limit_val_batches=0.05,
    #     profiler="simple",
    #     # callbacks=callbacks,
    #     # logger=loggers,
    # )

    # run_train(cfg, trainer=trainer)

    run_train(cfg)

    # trainer.checkpoint_callback.to_yaml(os.path.join(log_dir, 'checkpoint.yaml'))
