# Combined script for triplet pretraining with pretrained embeddings

import math
import statistics
from pathlib import Path
from typing import Any, Callable

import delu
import numpy as np
import rtdl_num_embeddings
import torch
import torch.nn as nn
import torch.utils.tensorboard
from loguru import logger
from torch import Tensor
from tqdm import tqdm
from typing_extensions import NotRequired, TypedDict
from collections import defaultdict

import lib
from lib import KWArgs, PartKey

EvalOut = tuple[dict[PartKey, Any], dict[PartKey, np.ndarray], int]


class Model(nn.Module):
    def __init__(
        self,
        *,
        n_num_features: int,
        n_bin_features: int,
        cat_cardinalities: list[int],
        n_classes: None | int,
        bins: None | list[Tensor],
        num_embeddings: None | dict = None,
        backbone: dict,
        pretrain: bool = False,
        normalize_similarity: bool = True,
    ) -> None:
        assert n_num_features or n_bin_features or cat_cardinalities
        super().__init__()

        self.m_bin = None

        if num_embeddings is None:
            assert bins is None
            self.m_num = None
            d_num = n_num_features
        else:
            assert n_num_features > 0
            if num_embeddings["type"] in (
                rtdl_num_embeddings.PiecewiseLinearEmbeddings.__name__,
                rtdl_num_embeddings.PiecewiseLinearEncoding.__name__,
                lib.deep.PiecewiseLinearEmbeddingsV2.__name__,
            ):
                assert bins is not None
                self.m_num = lib.deep.make_module(**num_embeddings, bins=bins)
                d_num = (
                    sum(len(x) - 1 for x in bins)
                    if num_embeddings["type"].startswith(
                        rtdl_num_embeddings.PiecewiseLinearEncoding.__name__
                    )
                    else n_num_features * num_embeddings["d_embedding"]
                )
            else:
                assert bins is None
                self.m_num = lib.deep.make_module(
                    **num_embeddings, n_features=n_num_features
                )
                d_num = n_num_features * num_embeddings["d_embedding"]

        self.m_cat = (
            lib.deep.OneHotEncoding0d(cat_cardinalities) if cat_cardinalities else None
        )
        d_cat = sum(cat_cardinalities)

        backbone["d_in"] = d_num + n_bin_features + d_cat

        self.pretrain = pretrain
        self.normalize_similarity = normalize_similarity

        if self.pretrain:
            self.backbone = nn.Linear(backbone["d_in"], backbone["d_main"])
        else:
            self.backbone = lib.deep.make_module(
                **backbone,
                d_out=lib.deep.get_d_out(n_classes),
            )

    def forward(
        self,
        *,
        x_num: None | Tensor = None,
        x_bin: None | Tensor = None,
        x_cat: None | Tensor = None,
        y: None | Tensor = None,
    ) -> Tensor:
        x = []
        if not self.pretrain:
            assert y is None, "No target should be here if not in pretrain"

        if x_num is not None:
            x.append(x_num if self.m_num is None else self.m_num(x_num))

        if x_bin is not None:
            x.append(x_bin if self.m_bin is None else self.m_bin(x_bin))
        if x_cat is None:
            assert self.m_cat is None
        else:
            assert self.m_cat is not None
            x.append(self.m_cat(x_cat))

        x = torch.column_stack([x_.flatten(1, -1) for x_ in x])
        x = self.backbone(x)

        if self.pretrain:
            # Triplet Loss:
            assert y is not None, "y must be provided for triplet loss"
            assert x_num is not None, "This is a numerical feature embeddings thing"
            thirds_size = x_num.shape[0] // 3

            x_central = x[:thirds_size].flatten(1, -1)
            x_first = x[thirds_size : 2 * thirds_size].flatten(1, -1)
            x_second = x[2 * thirds_size : 3 * thirds_size].flatten(1, -1)

            if self.normalize_similarity:
                sim_first = (x_central * x_first).sum(-1) / x_central.shape[-1] ** 0.5
                sim_second = (x_central * x_second).sum(-1) / x_central.shape[-1] ** 0.5
            else:
                sim_first = (x_central * x_first).sum(-1)
                sim_second = (x_central * x_second).sum(-1)

            diff_first = (y[:thirds_size] - y[thirds_size : 2 * thirds_size]).abs()
            diff_second = (y[:thirds_size] - y[2 * thirds_size : 3 * thirds_size]).abs()
            answer = (diff_first > diff_second) * 1
            flat_preds = torch.cat((sim_first[:, None], sim_second[:, None]), dim=-1)
            loss = nn.CrossEntropyLoss()(flat_preds, answer)

            return loss[None]

        return x


