import logging
import os
from copy import copy, deepcopy
from typing import Any, Dict, List, Optional, Tuple
from matplotlib import pyplot as plt

import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn.functional as F

# +
from einops import rearrange, repeat

from torch import Tensor, nn

from backbone import build_backbone
from config import AutoConfig
from config_utils import convert_to_dict, save_to_yaml
from datamodule import AllDatamodule, build_dm
from datasets import Data
from head import Head
from loss import build_loss
from neck import build_neck

from torchmetrics import (
    MetricCollection,
    R2Score,
    MeanSquaredError,
    MeanAbsoluteError,
    PearsonCorrCoef,
)

from abc import ABC, abstractmethod

from topyneck import ConvBlocks, ConvBlocksWithPool, TopyNeck

from image_shifter import ImageShifter

logger = logging.getLogger(__name__)


class EMAMetric(nn.Module):
    def __init__(self, beta=0.9, bias_correction=False):
        super().__init__()
        self.beta = beta
        self.running_v = None
        self.prev_running_v = None
        self.prev_v = None
        self.running_grad = None
        self.bias_correction = bias_correction
        self.t = 0

    def update(self, v):
        if self.running_v is None:
            self.running_v = torch.zeros_like(v)
            self.prev_running_v = torch.zeros_like(v)
            self.prev_v = torch.zeros_like(v)
            self.running_grad = torch.zeros_like(v)
            self.prev_grad = torch.zeros_like(v)

        g = v - self.prev_v
        self.prev_v = copy(v)
        self.running_grad = (1 - self.beta) * g + self.beta * self.running_grad

        self.prev_running_v = copy(self.running_v)
        self.running_v = (1 - self.beta) * v + self.beta * self.running_v

        self.t += 1
        if self.bias_correction:
            self.running_v /= 1 - self.beta**self.t
            self.running_grad /= 1 - self.beta**self.t
        return self.running_v

    def get_gradient(self):
        if self.prev_running_v is None:
            self.prev_running_v = torch.zeros_like(self.running_v)
        return self.running_v - self.prev_running_v

    @staticmethod
    def normalize(x):
        return (x - x.mean()) / x.std()

    def get_status(self):
        return {
            "running_v": self.running_v,
            "running_grad": self.running_grad,
            "running_v_grad": self.get_gradient(),
            "vgrad": self.running_v * self.get_gradient(),
            "nvgrad": self.normalize(self.running_v) * self.get_gradient(),
        }


class PLModelABC(pl.LightningModule, ABC):
    """
    save config to log_dir
    implement my_log() to log min, max, and now of a scalar
    """

    def __init__(self, cfg: AutoConfig):
        super().__init__()
        self.cfg = cfg
        self._cfg_hparams = convert_to_dict(self.cfg.clone())
        self.save_hyperparameters(self._cfg_hparams)

        # self.metric_dict = {}

    def on_fit_start(self) -> None:
        yaml_path = os.path.join(self.logger.log_dir, "config.yaml")
        save_to_yaml(self.cfg, yaml_path)


def get_c_dict(backbone, cfg):
    r1, r2 = cfg.DATASET.RESOLUTION
    p2d = cfg.DATASET.PADDING
    r1 += p2d[2] + p2d[3]
    r2 += p2d[0] + p2d[1]
    input_shape = (3, r1, r2)
    with torch.no_grad():
        backbone = backbone.cuda()
        out = backbone(torch.rand(*input_shape).unsqueeze(0).cuda())
        # out = backbone(torch.rand(*input_shape).unsqueeze(0))

    c_dict = {}
    for k, v in out.items():
        # print(k, v.shape)
        if v.shape[-1] != v.shape[-2]:
            v = rearrange(v, "b h w c -> b c h w")
        assert v.shape[-1] == v.shape[-2]
        c_dict[k] = v.shape[1]
    return c_dict


