import math
import shutil
import statistics
import sys
from pathlib import Path
from typing import Any, Literal

import delu
import numpy as np
import rtdl_num_embeddings
import scipy
import torch
import torch.nn as nn
import torch.utils.tensorboard
from loguru import logger
from torch import Tensor
from tqdm import tqdm
from typing_extensions import NotRequired, TypedDict
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import (
    StandardScaler,
    PowerTransformer,
    QuantileTransformer,
    FunctionTransformer,
)

if __name__ == "__main__":
    _cwd = Path.cwd()
    assert _cwd.joinpath(".git").exists(), (
        "The script must be run from the root of the repository"
    )
    sys.path.append(str(_cwd))
    del _cwd

import lib
import lib.data
import lib.deep
import lib.env
from lib import KWArgs, PartKey
from lib.tabpfn.model.bar_distribution import (
    BarDistribution,
    FullSupportBarDistribution,
)
from lib.tabpfn.utils import _transform_borders_one, translate_probs_across_borders


def _get_first_ensemble_layer(
    backbone: lib.deep.MLP,
) -> lib.deep.LinearEfficientEnsemble:
    if isinstance(backbone, lib.deep.MLP):
        return backbone.blocks[0][0]  # type: ignore[code]
    else:
        raise RuntimeError(f"Unsupported backbone: {backbone}")


@torch.inference_mode()
def _init_first_adapter(
    weight: Tensor,
    distribution: Literal["normal", "random-signs"],
    init_sections: list[int],
) -> None:
    """Initialize the first adapter.

    NOTE
    The `init_sections` argument is a historical artifact that accidentally leaked
    from irrelevant experiments to the final models. Perhaps, the code related
    to `init_sections` can be simply removed, but this was not tested.
    """
    assert weight.ndim == 2
    assert weight.shape[1] == sum(init_sections)

    if distribution == "normal":
        init_fn_ = nn.init.normal_
    elif distribution == "random-signs":
        init_fn_ = lib.deep.init_random_signs_
    else:
        raise ValueError(f"Unknown distribution: {distribution}")

    section_bounds = [0, *torch.tensor(init_sections).cumsum(0).tolist()]
    for i in range(len(init_sections)):
        # NOTE
        # As noted above, this section-based initialization is an arbitrary historical
        # artifact. Consider the first adapter of one ensemble member.
        # This adapter vector is implicitly split into "sections",
        # where one section corresponds to one feature. The code below ensures that
        # the adapter weights in one section are initialized with the same random value
        # from the given distribution.
        w = torch.empty((len(weight), 1), dtype=weight.dtype, device=weight.device)
        init_fn_(w)
        weight[:, section_bounds[i] : section_bounds[i + 1]] = w


def regression_output_transform(
    target_transform,
    criterion,
    renormalized_criterion,
    softmax_temperature: float = 0.9,
):
    "transform model outputs into original values for regression; extracted from regressor.py in tabpfn"

    std_borders = criterion.borders.cpu().numpy()
    logit_cancel_mask, descending_borders, borders_t = _transform_borders_one(
        std_borders,
        target_transform=target_transform,
        repair_nan_borders_after_transform=True,
    )
    if descending_borders:
        borders_t = borders_t.flip(-1)  # type: ignore

    device = lib.get_device()

    def transform(out):
        logits = translate_probs_across_borders(
            out.float() / softmax_temperature,
            frm=torch.as_tensor(borders_t, device=device),
            to=criterion.borders.to(device),
        )
        if logit_cancel_mask is not None:
            out = out.clone()  # pyright: ignore
            out[..., logit_cancel_mask] = float("-inf")

        logits = logits.log()
        if logits.dtype == torch.float16:
            logits = logits.float()
        logits = logits.cpu()

        return renormalized_criterion.mean(logits)

    return transform