class Config(TypedDict):
    seed: int
    data: KWArgs
    bins: NotRequired[KWArgs]
    model: KWArgs
    optimizer: KWArgs
    optimizer_triplet_pretrain: KWArgs
    n_lr_warmup_epochs: NotRequired[int]
    batch_size: int
    batch_size_triplet_pretrain: NotRequired[int]
    patience: int
    patience_triplet_pretrain: int
    n_epochs: int
    n_epochs_triplet_pretrain: int
    normalize_similarity: NotRequired[bool]
    gradient_clipping_norm: NotRequired[float]
    parameter_statistics: NotRequired[bool]
    amp: NotRequired[bool]


def main(
    config: Config, output: str | Path, *, force: bool = False
) -> None | lib.JSONDict:
    # >>> start
    assert set(config) >= Config.__required_keys__, set(Config.__required_keys__) - set(
        config
    )
    assert set(config) <= Config.__required_keys__ | Config.__optional_keys__, set(
        config
    ) - set(Config.__required_keys__)
    if not lib.start(output, force=force):
        return None

    lib.show_config(config)  # type: ignore[code]
    output = Path(output)
    delu.random.seed(config["seed"])
    device = lib.get_device()
    report = lib.create_report(config)  # type: ignore[code]

    # >>> dataset
    dataset = lib.data.build_dataset(**config["data"])
    if dataset.task.is_regression:
        dataset.data["y"], regression_label_stats = lib.data.standardize_labels(
            dataset.data["y"]
        )
    else:
        regression_label_stats = None
    dataset = dataset.to_torch(device)
    Y_train = dataset.data["y"]["train"].to(
        torch.long if dataset.task.is_multiclass else torch.float
    )

    # >>> model
    if "bins" in config:
        compute_bins_kwargs = (
            {
                "y": Y_train.to(
                    torch.long if dataset.task.is_classification else torch.float
                ),
                "regression": dataset.task.is_regression,
                "verbose": True,
            }
            if "tree_kwargs" in config["bins"]
            else {}
        )
        bin_edges = rtdl_num_embeddings.compute_bins(
            dataset["x_num"]["train"], **config["bins"], **compute_bins_kwargs
        )
        logger.info(f"Bin counts: {[len(x) - 1 for x in bin_edges]}")
    else:
        bin_edges = None

    backbone_triplet_pretrain_conf = config["model"].pop("backbone_triplet_pretrain")
    backbone_conf = config["model"].pop("backbone")

    model = Model(
        n_num_features=dataset.n_num_features,
        n_bin_features=dataset.n_bin_features,
        cat_cardinalities=dataset.compute_cat_cardinalities(),
        n_classes=dataset.task.try_compute_n_classes(),
        backbone=backbone_triplet_pretrain_conf,
        pretrain=True,
        bins=bin_edges,
        **config["model"],
    )

    report["n_parameters"] = lib.deep.get_n_parameters(model)
    logger.info(f"n_parameters = {report['n_parameters']}")
    report["prediction_type"] = "labels" if dataset.task.is_regression else "logits"
    model.to(device)
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)

    # >>> training
    optimizer = lib.deep.make_optimizer(
        **config["optimizer_triplet_pretrain"],
        params=lib.deep.make_parameter_groups(model),
    )
    loss_fn = lib.deep.get_loss_fn(dataset.task.type_)
    gradient_clipping_norm = config.get("gradient_clipping_norm")

    batch_size = config["batch_size"]
    report["epoch_size"] = epoch_size = math.ceil(dataset.size("train") / batch_size)
    eval_batch_size = 32768
    chunk_size = None
    generator = torch.Generator(device).manual_seed(config["seed"])

    report["metrics"] = {"val": {"score": -math.inf}}
    if "n_lr_warmup_epochs" in config:
        n_warmup_steps = min(10000, config["n_lr_warmup_epochs"] * epoch_size)
        n_warmup_steps = max(1, math.trunc(n_warmup_steps / epoch_size)) * epoch_size
        logger.info(f"{n_warmup_steps=}")
        lr_scheduler = torch.optim.lr_scheduler.LinearLR(
            optimizer, start_factor=0.01, total_iters=n_warmup_steps
        )
    else:
        lr_scheduler = None

    timer = delu.tools.Timer()
    parameter_statistics = config.get("parameter_statistics", config["seed"] == 1)
    training_log = []
    writer = torch.utils.tensorboard.SummaryWriter(output)  # type: ignore[code]

    amp_enabled = (
        config.get("amp", False)
        and device.type == "cuda"
        and torch.cuda.is_bf16_supported()
    )
    logger.info(f"AMP enabled: {amp_enabled}")

    @torch.autocast(  # type: ignore[code]
        device.type, enabled=amp_enabled, dtype=torch.bfloat16 if amp_enabled else None
    )
    def apply_model(part: PartKey, idx: Tensor) -> Tensor:
        model_input = {
            key: dataset.data[key][part][idx]  # type: ignore[code]
            for key in ["x_num", "x_bin", "x_cat"]
            if key in dataset  # type: ignore[index]
        }
        if model.pretrain:
            model_input["y"] = dataset.data["y"][part][idx]

        if model.pretrain:
            return model(**model_input).float()
        else:
            return model(**model_input).squeeze(-1).float()

    @torch.inference_mode()
    def evaluate(parts: list[PartKey], eval_batch_size: int) -> EvalOut:
        model.eval()
        predictions: dict[PartKey, np.ndarray] = {}
        for part in parts:
            while eval_batch_size:
                try:
                    predictions[part] = (
                        torch.cat(
                            [
                                apply_model(part, idx)
                                for idx in torch.arange(
                                    len(dataset.data["y"][part]),
                                    device=device,
                                ).split(eval_batch_size)
                            ]
                        )
                        .cpu()
                        .numpy()
                    )
                except RuntimeError as err:
                    if not lib.is_oom_exception(err):
                        raise
                    eval_batch_size //= 2
                    logger.warning(f"eval_batch_size = {eval_batch_size}")
                else:
                    break
            if not eval_batch_size:
                RuntimeError("Not enough memory even for eval_batch_size=1")
        if model.pretrain:
            metrics = dict()
            for part in predictions.keys():
                metrics[part] = {"score": float(-predictions[part].mean())}
        else:
            if regression_label_stats is not None:
                predictions = {
                    k: v * regression_label_stats.std + regression_label_stats.mean
                    for k, v in predictions.items()
                }
            metrics = (
                dataset.task.calculate_metrics(predictions, report["prediction_type"])
                if lib.are_valid_predictions(predictions)
                else {x: {"score": -999999.0} for x in predictions}
            )

        return metrics, predictions, eval_batch_size

    def train_loop(
        *,
        step_fn: Callable[[Tensor], Tensor],
        eval_fn: Callable[..., tuple],
        batch_size: int,
        n_epochs: int,
        patience: int,
        report_key: str,
        chunk_size=None,
        eval_batch_size=eval_batch_size,
    ):
        def save_checkpoint(step) -> None:
            lib.dump_checkpoint(
                output,
                {
                    "step": step,
                    "model": model.state_dict(),
                    "optimizer": optimizer.state_dict(),
                    "generator": generator.get_state(),
                    "random_state": delu.random.get_state(),
                    "early_stopping": early_stopping,
                    "report": report,
                    "timer": timer,
                    "training_log": training_log,
                }
                | (
                    {}
                    if lr_scheduler is None
                    else {"lr_scheduler": lr_scheduler.state_dict()}
                ),
            )
            lib.dump_report(output, report)
            lib.backup_output(output)

        step = 0
        early_stopping = delu.tools.EarlyStopping(patience, mode="max")
        report[report_key] = {"metrics": {"val": {"score": -math.inf}}}

        while n_epochs == -1 or step // epoch_size < n_epochs:
            print(f"[...] {output} | {timer}")

            # >>>
            model.train()
            epoch_losses = []
            logs_train = defaultdict(list)

            for batch_idx in tqdm(
                torch.randperm(
                    len(dataset.data["y"]["train"]), generator=generator, device=device
                ).split(batch_size),
                desc=f"Epoch {step // epoch_size} Step {step}",
            ):
                loss, new_chunk_size = lib.deep.zero_grad_forward_backward(
                    optimizer,
                    step_fn,
                    batch_idx,
                    chunk_size or batch_size,
                )

                for k, v in log_dict.items():
                    logs_train[k].append(v)

                if parameter_statistics and (
                    step % epoch_size == 0  # The first batch of the epoch.
                    or step // epoch_size == 0  # The first epoch.
                ):
                    for k, v in lib.deep.compute_parameter_stats(model).items():
                        writer.add_scalars(
                            f"{report_key}/{k}", v, step, timer.elapsed()
                        )
                        del k, v

                if gradient_clipping_norm is not None:
                    nn.utils.clip_grad.clip_grad_norm_(
                        model.parameters(), gradient_clipping_norm
                    )
                optimizer.step()

                if lr_scheduler is not None:
                    lr_scheduler.step()
                step += 1
                epoch_losses.append(loss.detach())
                if new_chunk_size and new_chunk_size < (chunk_size or batch_size):
                    chunk_size = new_chunk_size
                    logger.warning(f"chunk_size = {chunk_size}")

            epoch_losses = torch.stack(epoch_losses).tolist()
            mean_loss = statistics.mean(epoch_losses)

            metrics, predictions, eval_batch_size = eval_fn(
                ["val", "test"], eval_batch_size
            )
            metrics["train"] = {}
            for k, v in logs_train.items():
                metrics["train"][k] = np.mean(v).item()

            training_log.append(
                {
                    "epoch-losses": epoch_losses,
                    "metrics": metrics,
                    "time": timer.elapsed(),
                }
            )
            lib.print_metrics(mean_loss, metrics)
            writer.add_scalars(
                f"{report_key}/loss", {"train": mean_loss}, step, timer.elapsed()
            )
            for part in metrics:
                for k in metrics[part].keys():
                    if k != "score":
                        continue
                    writer.add_scalars(
                        f"{report_key}/{k}",
                        {part: metrics[part][k]},
                        step,
                        timer.elapsed(),
                    )

            if metrics["val"]["score"] > report[report_key]["metrics"]["val"]["score"]:
                print("🌸 New best epoch! 🌸")
                report[report_key]["best_step"] = step
                report[report_key]["metrics"] = metrics
                save_checkpoint(step)
                lib.dump_predictions(output, predictions)

            early_stopping.update(metrics["val"]["score"])
            if early_stopping.should_stop() or not lib.are_valid_predictions(
                predictions
            ):
                break

            print()
        return chunk_size, eval_batch_size

    def pretrain_step(idx):
        return apply_model("train", idx).mean()

    def finetune_step(idx):
        return loss_fn(apply_model("train", idx), Y_train[idx])

    # Log everything from here
    log_dict = {}

    # >>> pretrain
    print("Pretraining")
    timer.run()

    chunk_size, eval_batch_size = train_loop(
        step_fn=pretrain_step,
        batch_size=config.get("batch_size_triplet_pretrain", batch_size),
        eval_fn=evaluate,
        n_epochs=config["n_epochs_triplet_pretrain"],
        patience=config.get("patience_triplet_pretrain", config["patience"]),
        report_key="pretrain",
        chunk_size=chunk_size,
    )

    log_dict = {}
    # >>> finetune
    print("Finetuning")
    model = Model(
        n_num_features=dataset.n_num_features,
        n_bin_features=dataset.n_bin_features,
        cat_cardinalities=dataset.compute_cat_cardinalities(),
        n_classes=dataset.task.try_compute_n_classes(),
        backbone=backbone_conf,
        pretrain=False,
        bins=bin_edges,
        **config["model"],
    )
    try:
        ckpt = lib.load_checkpoint(output)["model"]
        model.m_num.load_state_dict(  # pyright: ignore
            {k[6:]: v for k, v in ckpt.items() if k.startswith("m_num")}
        )
    except Exception as e:
        print("Failed loading checkpoint")
        print(e)

    # fresh optimizer
    model.to(device)
    optimizer = lib.deep.make_optimizer(
        **config["optimizer"], params=lib.deep.make_parameter_groups(model)
    )

    chunk_size, eval_batch_size = train_loop(
        step_fn=finetune_step,
        batch_size=batch_size,
        eval_fn=evaluate,
        n_epochs=config["n_epochs"],
        patience=config["patience"],
        report_key="finetune",
        chunk_size=chunk_size,
    )
    report["time"] = str(timer)

    # >>> finish
    model.load_state_dict(lib.load_checkpoint(output)["model"])
    report["metrics"], predictions, _ = evaluate(
        ["train", "val", "test"], eval_batch_size
    )
    report["chunk_size"] = chunk_size
    report["eval_batch_size"] = eval_batch_size
    lib.dump_predictions(output, predictions)
    lib.dump_summary(output, lib.summarize(report))
    lib.finish(output, report)
    return report


if __name__ == "__main__":
    lib.configure_libraries()
    lib.run_MainFunction_cli(main)