class VEModel(PLModelABC):
    def __init__(
        self,
        cfg: AutoConfig,
        num_voxel_dict: Dict[str, int],
        roi_dict: Dict[str, Dict[str, Tensor]],
        neuron_coords_dict: Dict[str, Tensor],
        noise_ceiling_dict: Dict[str, Tensor],
    ):
        super().__init__(cfg)
        self.num_voxel_dict = num_voxel_dict
        self.roi_dict = roi_dict
        self.neuron_coords_dict = neuron_coords_dict
        self.noise_ceiling_dict = noise_ceiling_dict
        self.subject_list = list(self.num_voxel_dict.keys())

        self.backbone = build_backbone(self.cfg)

        self.image_shifter = None
        if self.cfg.MODEL.NECK.IMAGE_SHIFTER.USE:
            self.image_shifter = ImageShifter(self.cfg, self.c_dict)

        self.c_dict = get_c_dict(self.backbone, self.cfg)
        if not self.cfg.MODEL.NECK.POOL_HEAD.USE:
            self.conv_blocks = ConvBlocks(self.cfg, self.c_dict)
        else:
            self.conv_blocks = ConvBlocksWithPool(self.cfg, self.c_dict)

        self.neck: TopyNeck = build_neck(
            self.cfg, self.c_dict, self.num_voxel_dict, self.neuron_coords_dict
        )

        self.loss = build_loss(self.cfg)

        self.metrics = (
            nn.ModuleDict()
        )  # {'TRAIN': {"NSD_01": {"early": MetricCollection}, 'VAL': {}, 'TEST': {}}
        self.init_metrics()

        self.ema_score = nn.ModuleDict()  # {"TRAIN": {"NSD_01": v}, "VAL": {}}
        self.init_emas()

        self.voxel_weight = {}
        for s in self.subject_list:
            self.voxel_weight[s] = 1.0

        # for FinetuneEachVoxelCallback
        self.voxel_score = {}

        # for prediction_step
        self.predict_vi_dict = None

    def init_metrics(self):
        self.metrics = (
            nn.ModuleDict()
        )  # {'TRAIN': {"NSD_01": {"early": MetricCollection}, 'VAL': {}, 'TEST': {}}
        for stage in ["TRAIN", "VAL", "TEST"]:
            self.metrics.update({stage: nn.ModuleDict()})
            for s in self.subject_list:
                self.metrics[stage].update({s: nn.ModuleDict()})
                for roi in self.roi_dict[s].keys():
                    num_voxels = self.num_voxel_dict[s]
                    if roi == "all":
                        num_voxels = num_voxels
                    else:
                        num_voxels = self.roi_dict[s][roi].shape[0]
                    if (
                        stage == "TRAIN"
                        and num_voxels > self.cfg.MODEL.MAX_TRAIN_VOXELS
                    ):
                        num_voxels = self.cfg.MODEL.MAX_TRAIN_VOXELS
                    m = MetricCollection(
                        [
                            MeanSquaredError(),
                            MeanAbsoluteError(),
                            PearsonCorrCoef(num_outputs=num_voxels),
                            # R2Score(num_outputs=num_voxels),
                        ]
                    )
                    self.metrics[stage][s].update(
                        {roi: m.clone(prefix=f"{stage}/", postfix=f"/{s}/{roi}")}
                    )

    def init_emas(self):
        self.ema_score = nn.ModuleDict()  # {"TRAIN": {"NSD_01": v}, "VAL": {}}
        for stage in ["TRAIN", "VAL"]:
            self.ema_score.update({stage: nn.ModuleDict()})
            for s in self.subject_list:
                self.ema_score[stage].update(
                    {
                        s: EMAMetric(
                            beta=self.cfg.LOSS.SYNC.EMA_BETA,
                            bias_correction=self.cfg.LOSS.SYNC.EMA_BIAS_CORRECTION,
                        )
                    }
                )

    def move_device(self) -> None:
        self.neck.neuron_coords_dict = {
            k: v.to(self.device) for k, v in self.neck.neuron_coords_dict.items()
        }
        self.noise_ceiling_dict = {
            k: v.to(self.device) if v is not None else None
            for k, v in self.noise_ceiling_dict.items()
        }

    def on_fit_start(self) -> None:
        return self.move_device()

    def on_test_start(self) -> None:
        return self.move_device()

    def on_validation_start(self) -> None:
        return self.move_device()

    def from_batch(self, batch):
        img, y, subject_ids, session_ids, eye_coords, darkness = batch
        return img, y, subject_ids, session_ids, eye_coords, darkness

    def get_intermidiate_outputs(self, x):
        x = self.backbone(x)
        
        if not next(self.conv_blocks.parameters()).requires_grad:
            with torch.no_grad():
                x = self.conv_blocks(x)
        else:
            x = self.conv_blocks(x)
        return x

    def forward(
        self,
        img,  # [B, C, H, W]
        subject_ids: List[str],  # [B]
        session_ids: List[str],  # [B]
        eye_coords: Tensor = None,  # [B, 2]
        voxel_indices_dict: Dict[str, Tensor] = None,  # [N]
    ) -> Tuple[Dict[str, Tensor], Tensor]:
        b = img.shape[0]

        x = self.backbone(img)

        if self.cfg.MODEL.NECK.CONV_HEAD.USE:
            if not next(self.conv_blocks.parameters()).requires_grad:
                with torch.no_grad():
                    x = self.conv_blocks(x)
            else:
                x = self.conv_blocks(x)
                self.neck.previous_layer_requires_grad = True
        # {'layer1': [B, C, H, W], 'layer2': [B, C, H, W], ...}

        x_shift, x_shift_reg = None, [0.0] * b
        if self.image_shifter is not None:
            x_shift, x_shift_reg = self.image_shifter(x)

        out, reg, x_shift = self.neck(
            x=x,
            subject_ids=subject_ids,
            session_ids=session_ids,
            eye_coords=eye_coords,
            voxel_indices_dict=voxel_indices_dict,
            x_shift=x_shift,
        )
        # out: List[Tensor]: [B, N]

        for i in range(b):
            reg[i] += x_shift_reg[i]

        return out, reg, x_shift

    def training_step(self, batch, batch_idx):
        stage = "TRAIN"
        img, ys, subject_ids, session_ids, eye_coords, darkness = self.from_batch(batch)
        voxel_indices_dict = {}  # {subject_id: [N]} reduce memory usage
        for s in self.subject_list:
            n = self.num_voxel_dict[s]
            voxel_indices = ...
            if n > self.cfg.MODEL.MAX_TRAIN_VOXELS:
                voxel_indices = torch.randperm(n)[: self.cfg.MODEL.MAX_TRAIN_VOXELS]
            voxel_indices_dict[s] = voxel_indices
        out, reg, x_shift = self(
            img,
            subject_ids,
            session_ids,
            eye_coords,
            voxel_indices_dict=voxel_indices_dict,
        )

        # legacy shifter logging
        if self.global_step % 10 == 0:
            if x_shift is not None:
                vmax = 0.5
                fig, axs = plt.subplots(1, 2, figsize=(10, 5))
                axs[0].imshow(
                    x_shift[0, 0, ...].detach().cpu().numpy(),
                    cmap="bwr",
                    vmin=-vmax,
                    vmax=vmax,
                )
                axs[1].imshow(
                    x_shift[0, 1, ...].detach().cpu().numpy(),
                    cmap="bwr",
                    vmin=-vmax,
                    vmax=vmax,
                )
                self.logger.experiment.add_figure("x_shift", fig, self.global_step)
                plt.close(fig)

        b = img.shape[0]
        batch_loss = []
        n_voxels = []
        for i, (s, o, y) in enumerate(zip(subject_ids, out, ys)):
            vi = voxel_indices_dict[s]
            n_voxels.append(vi.shape[0] if vi != ... else y.shape[0])
        total_voxels = sum(n_voxels)
        # start_of_loop = True
        for i, (s, o, y) in enumerate(zip(subject_ids, out, ys)):
            vi = voxel_indices_dict[s]
            y = y[vi].unsqueeze(0)
            o = o.unsqueeze(0)
            voxel_loss = self.loss(o, y).squeeze(0)  # [N]

            w_v = self.voxel_weight[s]  # reweight by ema_score
            w_v = 1.0 if isinstance(w_v, float) else w_v[vi]
            voxel_loss = voxel_loss * w_v
            voxel_loss = voxel_loss.mean() * n_voxels[i] / total_voxels
            # scalar #TODO: mean is not good for imbalanced data
            # if torch.isnan(voxel_loss) or torch.isinf(voxel_loss) or voxel_loss > 20:
            #     if hasattr(self, "nanloss") and self.nanloss and start_of_loop:
            #         raise KeyboardInterrupt
            #     logging.warning(
            #         f"loss is invalid: {voxel_loss}, batch_idx: {batch_idx}, i: {i}, subject_id: {s}, skippping..."
            #     )
            #     # continue
            #     # raise KeyboardInterrupt # stop training
            #     best_path = self.trainer.checkpoint_callback.best_model_path
            #     logging.warning(f"reloading best checkpoint: {best_path}")
            #     state_dict = torch.load(best_path)["state_dict"]
            #     self.load_state_dict(state_dict)
            #     self.nanloss = True
            # start_of_loop = False

            voxel_loss += reg[i]  # regularization
            batch_loss.append(voxel_loss)

            self.metrics[stage][s]["all"].update(o.float(), y.float())

        # voxel_loss = None  # don't work for AMP, but usually it won't be None
        # if len(batch_loss) > 0:
        voxel_loss = torch.sum(torch.stack(batch_loss))

        return voxel_loss

    def _shared_eval_step(  # TODO
        self, batch, batch_idx, stage, is_log=True
    ) -> Tuple[List[Tensor], Tensor, List[Tensor]]:
        img, ys, subject_ids, session_ids, eye_coords, darkness = self.from_batch(batch)
        out, reg, x_shift = self(img, subject_ids, session_ids, eye_coords)

        for i, (s, o, y) in enumerate(zip(subject_ids, out, ys)):
            y = y.unsqueeze(0)
            o = o.unsqueeze(0)
            for roi in self.roi_dict[s].keys():
                if roi == "all":
                    self.metrics[stage][s][roi].update(o.float(), y.float())
                else:
                    roi_idx = self.roi_dict[s][roi]
                    self.metrics[stage][s][roi].update(
                        o[:, roi_idx].float(), y[:, roi_idx].float()
                    )

    def validation_step(self, batch, batch_idx):
        self._shared_eval_step(batch, batch_idx, "VAL")

    def test_step(self, batch, batch_idx):
        self._shared_eval_step(batch, batch_idx, "TEST")

    def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
        img, ys, subject_ids, session_ids, eye_coords, darkness = self.from_batch(batch)
        out, reg, x_shift = self(
            img,
            subject_ids,
            session_ids,
            eye_coords,
            voxel_indices_dict=self.predict_vi_dict,
        )
        return out

    @torch.no_grad()
    def update_voxel_weight_by_ema(self):
        stage = self.cfg.LOSS.SYNC.STAGE
        v_dict = {}
        all_v = []
        for s in self.subject_list:
            ema: EMAMetric = self.ema_score[stage][s]
            v = ema.get_status()[self.cfg.LOSS.SYNC.EMA_KEY]
            if self.cfg.LOSS.SYNC.UPDATE_RULE == "exp":
                v = torch.exp(
                    v * self.cfg.LOSS.SYNC.EXP_SCALE + self.cfg.LOSS.SYNC.EXP_SHIFT
                )
            elif self.cfg.LOSS.SYNC.UPDATE_RULE == "square":
                v = v**2
            elif self.cfg.LOSS.SYNC.UPDATE_RULE == "raw":
                v = v
            elif self.cfg.LOSS.SYNC.UPDATE_RULE == "log":  # recommended
                v = torch.log(v + self.cfg.LOSS.SYNC.LOG_SHIFT)
            elif self.cfg.LOSS.SYNC.UPDATE_RULE == "norm":
                std = v.std()
                mean = v.mean()
                v = (v - mean) / std
                v = torch.clamp(v, -3, 3)
                # grad = (grad - grad.min()) / (grad.max() - grad.min())
            elif self.cfg.LOSS.SYNC.UPDATE_RULE == "none":
                return  # do nothing
            else:
                raise NotImplementedError
            v_dict[s] = v
            all_v.append(v)
        all_v = torch.cat(all_v, dim=0)
        vmax, vmin = all_v.max(), all_v.min()

        for s in self.subject_list:
            v = v_dict[s]
            v = (v - vmin) / (vmax - vmin)
            self.voxel_weight[s] = v

    @torch.no_grad()
    def _shared_epoch_end(self, outputs, stage):
        if self.global_step == 0:
            if not hasattr(self, "zero_flag"):
                return

        voxel_metric_dict = {}
        nsd_all_v = []
        all_v = []
        for subject_id in self.subject_list:
            log_vi = ...
            if hasattr(self, "dark_gt_vis") and self.dark_gt_vis:
                if subject_id not in self.dark_gt_vis:
                    logging.warning(f"subject_id {subject_id} not in dark_gt_vis")
                else:
                    log_vi = self.dark_gt_vis[subject_id]

            voxel_metric_dict[subject_id] = {}  # for saving
            for roi in self.roi_dict[subject_id].keys():
                if stage == "TRAIN" and roi != "all":
                    # skip roi for training
                    continue
                metric_dict = self.metrics[stage][subject_id][roi].compute()
                for k, v in metric_dict.items():
                    v[torch.isnan(v)] = 0
                    metric_dict[k] = v

                mean_d = {}
                for k, v in metric_dict.items():
                    if roi == "all":
                        voxel_metric_dict[subject_id][k] = v.detach().cpu().numpy()
                        if stage != "TEST":
                            if "PearsonCorrCoef" in k:
                                self.ema_score[stage][subject_id].update(v)
                                if stage == "VAL":  # for FinetuneEachVoxelCallback
                                    self.voxel_score[subject_id] = (
                                        v.detach().cpu().numpy()
                                    )
                    if roi == "all" and stage == "TEST" and "PearsonCorrCoef" in k:
                        self.voxel_score[subject_id] = v.detach().cpu().numpy()
                    mean_d[k] = torch.mean(v)

                self.log_dict(mean_d)

                # for early stopping
                if roi == "all":
                    vs = metric_dict[f"{stage}/PearsonCorrCoef/{subject_id}/{roi}"]
                    if stage == "TRAIN":
                        all_v.append(vs)
                    else:
                        all_v.append(vs[log_vi])

                # for nsd algonauts leaderboard
                if "NSD" in subject_id and roi == "all" and stage != "TRAIN":
                    vs = metric_dict[f"{stage}/PearsonCorrCoef/{subject_id}/{roi}"]
                    vs = vs**2 / (self.noise_ceiling_dict[subject_id] + 1e-6)
                    nsd_all_v.append(vs[log_vi])

        if len(nsd_all_v) > 0:
            nsd_all_v = torch.cat(nsd_all_v)
            median_vertex = torch.median(nsd_all_v)
            self.log(f"{stage}/PearsonCorrCoef/challenge", median_vertex)
            save_dir = os.path.join(self.logger.log_dir, "challenge")
            os.makedirs(save_dir, exist_ok=True)
            np.save(
                os.path.join(save_dir, f"{stage}.npy"), nsd_all_v.detach().cpu().numpy()
            )

        # for early stopping
        all_v_mean = torch.mean(torch.cat(all_v))
        self.log(f"{stage}/PearsonCorrCoef/mean", all_v_mean)

        # save to disk
        log_dir = self.logger.log_dir
        epoch = self.current_epoch
        step = self.global_step
        save_dir = os.path.join(log_dir, f"voxel_metric")
        os.makedirs(save_dir, exist_ok=True)
        path = os.path.join(save_dir, f"stage={stage}.step={step:012d}.pkl")
        np.save(path, voxel_metric_dict, allow_pickle=True)

        for subject_id in self.subject_list:
            for roi in self.roi_dict[subject_id].keys():
                self.metrics[stage][subject_id][roi].reset()

        # update voxel_weight by ema
        if (
            self.cfg.LOSS.SYNC.USE
            and stage == self.cfg.LOSS.SYNC.STAGE
            and self.global_step > 0
        ):
            self.update_voxel_weight_by_ema()

        return all_v_mean

    def training_epoch_end(self, outputs):
        stage = "TRAIN"
        self._shared_epoch_end(outputs, stage)

    def validation_epoch_end(self, outputs):
        stage = "VAL"
        s = self._shared_epoch_end(outputs, stage)
        if s is not None:
            self.log("hp_metric", s)

    def test_epoch_end(self, outputs):
        stage = "TEST"
        s = self._shared_epoch_end(outputs, stage)
        if s is not None:
            self.log("hp_metric", s)

    def configure_optimizers(self):
        from optimizers import build_optimizer

        base_lr = self.cfg.OPTIMIZER.LR
        base_wd = self.cfg.OPTIMIZER.WEIGHT_DECAY

        no_decay = [
            "bias",
            "BatchNorm3D.weight",
            "BatchNorm1D.weight",
            "BatchNorm2D.weight",
        ]

        optimizer_grouped_parameters = []

        def add_param(configs, model):
            for name, lr, weight_decay in configs:
                if not hasattr(model, name) or getattr(model, name) is None:
                    logger.warning(f"no {name} in model, skipping parameter group")
                    continue
                module = getattr(model, name)
                optimizer_grouped_parameters.append(
                    {
                        "params": [
                            p
                            for n, p in module.named_parameters()
                            if not any(nd in n for nd in no_decay)
                        ],
                        "lr": lr,
                        "weight_decay": weight_decay,
                    }
                )
                optimizer_grouped_parameters.append(
                    {
                        "params": [
                            p
                            for n, p in module.named_parameters()
                            if any(nd in n for nd in no_decay)
                        ],
                        "lr": lr,
                        "weight_decay": 0.0,
                    }
                )

        configs = [
            ("mlp", base_lr, base_wd),
        ]
        add_param(configs, self.backbone)  # for stablediffusion only

        configs = [
            ("image_shifter", base_lr, base_wd),
            ("conv_blocks", base_lr, base_wd),
        ]
        add_param(configs, self)

        configs = [
            (
                "neuron_projectors",
                base_lr * self.cfg.OPTIMIZER.NEURON_PROJECTOR_LR_RATIO,
                base_wd,
            ),
            (
                "eye_shifters",
                base_lr * self.cfg.OPTIMIZER.NEURON_PROJECTOR_LR_RATIO,
                base_wd,
            ),
            ("image_shifter", base_lr, base_wd),
            (
                "neuron_shifters",
                base_lr,
                base_wd,
            ),
            ("layer_gates", base_lr * 1.0, base_wd),
            ("weight", base_lr * 1.0, self.cfg.OPTIMIZER.VOXEL_WEIGHT_DECAY),
            ("bias", base_lr * 1.0, 0.0),
            ("voxel_outs", base_lr * 1.0, self.cfg.OPTIMIZER.VOXEL_WEIGHT_DECAY),
        ]

        add_param(configs, self.neck)

        return build_optimizer(self.cfg, optimizer_grouped_parameters)

    def lr_scheduler_step(self, scheduler, *args, **kwargs):
        scheduler.step(
            epoch=self.current_epoch
        )  # timm's scheduler need the epoch value


