import math
import statistics
import sys
from pathlib import Path
from typing import Any, Literal

import delu
import numpy as np
import rtdl_num_embeddings
from sklearn.discriminant_analysis import StandardScaler
from sklearn.pipeline import FunctionTransformer, Pipeline
from sklearn.preprocessing import PowerTransformer, QuantileTransformer
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.tensorboard
from loguru import logger
from torch import Tensor
from tqdm import tqdm
from typing_extensions import NotRequired, TypedDict
from lib.tabpfn.model.bar_distribution import (
    BarDistribution,
    FullSupportBarDistribution,
)
from lib.tabpfn.utils import _transform_borders_one, translate_probs_across_borders

if __name__ == "__main__":
    _cwd = Path.cwd()
    assert _cwd.joinpath(".git").exists(), (
        "The script must be run from the root of the repository"
    )
    sys.path.append(str(_cwd))
    del _cwd

import lib
import lib.data
import lib.deep
from lib import KWArgs, PartKey


def regression_output_transform(
    target_transform,
    criterion,
    renormalized_criterion,
    softmax_temperature: float = 0.9,
):
    "transform model outputs into original values for regression; extracted from regressor.py in tabpfn"

    std_borders = criterion.borders.cpu().numpy()
    logit_cancel_mask, descending_borders, borders_t = _transform_borders_one(
        std_borders,
        target_transform=target_transform,
        repair_nan_borders_after_transform=True,
    )
    if descending_borders:
        borders_t = borders_t.flip(-1)  # type: ignore

    device = lib.get_device()

    def transform(out):
        logits = translate_probs_across_borders(
            out.float() / softmax_temperature,
            frm=torch.as_tensor(borders_t, device=device),
            to=criterion.borders.to(device),
        )
        if logit_cancel_mask is not None:
            out = out.clone()  # pyright: ignore
            out[..., logit_cancel_mask] = float("-inf")

        logits = logits.log()
        if logits.dtype == torch.float16:
            logits = logits.float()
        logits = logits.cpu()

        return renormalized_criterion.mean(logits)

    return transform


class Residual_block(nn.Module):
    def __init__(self, d_in, d, dropout):
        super().__init__()
        self.linear0 = nn.Linear(d_in, d)
        self.Linear1 = nn.Linear(d, d_in)
        self.bn = nn.BatchNorm1d(d_in)
        self.dropout = nn.Dropout(dropout)
        self.activation = nn.ReLU()

    def forward(self, x):
        z = self.bn(x)
        z = self.linear0(z)
        z = self.activation(z)
        z = self.dropout(z)
        z = self.Linear1(z)
        # z=x+z
        return z