class Model(nn.Module):
    """MLP & TabM."""

    def __init__(
        self,
        *,
        n_num_features: int,
        cat_cardinalities: list[int],
        n_classes: None | int,
        backbone: dict,
        bins: None | list[Tensor],  # For piecewise-linear encoding/embeddings.
        num_embeddings: None | dict = None,
        arch_type: Literal[
            # Plain feed-forward network without any kind of ensembling.
            "plain",
            #
            # TabM-mini
            "tabm-mini",
            #
            # TabM-mini. The first adapter is initialized from the normal distribution.
            # This is used in Section 5.1.
            "tabm-mini-normal",
            #
            # TabM
            "tabm",
            #
            # TabM. The first adapter is initialized from the normal distribution.
            # This variation was is not used in the paper, but there is a preliminary
            # evidence that may be a better default strategy.
            "tabm-normal",
        ],
        k: None | int = None,
        **kwargs,
    ) -> None:
        # >>> Validate arguments.
        assert n_num_features >= 0
        assert n_num_features or cat_cardinalities
        if arch_type == "plain":
            assert k is None
        else:
            assert k is not None
            assert k > 0

        super().__init__()

        # >>> Continuous (numerical) features
        first_adapter_sections = []  # See the comment in `_init_first_adapter`.

        if n_num_features == 0:
            assert bins is None
            self.num_module = None
            d_num = 0

        elif num_embeddings is None:
            assert bins is None
            self.num_module = None
            d_num = n_num_features
            first_adapter_sections.extend(1 for _ in range(n_num_features))

        else:
            if bins is None:
                self.num_module = lib.deep.make_module(
                    **num_embeddings, n_features=n_num_features
                )
            else:
                assert num_embeddings["type"].startswith("PiecewiseLinearEmbeddings")
                self.num_module = lib.deep.make_module(**num_embeddings, bins=bins)
            d_num = n_num_features * num_embeddings["d_embedding"]
            first_adapter_sections.extend(
                num_embeddings["d_embedding"] for _ in range(n_num_features)
            )

        # >>> Categorical features
        self.cat_module = (
            lib.deep.OneHotEncoding0d(cat_cardinalities) if cat_cardinalities else None
        )
        first_adapter_sections.extend(cat_cardinalities)
        d_cat = sum(cat_cardinalities)

        # >>> Backbone
        d_flat = d_num + d_cat
        self.minimal_ensemble_adapter = None
        self.backbone = lib.deep.make_module(d_in=d_flat, **backbone)

        if arch_type != "plain":
            assert k is not None
            first_adapter_init = (
                "normal"
                if arch_type in ("tabm-mini-normal", "tabm-normal")
                # For other arch_types, the initialization depends
                # on the presense of num_embeddings.
                else "random-signs"
                if num_embeddings is None
                else "normal"
            )

            if arch_type in ("tabm-mini", "tabm-mini-normal"):
                # Minimal ensemble
                self.minimal_ensemble_adapter = lib.deep.ScaleEnsemble(
                    k,
                    d_flat,
                    init="random-signs" if num_embeddings is None else "normal",
                )
                _init_first_adapter(
                    self.minimal_ensemble_adapter.weight,  # type: ignore[code]
                    first_adapter_init,
                    first_adapter_sections,
                )

            elif arch_type in ("tabm", "tabm-normal"):
                # Like BatchEnsemble, but all multiplicative adapters,
                # except for the very first one, are initialized with ones.
                lib.deep.make_efficient_ensemble(
                    self.backbone,
                    k=k,
                    ensemble_scaling_in=True,
                    ensemble_scaling_out=True,
                    ensemble_bias=True,
                    scaling_init="ones",
                )
                _init_first_adapter(
                    _get_first_ensemble_layer(self.backbone).r,  # type: ignore[code]
                    first_adapter_init,
                    first_adapter_sections,
                )

            else:
                raise ValueError(f"Unknown arch_type: {arch_type}")

        # >>> Output
        d_block = backbone["d_block"]
        d_out = 1 if n_classes is None else n_classes
        self.output = (
            nn.Linear(d_block, d_out)
            if arch_type == "plain"
            else lib.deep.NLinear(k, d_block, d_out)  # type: ignore[code]
        )

        # >>>
        self.arch_type = arch_type
        self.k = k

    def forward(
        self, x_num: None | Tensor = None, x_cat: None | Tensor = None
    ) -> tuple[Tensor, Tensor]:
        x = []
        if x_num is not None:
            x.append(x_num if self.num_module is None else self.num_module(x_num))
        if x_cat is None:
            assert self.cat_module is None
        else:
            assert self.cat_module is not None
            x.append(self.cat_module(x_cat).float())
        x = torch.column_stack([x_.flatten(1, -1) for x_ in x])

        if self.k is not None:
            x = x[:, None].expand(-1, self.k, -1)  # (B, D) -> (B, K, D)
            if self.minimal_ensemble_adapter is not None:
                x = self.minimal_ensemble_adapter(x)
        else:
            assert self.minimal_ensemble_adapter is None

        x = self.backbone(x)
        x = self.output(x)
        if self.k is None:
            # Adjust the output shape for plain networks to make them compatible
            # with the rest of the script (loss, metrics, predictions, ...).
            # (B, D_OUT) -> (B, 1, D_OUT)
            x = x[:, None]
        return x