class DarkVEModel(VEModel):
    # power of darkness!
    def __init__(
        self,
        cfg: AutoConfig,
        num_voxel_dict: Dict[str, int],
        roi_dict: Dict[str, Dict[str, Tensor]],
        neuron_coords_dict: Dict[str, Tensor],
        noise_ceiling_dict: Dict[str, Tensor],
    ):
        super().__init__(
            cfg, num_voxel_dict, roi_dict, neuron_coords_dict, noise_ceiling_dict
        )

        assert self.cfg.LOSS.DARK.USE == True
        # assert self.cfg.DATASET.ROIS != ["all"]
        for roi in self.cfg.LOSS.DARK.GT_ROIS:
            assert roi in self.cfg.DATASET.ROIS
        assert len(self.cfg.DATASET.DARK_POSTFIX) > 0
        # voxel index for gt
        vis = {}
        for subject_id in self.subject_list:
            if subject_id not in self.roi_dict:
                continue
            i_vis = []
            print(f"loading {subject_id}...", self.cfg.LOSS.DARK.GT_ROIS)
            for roi in self.cfg.LOSS.DARK.GT_ROIS:
                if roi not in self.roi_dict[subject_id]:
                    print(f"roi {roi} not in {subject_id}, skipping...")
                    continue
                vi = self.roi_dict[subject_id][roi]
                i_vis.append(vi)
            i_vis = np.concatenate(i_vis)
            vis[subject_id] = torch.from_numpy(i_vis).long()
        self.dark_gt_vis = vis

    @property
    def darkness_weight(self):
        max_epoch = self.cfg.LOSS.DARK.MAX_EPOCH
        epoch = self.current_epoch
        rate = 1.0 - (epoch / max_epoch)
        rate = max(rate, 0.0)
        return rate

    @property
    def gt_weight(self):
        max_epoch = self.cfg.LOSS.DARK.MAX_EPOCH
        epoch = self.current_epoch
        rate = epoch / max_epoch
        rate = min(rate, 1.0)
        return rate

    def training_step(self, batch, batch_idx):
        stage = "TRAIN"
        img, ys, subject_ids, session_ids, eye_coords, darkness = self.from_batch(batch)
        voxel_indices_dict = {}  # {subject_id: [N]} reduce memory usage
        for s in self.subject_list:
            n = self.num_voxel_dict[s]
            voxel_indices = ...
            if n > self.cfg.MODEL.MAX_TRAIN_VOXELS:
                voxel_indices = torch.randperm(n)[: self.cfg.MODEL.MAX_TRAIN_VOXELS]
                voxel_indices = voxel_indices.sort()[0]
            voxel_indices_dict[s] = voxel_indices
        out, reg, x_shift = self(
            img,
            subject_ids,
            session_ids,
            eye_coords,
            voxel_indices_dict=voxel_indices_dict,
        )

        loss = 0.0

        ### gt part
        b = img.shape[0]
        batch_loss = []
        n_voxels = []
        out_masks = []
        gt_vis = []
        inter_vis = []
        targets = ys
        for i, (s, o, y) in enumerate(zip(subject_ids, out, targets)):
            vi = voxel_indices_dict[s]
            if vi == ...:
                vi = torch.arange(y.shape[0])
            gt_vi = self.dark_gt_vis[s]
            out_mask = torch.isin(vi, gt_vi)  # gt vi on output vi
            out_masks.append(out_mask)
            intersection = vi[out_mask]
            inter_vis.append(intersection)
            gt_vis.append(gt_vi)
            n_voxels.append(out_mask.sum())
        total_voxels = sum(n_voxels)
        for i, (s, o, y, om, vi) in enumerate(
            zip(subject_ids, out, targets, out_masks, inter_vis)
        ):
            y = y[vi].unsqueeze(0)
            o = o[om].unsqueeze(0)
            voxel_loss = self.loss(o, y).squeeze(0)  # [N]

            w_v = self.voxel_weight[s]  # reweight by ema_score
            w_v = 1.0 if isinstance(w_v, float) else w_v[vi]
            voxel_loss = voxel_loss * w_v
            voxel_loss = voxel_loss.mean() * n_voxels[i] / total_voxels
            # scalar

            batch_loss.append(voxel_loss)

        if not self.cfg.LOSS.DARK.IGNORE_GT:
            loss += torch.sum(torch.stack(batch_loss)) * self.gt_weight

        ### darkness part
        assert self.cfg.LOSS.DARK.IGNORE_OTHER_ROIS == False

        targets = darkness
        b = img.shape[0]
        batch_loss = []
        n_voxels = []
        for i, (s, o, y) in enumerate(zip(subject_ids, out, targets)):
            vi = voxel_indices_dict[s]
            n_voxels.append(vi.shape[0] if vi != ... else y.shape[0])
        total_voxels = sum(n_voxels)
        for i, (s, o, y, om) in enumerate(zip(subject_ids, out, targets, out_masks)):
            vi = voxel_indices_dict[s]
            y = y[vi].unsqueeze(0)
            o = o.unsqueeze(0)
            voxel_loss = self.loss(o, y).squeeze(0)  # [N]

            voxel_loss[om] *= self.cfg.LOSS.DARK.GT_SCALE_UP_COEF  # 3.0

            # don't reweight darkness
            w_v = 1.0
            voxel_loss = voxel_loss * w_v
            voxel_loss = voxel_loss.mean() * n_voxels[i] / total_voxels

            batch_loss.append(voxel_loss)

            self.metrics[stage][s]["all"].update(o.float(), y.float())

        loss += torch.sum(torch.stack(batch_loss)) * self.darkness_weight

        # regularization
        loss += torch.mean(torch.stack(reg))

        return loss