# This implementation is based on the official implementation of ModernNCA:
# https://github.com/qile2000/LAMDA-TALENT/blob/2d7a166772ca2e79d0fa8b5f73a1ae6dbb8c5f09/LAMDA-TALENT/model/models/modernNCA.py
class ModernNCA(nn.Module):
    def __init__(
        self,
        *,
        n_num_features: int,
        cat_cardinalities: list[int],
        n_classes: None | int,
        #
        dim: int,
        dropout: int,
        d_block: None | int = None,
        d_block_multiplier: None | float = None,
        n_blocks: int,
        bins: None | list[Tensor],
        num_embeddings: None | dict = None,
        temperature: float = 1.0,
        sample_rate: float = 0.8,
    ) -> None:
        if d_block is None:
            assert d_block_multiplier is not None
            d_block = int(d_block_multiplier * dim)
        else:
            assert d_block_multiplier is None
        super().__init__()

        if n_num_features == 0:
            assert bins is None
            self.num_module = None
        elif num_embeddings is None:
            assert bins is None
            self.num_module = None
        else:
            if bins is None:
                self.num_module = lib.deep.make_module(
                    **num_embeddings, n_features=n_num_features
                )
            else:
                assert num_embeddings["type"].startswith("PiecewiseLinearEmbeddings")
                self.num_module = lib.deep.make_module(**num_embeddings, bins=bins)

        self.cat_module = (
            lib.deep.OneHotEncoding0d(cat_cardinalities) if cat_cardinalities else None
        )

        self.d_in = n_num_features * (
            1 if num_embeddings is None else num_embeddings["d_embedding"]
        ) + sum(cat_cardinalities)
        self.n_classes = n_classes
        self.dim = dim
        self.dropout = dropout
        self.d_block = d_block
        self.n_blocks = n_blocks
        self.T = temperature
        self.sample_rate = sample_rate
        if n_blocks > 0:
            self.post_encoder = nn.Sequential()
            for i in range(n_blocks):
                name = f"ResidualBlock{i}"
                self.post_encoder.add_module(name, self.make_layer())
            self.post_encoder.add_module("bn", nn.BatchNorm1d(dim))
        self.encoder = nn.Linear(self.d_in, dim)
        # self.bn=nn.BatchNorm1d(dim)

    def make_layer(self):
        block = Residual_block(self.dim, self.d_block, self.dropout)
        return block

    def _pre_encoder(self, x_num: None | Tensor, x_cat: None | Tensor) -> Tensor:
        x = []
        if x_num is not None:
            x.append(x_num if self.num_module is None else self.num_module(x_num))
        if x_cat is None:
            assert self.cat_module is None
        else:
            assert self.cat_module is not None
            x.append(self.cat_module(x_cat))
        x = torch.column_stack([x_.flatten(1, -1) for x_ in x])
        return x

    def forward(
        self,
        *,
        x_num: None | Tensor = None,
        x_cat: None | Tensor = None,
        y: None | Tensor,
        candidate_x_num: None | Tensor = None,
        candidate_x_cat: None | Tensor = None,
        candidate_y: Tensor,
        is_train: bool,
    ):
        if is_train:
            data_size = len(candidate_y)
            retrival_size = int(data_size * self.sample_rate)
            sample_idx = torch.randperm(data_size)[:retrival_size]
            candidate_x_num = (
                None if candidate_x_num is None else candidate_x_num[sample_idx]
            )
            candidate_x_cat = (
                None if candidate_x_cat is None else candidate_x_cat[sample_idx]
            )
            candidate_y = candidate_y[sample_idx]

        x = self._pre_encoder(x_num, x_cat)
        candidate_x = self._pre_encoder(candidate_x_num, candidate_x_cat)
        dtype = x.dtype if x.dtype != torch.int64 else torch.float32

        if self.n_blocks > 0:
            candidate_x = self.post_encoder(self.encoder(candidate_x.to(dtype)))
            x = self.post_encoder(self.encoder(x.to(dtype)))
        else:
            candidate_x = self.encoder(candidate_x.to(dtype))
            x = self.encoder(x.to(dtype))
        if is_train:
            assert y is not None
            candidate_x = torch.cat([x, candidate_x])
            candidate_y = torch.cat([y, candidate_y])
        else:
            assert y is None

        if self.n_classes is not None:
            candidate_y = F.one_hot(
                self._loss_fn.map_to_bucket_idx(candidate_y), self.n_classes
            ).to(dtype)
        elif len(candidate_y.shape) == 1:
            candidate_y = candidate_y.unsqueeze(-1)

        # The NCA-related computations are always performed at least in float32.
        if x.dtype != torch.float64:
            x = x.float()
            candidate_x = candidate_x.float()
            candidate_y = candidate_y.float()

        # calculate distance
        distances = torch.cdist(x, candidate_x, p=2)
        distances = distances / self.T
        # remove the label of training index
        if is_train:
            distances = distances.clone().fill_diagonal_(torch.inf)
        distances = F.softmax(-distances, dim=-1)
        logits = torch.mm(distances, candidate_y)
        # print(logits.shape)
        # print(logits[:, 1])
        eps = 1e-7
        if self.n_classes is not None:
            logits = torch.log(logits + eps)
        # return logits.squeeze()
        return logits.to(dtype)


class Config(TypedDict):
    seed: int
    data: KWArgs
    bins: NotRequired[KWArgs]
    model: KWArgs
    optimizer: KWArgs
    batch_size: int
    patience: int
    n_epochs: int
    gradient_clipping_norm: NotRequired[float]
    amp_dtype: NotRequired[Literal["float16", "bfloat16"]]
    target_transform: NotRequired[Literal["power", "quantile", "identity"]]
    n_target_bins: NotRequired[int]