class Config(TypedDict):
    seed: int
    data: KWArgs
    bins: NotRequired[KWArgs]
    model: KWArgs
    head_selection: NotRequired[bool]
    optimizer: KWArgs
    n_lr_warmup_epochs: NotRequired[int]
    batch_size: int
    eval_batch_size: NotRequired[int]
    patience: int
    n_epochs: int
    gradient_clipping_norm: NotRequired[float]
    parameter_statistics: NotRequired[bool]
    amp: NotRequired[bool]  # torch.autocast
    compile: NotRequired[bool]  # torch.compile
    #
    # loss_type: Literal["mse", "hlgauss", "mse_with_oe"]
    target_transform: NotRequired[Literal["power", "quantile", "identity"]]
    n_target_bins: NotRequired[int]
    # n_bins: int
    softmax_temperature: NotRequired[float]
    sigma: NotRequired[float]


def main(
    config: Config | str | Path,
    output: None | str | Path = None,
    *,
    force: bool = False,
) -> None | lib.JSONDict:
    # >>> Start
    config, output = lib.check(config, output, config_type=Config)
    if not lib.start(output, force=force):
        return None

    lib.show_config(config)  # type: ignore[code]
    delu.random.seed(config["seed"])
    device = lib.get_device()
    report = lib.create_report(config)  # type: ignore[code]

    # >>> Data
    dataset = lib.data.build_dataset(**config["data"])
    assert dataset.task.is_regression
    if dataset.task.is_regression:
        dataset.data["y"], regression_label_stats = lib.data.standardize_labels(
            dataset.data["y"]
        )
        target_transform_type = config.pop("target_transform", "identity")
        if target_transform_type == "power":
            target_transform = Pipeline(
                [("power", PowerTransformer()), ("standard", StandardScaler())]
            ).fit(dataset.data["y"]["train"].reshape(-1, 1))
        elif target_transform_type == "quantile":
            target_transform = QuantileTransformer(
                output_distribution="normal", random_state=config["seed"]
            ).fit(dataset.data["y"]["train"].reshape(-1, 1))
        elif target_transform_type == "identity":
            target_transform = FunctionTransformer(func=None)
        else:
            raise ValueError(f"Unknown target_transform {target_transform_type}")

        dataset.data["y"] = {
            part: target_transform.transform(dataset.data["y"][part].reshape(-1, 1))
            .astype(np.float32)  # type: ignore
            .squeeze()  # pyright: ignore
            for part in dataset.data["y"]
        }
    else:
        target_transform = None
        regression_label_stats = None

    # Convert binary features to categorical features.
    if dataset.n_bin_features > 0:
        x_bin = dataset.data.pop("x_bin")
        # Remove binary features with just one unique value in the training set.
        # This must be done, otherwise, the script will fail on one specific dataset
        # from the "why" benchmark.
        n_bin_features = x_bin["train"].shape[1]
        good_bin_idx = [
            i for i in range(n_bin_features) if len(np.unique(x_bin["train"][:, i])) > 1
        ]
        if len(good_bin_idx) < n_bin_features:
            x_bin = {k: v[:, good_bin_idx] for k, v in x_bin.items()}

        if dataset.n_cat_features == 0:
            dataset.data["x_cat"] = {
                part: np.zeros((dataset.size(part), 0), dtype=np.int64)
                for part in x_bin
            }
        for part in x_bin:
            dataset.data["x_cat"][part] = np.column_stack(
                [dataset.data["x_cat"][part], x_bin[part].astype(np.int64)]
            )
        del x_bin
    dataset = dataset.to_torch(device)
    Y_train = dataset.data["y"]["train"].to(
        torch.long if dataset.task.is_classification else torch.float
    )

    def unique_with_precision(arr, precision=1e-5):
        unique_values = []
        for value in arr:
            if not unique_values or abs(value - unique_values[-1]) > precision:
                unique_values.append(value)
        return torch.tensor(unique_values, dtype=torch.float32, device=device)

    borders = torch.quantile(
        Y_train,
        q=torch.linspace(0.0, 1.0, steps=config.get("n_target_bins", 128) + 1).to(
            Y_train
        ),
    )
    borders = unique_with_precision(borders)
    print("Len of borders:", len(borders))

    # >>> Model
    if "bins" in config:
        # Compute the bins for PiecewiseLinearEncoding and PiecewiseLinearEmbeddings.
        compute_bins_kwargs = (
            {
                "y": Y_train.to(
                    torch.long if dataset.task.is_classification else torch.float
                ),
                "regression": dataset.task.is_regression,
                "verbose": True,
            }
            if "tree_kwargs" in config["bins"]
            else {}
        )
        bin_edges = rtdl_num_embeddings.compute_bins(
            dataset.data["x_num"]["train"], **config["bins"], **compute_bins_kwargs
        )
        logger.info(f"Bin counts: {[len(x) - 1 for x in bin_edges]}")
    else:
        bin_edges = None
    model = Model(
        n_num_features=dataset.n_num_features,
        cat_cardinalities=dataset.compute_cat_cardinalities(),
        n_classes=len(borders) - 1,
        **config["model"],
        bins=bin_edges,
    )
    report["n_parameters"] = lib.deep.get_n_parameters(model)
    logger.info(f"n_parameters = {report['n_parameters']}")
    report["prediction_type"] = "labels" if dataset.task.is_regression else "probs"
    model.to(device)
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)

    # >>> Training
    step = 0
    batch_size = config["batch_size"]
    report["epoch_size"] = epoch_size = math.ceil(dataset.size("train") / batch_size)
    eval_batch_size = config.get(
        "eval_batch_size",
        # With torch.compile,
        # the largest possible evaluation batch size is noticeably smaller.
        2048 if config.get("compile", False) else 32768,
    )
    chunk_size = None

    optimizer = lib.deep.make_optimizer(
        **config["optimizer"], params=lib.deep.make_parameter_groups(model)
    )
    gradient_clipping_norm = config.get("gradient_clipping_norm")

    ## Vanilla
    def simple_mse(model_out: tuple[Tensor, Tensor], y_true: Tensor) -> Tensor:
        return nn.functional.mse_loss(model_out[0].mean(1), y_true)

    ## HL

    ## MSE with oue

    ## TabPFN
    # quantiles = torch.quantile(
    # Y_train, torch.linspace(0.0, 1.0, config["n_quantiles"] + 1).to(Y_train)
    # )
    # _loss_fn = FullSupportBarDistribution(quantiles)
    # renormalized_criterion = FullSupportBarDistribution(
    # _loss_fn.borders * regression_label_stats.std + regression_label_stats.mean,  # pyright: ignore
    # ).float()
    # if dataset.task.is_regression:
    #     pred_transform = regression_output_transform(
    #         target_transform,
    #         _loss_fn,
    #         renormalized_criterion,
    #         softmax_temperature=config["softmax_temperature"],
    #     )
    # else:
    #     pred_transform = None
    ##

    # if config["loss_type"] == "mse":
    #     assert config["n_bins"] == 1
    #     loss_fn = simple_mse
    # elif config["loss_type"] == "hlgauss":
    #     assert config["n_bins"] > 1
    #     _hlgauss_loss = HLGaussLoss(
    #         min_value=Y_train.min().item(),
    #         max_value=Y_train.max().item(),
    #         num_bins=config["n_bins"],
    #         sigma=config.get("sigma", None),
    #         device=device,
    #     )
    #     loss_fn = lambda model_out, y_true: _hlgauss_loss(model_out[0].mean(1), y_true)
    # elif config["loss_type"] == "mse_with_oe":
    #     assert config["n_bins"] == 1
    #     loss_fn = mse_with_oe
    if dataset.task.is_regression:
        assert regression_label_stats is not None
        # borders = torch.load(
        #     lib.PROJECT_DIR / "tabpfn-v2-regressor.ckpt",
        #     weights_only=True,
        # )["state_dict"]["criterion.borders"].to(device)
        # borders = torch.linspace(Y_train.min(), Y_train.max(), steps=51).to(Y_train)
        _loss_fn = FullSupportBarDistribution(borders)
        renormalized_criterion = FullSupportBarDistribution(
            _loss_fn.borders * regression_label_stats.std + regression_label_stats.mean,
        ).float()
        pred_transform = regression_output_transform(
            target_transform,  # pyright: ignore
            _loss_fn,
            renormalized_criterion,
            softmax_temperature=config["model"].get("softmax_temperature", 1.0),
        )

    def loss_fn(y_pred: Tensor, y_true: Tensor) -> Tensor:
        # print(y_pred.shape, y_true.shape)
        return _loss_fn(
            y_pred.flatten(0, 1).unsqueeze(1),
            y_true.repeat_interleave(y_pred.shape[1]).unsqueeze(1),
        ).mean()

    # The following generator is used only for creating training batches,
    # so the random seed fully determines the sequence of training objects.
    batch_generator = torch.Generator(device).manual_seed(config["seed"])
    timer = delu.tools.Timer()
    early_stopping = delu.tools.EarlyStopping(config["patience"], mode="max")
    parameter_statistics = config.get("parameter_statistics", config["seed"] == 1)
    training_log = []
    writer = torch.utils.tensorboard.SummaryWriter(output)  # type: ignore[code]

    # Only bfloat16 was tested as amp_dtype.
    # However, float16 is supported as a fallback.
    # To enable float16, uncomment the two lines below.
    amp_dtype = (
        torch.bfloat16
        if config.get("amp", False)
        and torch.cuda.is_available()
        and torch.cuda.is_bf16_supported()
        # else torch.float16
        # if config.get('amp', False) and and torch.cuda.is_available()
        else None
    )
    amp_enabled = amp_dtype is not None
    # For FP16, the gradient scaler must be used.
    grad_scaler = torch.cuda.amp.GradScaler() if amp_dtype is torch.float16 else None  # type: ignore[code]
    logger.info(f"AMP enabled: {amp_enabled}")

    if config.get("compile", False):
        # NOTE
        # `torch.compile` is intentionally called without the `mode` argument,
        # because it caused issues with training.
        model = torch.compile(model)
        evaluation_mode = torch.no_grad
    else:
        evaluation_mode = torch.inference_mode

    @torch.autocast(device.type, enabled=amp_enabled, dtype=amp_dtype)  # type: ignore[code]
    def apply_model(part: PartKey, idx: Tensor) -> Tensor:
        out = model(
            dataset.data["x_num"][part][idx] if "x_num" in dataset.data else None,
            dataset.data["x_cat"][part][idx] if "x_cat" in dataset.data else None,
        )
        out = out.squeeze(-1).float()
        return out

    @evaluation_mode()
    def evaluate(
        parts: list[PartKey], eval_batch_size: int
    ) -> tuple[
        dict[PartKey, Any], dict[PartKey, np.ndarray], dict[PartKey, np.ndarray], int
    ]:
        model.eval()
        head_predictions: dict[PartKey, np.ndarray] = {}
        for part in parts:
            while eval_batch_size:
                try:
                    head_predictions[part] = (
                        torch.cat(
                            [
                                apply_model(part, idx)
                                for idx in torch.arange(
                                    dataset.size(part), device=device
                                ).split(eval_batch_size)
                            ]
                        )
                        .cpu()
                        .numpy()
                    )
                except RuntimeError as err:
                    if not lib.is_oom_exception(err):
                        raise
                    eval_batch_size //= 2
                    logger.warning(f"eval_batch_size = {eval_batch_size}")
                else:
                    break
            if not eval_batch_size:
                RuntimeError("Not enough memory even for eval_batch_size=1")
        # if dataset.task.is_regression:
        #     assert regression_label_stats is not None
        #     head_predictions = {
        #         k: v * regression_label_stats.std + regression_label_stats.mean
        #         for k, v in head_predictions.items()
        #     }
        # else:
        #     head_predictions = {
        #         k: scipy.special.softmax(v, axis=-1)
        #         for k, v in head_predictions.items()
        #     }
        #     if dataset.task.is_binclass:
        #         head_predictions = {k: v[..., 1] for k, v in head_predictions.items()}
        predictions = {k: v.mean(1) for k, v in head_predictions.items()}
        predictions = {
            k: pred_transform(torch.from_numpy(v).to(device)).cpu().numpy()  # pyright: ignore
            for k, v in predictions.items()
        }

        # assert pred_transform is not None
        # predictions = {
        #     k: pred_transform(torch.tensor(v.mean(1), device=Y_train.device))
        #     .cpu()
        #     .numpy()
        #     for k, v in head_predictions.items()
        # }

        metrics = (
            dataset.task.calculate_metrics(predictions, report["prediction_type"])
            if lib.are_valid_predictions(predictions)
            else {x: {"score": -99999.0} for x in predictions}
        )
        return metrics, predictions, head_predictions, eval_batch_size

    def save_checkpoint() -> None:
        lib.dump_checkpoint(
            output,
            {
                "step": step,
                "model": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "batch_generator": batch_generator.get_state(),
                "random_state": delu.random.get_state(),
                "early_stopping": early_stopping,
                "report": report,
                "timer": timer,
                "training_log": training_log,
            }
            | (
                {} if grad_scaler is None else {"grad_scaler": grad_scaler.state_dict()}
            ),
        )
        lib.dump_report(output, report)
        lib.backup_output(output)

    print()
    timer.run()
    while config["n_epochs"] == -1 or step // epoch_size < config["n_epochs"]:
        print(f"[...] {lib.try_get_relative_path(output)} | {timer}")

        model.train()
        epoch_losses = []
        for batch_idx in tqdm(
            torch.randperm(
                len(dataset.data["y"]["train"]),
                generator=batch_generator,
                device=device,
            ).split(batch_size),
            desc=f"Epoch {step // epoch_size} Step {step}",
        ):
            loss, new_chunk_size = lib.deep.zero_grad_forward_backward(
                optimizer,
                lambda idx: loss_fn(apply_model("train", idx), Y_train[idx]),
                batch_idx,
                chunk_size or batch_size,
                grad_scaler,
            )

            if parameter_statistics and (
                step % epoch_size == 0  # The first batch of the epoch.
                or step // epoch_size == 0  # The first epoch.
            ):
                for k, v in lib.deep.compute_parameter_stats(model).items():  # type: ignore
                    writer.add_scalars(k, v, step, timer.elapsed())
                    del k, v

            if gradient_clipping_norm is not None:
                if grad_scaler is not None:
                    grad_scaler.unscale_(optimizer)
                nn.utils.clip_grad.clip_grad_norm_(
                    model.parameters(), gradient_clipping_norm
                )
            if grad_scaler is None:
                optimizer.step()
            else:
                grad_scaler.step(optimizer)
                grad_scaler.update()

            step += 1
            epoch_losses.append(loss.detach())
            if new_chunk_size and new_chunk_size < (chunk_size or batch_size):
                chunk_size = new_chunk_size
                logger.warning(f"chunk_size = {chunk_size}")

        epoch_losses = torch.stack(epoch_losses).tolist()
        mean_loss = statistics.mean(epoch_losses)
        metrics, predictions, _, eval_batch_size = evaluate(
            ["val", "test"], eval_batch_size
        )

        training_log.append(
            {"epoch-losses": epoch_losses, "metrics": metrics, "time": timer.elapsed()}
        )
        lib.print_metrics(mean_loss, metrics)
        writer.add_scalars("loss", {"train": mean_loss}, step, timer.elapsed())
        for part in metrics:
            writer.add_scalars(
                "score", {part: metrics[part]["score"]}, step, timer.elapsed()
            )

        if (
            "metrics" not in report
            or metrics["val"]["score"] > report["metrics"]["val"]["score"]
        ):
            print("🌸 New best epoch! 🌸")
            report["best_step"] = step
            report["metrics"] = metrics
            save_checkpoint()
            lib.dump_predictions(output, predictions)

        early_stopping.update(metrics["val"]["score"])
        if early_stopping.should_stop() or not lib.are_valid_predictions(predictions):
            break

        print()
    report["time"] = str(timer)

    # >>>
    if lib.get_checkpoint_path(output).exists():
        model.load_state_dict(lib.load_checkpoint(output)["model"])
    report["metrics"], predictions, head_predictions, _ = evaluate(
        ["train", "val", "test"], eval_batch_size
    )
    report["chunk_size"] = chunk_size
    report["eval_batch_size"] = eval_batch_size
    lib.dump_predictions(output, predictions)
    lib.dump_summary(output, lib.summarize(report))
    save_checkpoint()

    # >>> Submodel selection (TabM[B] & TabM[G]).
    if (
        config.get("head_selection", False)
        and head_predictions["train"].shape[1] > 1
        # The following conditions is a hack preventing the head selection during
        # the hyperparameter tuning, because bin/tune.py runs training
        # outside of the project directory.
        and lib.env.get_project_dir() in output.parents
        and output.parent.name != "trials"
    ):
        if output.parent.name.endswith("-evaluation"):
            best_head_output = (
                output.parent.with_name(
                    output.parent.name.removesuffix("-evaluation")
                    + "-best-head-evaluation"
                )
                / output.name
            )
            greedy_heads_output = (
                output.parent.with_name(
                    output.parent.name.removesuffix("-evaluation")
                    + "-greedy-heads-evaluation"
                )
                / output.name
            )
        else:
            best_head_output = output.with_name(output.name + "-best-head")
            greedy_heads_output = output.with_name(output.name + "-greedy-heads")
        for dir_ in [best_head_output, greedy_heads_output]:
            if dir_.exists():
                logger.warning(f"Removing the existing output: {dir_}")
                shutil.rmtree(dir_)

        prediction_type = (
            lib.PredictionType.PROBS
            if dataset.task.is_classification
            else lib.PredictionType.LABELS
        )
        head_selection_timer = delu.tools.Timer()
        head_selection_timer.run()

        # >>> TabM[B]: select the Best submodel.
        n_heads = head_predictions["val"].shape[1]
        head_val_scores = np.array(
            [
                dataset.task.calculate_metrics(
                    {"val": head_predictions["val"][:, i]}, prediction_type
                )["val"]["score"]
                for i in range(n_heads)
            ]
        )
        best_head_idx = int(np.argmax(head_val_scores))
        best_head_output.mkdir(parents=True)
        lib.finish(
            best_head_output,
            report
            | {
                "heads": [best_head_idx],
                "head_selection_time": str(head_selection_timer),
                "metrics": dataset.task.calculate_metrics(
                    {k: v[:, best_head_idx] for k, v in head_predictions.items()},
                    prediction_type,
                ),
            },
        )

        # >>> TabM[G]: Greedily select a powerful subset of submodels.

        # Start with the best head.
        greedy_idx = [best_head_idx]
        greedy_score = head_val_scores[best_head_idx]

        greedy_mask = [False] * n_heads
        greedy_mask[best_head_idx] = True

        while len(greedy_idx) < n_heads:
            new_idx = None
            new_score = None

            # Iterating through all heads.
            for head_idx in range(n_heads):
                # If the head is already in greedy_idx, skip it.
                if greedy_mask[head_idx]:
                    continue

                candidate_idx = [*greedy_idx, head_idx]
                candidate_score = dataset.task.calculate_metrics(
                    {"val": head_predictions["val"][:, candidate_idx].mean(1)},
                    prediction_type,
                )["val"]["score"]
                if candidate_score > greedy_score and (
                    new_score is None or candidate_score > new_score
                ):
                    new_idx = candidate_idx
                    new_score = candidate_score

            # If no head improves the current greedy score,
            # the head selection process is stopped.
            if new_idx is None:
                break
            else:
                assert new_score is not None
                greedy_score = new_score
                greedy_idx = new_idx

        greedy_heads_output.mkdir(parents=True)
        lib.finish(
            greedy_heads_output,
            report
            | {
                "heads": greedy_idx,
                "head_selection_time": str(head_selection_timer),
                "metrics": dataset.task.calculate_metrics(
                    {k: v[:, greedy_idx].mean(1) for k, v in head_predictions.items()},
                    prediction_type,
                ),
            },
        )

    lib.finish(output, report)
    return report


if __name__ == "__main__":
    lib.configure_libraries()
    lib.run_MainFunction_cli(main)