if __name__ == "__main__":
    from config_utils import get_cfg_defaults

    cfg = get_cfg_defaults()

    cfg.MODEL.NECK.REGULARIZATION.MEAN_WEIGHT = 0.0
    cfg.MODEL.NECK.REGULARIZATION.STD_WEIGHT = 0.0
    cfg.MODEL.NECK.POOLING_MODE = "avg"
    cfg.DATASET.NAME = "MEEG"
    cfg.DATASET.SUBJECT_LIST = ["EEG01"]
    cfg.ANALYSIS.SAVE_LAST_LINEAR_LAYER = False
    cfg.TRAINER.CALLBACKS.CHECKPOINT.SAVE_TOP_K = 1
    cfg.TRAINER.CALLBACKS.CHECKPOINT.REMOVE = True
    cfg.TRAINER.CALLBACKS.SAVE_OUTPUT = True
    cfg.TRAINER.PRECISION = 16
    cfg.DATAMODULE.BATCH_SIZE = 64
    cfg.OPTIMIZER.LR = 1e-4
    cfg.TRAINER.CALLBACKS.EARLY_STOP.PATIENCE = 20
    cfg.TRAINER.MAX_EPOCHS = 0
    cfg.TRAINER.LIMIT_TRAIN_BATCHES = 0.01
    cfg.TRAINER.LIMIT_VAL_BATCHES = 0.3
    cfg.MODEL.NECK.FC_HIDDEN_DIM = 512
    cfg.MODEL.NECK.POOLING_SIZE = 7
    cfg.TRAINER.VAL_CHECK_INTERVAL = 1.0
    # cfg.TRAINER.VAL_CHECK_INTERVAL = 100
    # ratio of number of data points EEG / MEG
    cfg.MODEL.LOSS.MEG_WEIGHT = 0.743445038206063
    cfg.OPTIMIZER.NAME = "AdamBelief"

    dm = build_dm(cfg)
    dm.setup()

    model = VEModel(cfg, dm.num_voxel_dict)

    from train import get_callbacks_and_loggers

    callbacks, loggers, log_dir = get_callbacks_and_loggers(cfg)

    trainer = pl.Trainer(
        accelerator="gpu",
        devices=[1],
        precision=16,
        max_epochs=2,
        limit_train_batches=0.1,
        limit_val_batches=0.5,
        profiler="simple",
        # callbacks=callbacks,
        # logger=loggers,
        # enable_progress_bar=False,
    )

    trainer.fit(model, datamodule=dm)