def main(
    config: Config | str | Path,
    output: None | str | Path = None,
    *,
    force: bool = False,
) -> None | lib.JSONDict:
    # >>> Start
    config, output = lib.check(config, output, config_type=Config)
    if not lib.start(output, force=force):
        return None

    lib.show_config(config)  # type: ignore[code]
    delu.random.seed(config["seed"])
    device = lib.get_device()
    report = lib.create_report(config)  # type: ignore

    # >>> Data
    dataset = lib.data.build_dataset(**config["data"])
    assert dataset.task.is_regression

    if dataset.task.is_regression:
        dataset.data["y"], regression_label_stats = lib.data.standardize_labels(
            dataset.data["y"]
        )
        target_transform_type = config.pop("target_transform", "identity")
        if target_transform_type == "power":
            target_transform = Pipeline(
                [("power", PowerTransformer()), ("standard", StandardScaler())]
            ).fit(dataset.data["y"]["train"].reshape(-1, 1))
        elif target_transform_type == "quantile":
            target_transform = QuantileTransformer(
                output_distribution="normal", random_state=config["seed"]
            ).fit(dataset.data["y"]["train"].reshape(-1, 1))
        elif target_transform_type == "identity":
            target_transform = FunctionTransformer(func=None)
        else:
            raise ValueError(f"Unknown target_transform {target_transform_type}")

        dataset.data["y"] = {
            part: target_transform.transform(dataset.data["y"][part].reshape(-1, 1))
            .astype(np.float32)  # type: ignore
            .squeeze()  # pyright: ignore
            for part in dataset.data["y"]
        }
    else:
        target_transform = None
        regression_label_stats = None

    # Convert binary features to categorical features.
    if dataset.n_bin_features > 0:
        x_bin = dataset.data.pop("x_bin")
        # Remove binary features with just one unique value in the training set.
        # This must be done, otherwise, the script will fail on one specific dataset
        # from the "why" benchmark.
        n_bin_features = x_bin["train"].shape[1]
        good_bin_idx = [
            i for i in range(n_bin_features) if len(np.unique(x_bin["train"][:, i])) > 1
        ]
        if len(good_bin_idx) < n_bin_features:
            x_bin = {k: v[:, good_bin_idx] for k, v in x_bin.items()}

        if dataset.n_cat_features == 0:
            dataset.data["x_cat"] = {
                part: np.zeros((dataset.size(part), 0), dtype=np.int64)
                for part in x_bin
            }
        for part in x_bin:
            dataset.data["x_cat"][part] = np.column_stack(
                [dataset.data["x_cat"][part], x_bin[part].astype(np.int64)]
            )
        del x_bin

    dataset = dataset.to_torch(device)
    Y_train = dataset.data["y"]["train"].to(
        torch.long if dataset.task.is_classification else torch.float
    )

    def unique_with_precision(arr, precision=1e-5):
        unique_values = []
        for value in arr:
            if not unique_values or abs(value - unique_values[-1]) > precision:
                unique_values.append(value)
        return torch.tensor(unique_values, dtype=torch.float32, device=device)

    borders = torch.quantile(
        Y_train,
        q=torch.linspace(0.0, 1.0, steps=config.get("n_target_bins", 128) + 1).to(
            Y_train
        ),
    )
    borders = unique_with_precision(borders)
    print("Len of borders:", len(borders))

    # >>> model
    if "bins" in config:
        # Compute the bins for PiecewiseLinearEncoding and PiecewiseLinearEmbeddings.
        compute_bins_kwargs = (
            {
                "y": Y_train.to(
                    torch.long if dataset.task.is_classification else torch.float
                ),
                "regression": dataset.task.is_regression,
                "verbose": True,
            }
            if "tree_kwargs" in config["bins"]
            else {}
        )
        bin_edges = rtdl_num_embeddings.compute_bins(
            dataset["x_num"]["train"], **config["bins"], **compute_bins_kwargs
        )
        logger.info(f"Bin counts: {[len(x) - 1 for x in bin_edges]}")
    else:
        bin_edges = None
    model = ModernNCA(
        n_num_features=dataset.n_num_features,
        cat_cardinalities=dataset.compute_cat_cardinalities(),
        n_classes=len(borders) - 1,
        **config["model"],
        bins=bin_edges,
    )
    report["n_parameters"] = lib.deep.get_n_parameters(model)
    logger.info(f"n_parameters = {report['n_parameters']}")
    report["prediction_type"] = "labels" if dataset.task.is_regression else "probs"
    model.to(device)

    # >>> Training
    step = 0
    train_size = dataset.size("train")
    batch_size = config["batch_size"]
    report["epoch_size"] = epoch_size = math.ceil(train_size / batch_size)
    eval_batch_size = 32768
    chunk_size = None
    train_indices = torch.arange(train_size, device=device)

    optimizer = lib.deep.make_optimizer(
        **config["optimizer"],
        params=lib.deep.make_parameter_groups(model),
    )
    gradient_clipping_norm = config.get("gradient_clipping_norm")
    if dataset.task.is_regression:
        assert regression_label_stats is not None
        # borders = torch.load(
        #     lib.PROJECT_DIR / "tabpfn-v2-regressor.ckpt",
        #     weights_only=True,
        # )["state_dict"]["criterion.borders"].to(device)
        # borders = torch.linspace(Y_train.min(), Y_train.max(), steps=501).to(Y_train)
        _loss_fn = FullSupportBarDistribution(borders)
        model._loss_fn = _loss_fn
        renormalized_criterion = FullSupportBarDistribution(
            _loss_fn.borders * regression_label_stats.std + regression_label_stats.mean,
        ).float()
        pred_transform = regression_output_transform(
            target_transform,  # pyright: ignore
            _loss_fn,
            renormalized_criterion,
            softmax_temperature=config["model"].get("softmax_temperature", 1.0),
        )

    def loss_fn(y_pred: Tensor, y_true: Tensor) -> Tensor:
        return _loss_fn(
            y_pred.unsqueeze(1),
            y_true.unsqueeze(1),
        ).mean()

    batch_generator = torch.Generator(device).manual_seed(config["seed"])
    timer = delu.tools.Timer()
    early_stopping = delu.tools.EarlyStopping(config["patience"], mode="max")
    training_log = []
    writer = torch.utils.tensorboard.SummaryWriter(output)  # type: ignore[code]

    amp_dtype = config.get("amp_dtype")
    if amp_dtype is not None:
        amp_dtype = getattr(torch, amp_dtype)
    amp_enabled = amp_dtype is not None
    scaler = torch.cuda.amp.GradScaler() if amp_dtype is torch.float16 else None  # type: ignore[code]
    logger.info(f"AMP enabled: {amp_dtype is not None}")

    def get_Xy(part: PartKey, idx: None | Tensor) -> tuple[dict[str, Tensor], Tensor]:
        batch = (
            {
                key: dataset.data[key][part]
                for key in ["x_num", "x_bin", "x_cat"]
                if key in dataset.data
            },
            dataset.data["y"][part],
        )
        return (
            batch
            if idx is None
            else ({k: v[idx] for k, v in batch[0].items()}, batch[1][idx])
        )

    def apply_model(part: str, idx: Tensor, training: bool) -> Tensor:
        # Currently, this argument is not used. However, it can be useful for other
        # variations of the model.
        del training
        x, y = get_Xy(part, idx)

        candidate_indices = train_indices
        # NOTE: is_train and training are different things, as explained here:
        # https://github.com/yandex-research/tabular-dl-tabr/issues/5#issuecomment-1726063188
        is_train = part == "train"
        if is_train:
            # NOTE: here, the training batch is removed from the candidates.
            # It will be added back inside the model's forward pass.
            candidate_indices = candidate_indices[~torch.isin(candidate_indices, idx)]
        candidate_x, candidate_y = get_Xy(
            "train",
            # This condition is here for historical reasons, it could be just
            # the unconditional `candidate_indices`.
            None if candidate_indices is train_indices else candidate_indices,
        )

        with torch.autocast(  # type: ignore[code]
            device.type,
            enabled=amp_enabled,
            dtype=torch.bfloat16 if amp_enabled else None,
        ):
            return model(
                **x,
                y=y if is_train else None,
                **{f"candidate_{k}": v for k, v in candidate_x.items()},
                candidate_y=candidate_y,
                is_train=is_train,
            ).squeeze(-1)

    @torch.inference_mode()
    def evaluate(
        parts: list[PartKey], eval_batch_size: int
    ) -> tuple[dict[PartKey, Any], dict[PartKey, np.ndarray], int]:
        model.eval()
        predictions: dict[PartKey, np.ndarray] = {}
        for part in parts:
            while eval_batch_size:
                try:
                    predictions[part] = (
                        torch.cat(
                            [
                                apply_model(part, idx, False)
                                for idx in torch.arange(
                                    dataset.size(part), device=device
                                ).split(eval_batch_size)
                            ]
                        )
                        .cpu()
                        .numpy()
                    )
                except RuntimeError as err:
                    if not lib.is_oom_exception(err):
                        raise
                    eval_batch_size //= 2
                    delu.cuda.free_memory()
                    logger.warning(f"eval_batch_size = {eval_batch_size}")
                else:
                    break
            if not eval_batch_size:
                RuntimeError("Not enough memory even for eval_batch_size=1")
        delu.cuda.free_memory()
        # if dataset.task.is_regression:
        #     assert regression_label_stats is not None
        #     predictions = {
        #         k: v * regression_label_stats.std + regression_label_stats.mean
        #         for k, v in predictions.items()
        #     }
        # else:
        #     predictions = {k: np.exp(v) for k, v in predictions.items()}
        #     if dataset.task.is_binclass:
        #         predictions = {k: v[:, 1] for k, v in predictions.items()}
        predictions = {
            k: pred_transform(torch.from_numpy(v).to(device)).cpu().numpy()  # pyright: ignore
            for k, v in predictions.items()
        }
        metrics = (
            dataset.task.calculate_metrics(predictions, report["prediction_type"])
            if lib.are_valid_predictions(predictions)
            else {x: {"score": -99999.0} for x in predictions}
        )
        return metrics, predictions, eval_batch_size

    def save_checkpoint() -> None:
        lib.dump_checkpoint(
            output,
            {
                "step": step,
                "model": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "batch_generator": batch_generator.get_state(),
                "random_state": delu.random.get_state(),
                "early_stopping": early_stopping,
                "report": report,
                "timer": timer,
                "training_log": training_log,
            }
            | ({} if scaler is None else {"scaler": scaler.state_dict()}),
        )
        lib.dump_report(output, report)
        lib.backup_output(output)

    print()
    timer.run()
    while config["n_epochs"] == -1 or step // epoch_size < config["n_epochs"]:
        print(f"[...] {lib.try_get_relative_path(output)} | {timer}")

        model.train()
        epoch_losses = []
        for batch_idx in tqdm(
            torch.randperm(
                len(dataset.data["y"]["train"]),
                generator=batch_generator,
                device=device,
            ).split(batch_size),
            desc=f"Epoch {step // epoch_size} Step {step}",
        ):
            loss, new_chunk_size = lib.deep.zero_grad_forward_backward(
                optimizer,
                lambda idx: loss_fn(apply_model("train", idx, True), Y_train[idx]),
                batch_idx,
                chunk_size or batch_size,
                scaler,
            )

            if gradient_clipping_norm is not None:
                if scaler is not None:
                    scaler.unscale_(optimizer)
                nn.utils.clip_grad.clip_grad_norm_(
                    model.parameters(), gradient_clipping_norm
                )
            if scaler is None:
                optimizer.step()
            else:
                scaler.step(optimizer)
                scaler.update()

            step += 1
            epoch_losses.append(loss.detach())
            if new_chunk_size and new_chunk_size < (chunk_size or batch_size):
                chunk_size = new_chunk_size
                logger.warning(f"chunk_size = {chunk_size}")

        epoch_losses = torch.stack(epoch_losses).tolist()
        mean_loss = statistics.mean(epoch_losses)
        metrics, predictions, eval_batch_size = evaluate(
            ["val", "test"], eval_batch_size
        )

        training_log.append(
            {"epoch-losses": epoch_losses, "metrics": metrics, "time": timer.elapsed()}
        )
        lib.print_metrics(mean_loss, metrics)
        writer.add_scalars("loss", {"train": mean_loss}, step, timer.elapsed())
        for part in metrics:
            writer.add_scalars(
                "score", {part: metrics[part]["score"]}, step, timer.elapsed()
            )

        if (
            "metrics" not in report
            or metrics["val"]["score"] > report["metrics"]["val"]["score"]
        ):
            print("🌸 New best epoch! 🌸")
            report["best_step"] = step
            report["metrics"] = metrics
            save_checkpoint()
            lib.dump_predictions(output, predictions)

        early_stopping.update(metrics["val"]["score"])
        if early_stopping.should_stop() or not lib.are_valid_predictions(predictions):
            break

        print()
    report["time"] = str(timer)

    # >>>
    if lib.get_checkpoint_path(output).exists():
        model.load_state_dict(lib.load_checkpoint(output)["model"])
    report["metrics"], predictions, _ = evaluate(
        ["train", "val", "test"], eval_batch_size
    )
    report["chunk_size"] = chunk_size
    report["eval_batch_size"] = eval_batch_size
    lib.dump_predictions(output, predictions)
    lib.dump_summary(output, lib.summarize(report))
    save_checkpoint()
    lib.finish(output, report)
    return report


if __name__ == "__main__":
    lib.configure_libraries()
    lib.run_MainFunction_cli(main)
