# Feed-forward network[s].

import math
import statistics
from pathlib import Path
from typing import Any

import delu
import numpy as np
import rtdl_num_embeddings
import torch
import torch.nn as nn
import torch.utils.tensorboard
from loguru import logger
from torch import Tensor
from tqdm import tqdm
from typing_extensions import NotRequired, TypedDict

import lib
from lib import KWArgs, PartKey
from tabm_reference_cifar import Model


class SNN(nn.Module):
    """SNN from 'Self-Normalizing Neural Networks'."""

    def __init__(
        self, *, d_in: int, d_out: int, n_blocks: int, d_block: int, dropout: float
    ) -> None:
        super().__init__()
        self.blocks = nn.ModuleList(
            [
                nn.Sequential(
                    nn.Linear(d_block if i else d_in, d_block),
                    nn.SELU(),
                    nn.AlphaDropout(dropout),
                )
                for i in range(n_blocks)
            ]
        )
        self.output = nn.Linear(d_block, d_out)
        for module in self.modules():
            if isinstance(module, nn.Linear):
                # The initialization follows the official implementation:
                # https://github.com/bioinf-jku/SNNs/blob/b578499301fcb801f8d4135dbd7cebb246722bfc/Pytorch/SelfNormalizingNetworks_MLP_MNIST.ipynb
                nn.init.kaiming_normal_(
                    module.weight, mode="fan_in", nonlinearity="linear"
                )
                nn.init.zeros_(module.bias)

    def forward(self, x: Tensor) -> Tensor:
        for block in self.blocks:
            x = block(x)
        if self.output is not None:
            x = self.output(x)
        return x


class DCNv2(nn.Module):
    def __init__(
        self,
        d_in: int,
        d_deep: int,
        d_out: int,
        n_cross_layers: int,
        n_deep_layers: int,
        dropout_p: float,
        k_low_rank_cross: int | None = None,
        nonlin_cross: bool = False,
    ):
        super().__init__()

        def get_cross_layer():
            "Cross layer variations from the paper, no MoE variation"

            if k_low_rank_cross is None:
                m = nn.Linear(d_in, d_in, bias=True)
                torch.nn.init.zeros_(m.bias)
            elif nonlin_cross:
                d_low_rank_cross = d_in // k_low_rank_cross
                m = nn.Sequential(
                    nn.Linear(d_in, d_low_rank_cross, bias=False),
                    nn.ReLU(True),
                    nn.Linear(d_low_rank_cross, d_low_rank_cross, bias=False),
                    nn.ReLU(True),
                    nn.Linear(d_low_rank_cross, d_in),
                )
                torch.nn.init.zeros_(m[-1].bias)
            else:
                d_low_rank_cross = d_in // k_low_rank_cross
                m = nn.Sequential(
                    nn.Linear(d_in, d_low_rank_cross, bias=False),
                    nn.Linear(d_low_rank_cross, d_in),
                )
                torch.nn.init.zeros_(m[-1].bias)
            return m

        self.cross_layers = nn.ModuleList(
            [get_cross_layer() for _ in range(n_cross_layers)]
        )

        def get_dnn_layer(d_in=None):
            return nn.Sequential(
                nn.Linear(d_in if d_in is not None else d_deep, d_deep),
                nn.ReLU(True),
                nn.Dropout(dropout_p),
            )

        self.deep_layers = nn.Sequential(
            *[
                get_dnn_layer(d_in=d_in if i == 0 else None)
                for i in range(n_deep_layers)
            ]
        )
        self.head = nn.Linear(d_deep, d_out)

    def forward(self, x):
        x0 = x
        for c in self.cross_layers:
            x = x0 * c(x)

        x = self.deep_layers(x)
        x = self.head(x)

        return x


lib.deep.register_module(DCNv2.__name__, DCNv2)
lib.deep.register_module(SNN.__name__, SNN)


class Config(TypedDict):
    seed: int
    data: KWArgs
    bins: NotRequired[KWArgs]
    model: KWArgs
    optimizer: KWArgs
    n_lr_warmup_epochs: NotRequired[int]
    batch_size: int
    patience: int
    n_epochs: int
    gradient_clipping_norm: NotRequired[float]
    parameter_statistics: NotRequired[bool]


def main(
    config: Config, output: str | Path, *, force: bool = False
) -> None | lib.JSONDict:
    # >>> start
    assert set(config) >= Config.__required_keys__
    assert set(config) <= Config.__required_keys__ | Config.__optional_keys__
    if not lib.start(output, force=force):
        return None

    lib.show_config(config)  # type: ignore[code]
    output = Path(output)
    delu.random.seed(config["seed"])
    device = lib.get_device()
    report = lib.create_report(config)  # type: ignore[code]

    # >>> dataset
    dataset = lib.data.build_dataset(**config["data"])
    if dataset.task.is_regression:
        dataset.data["y"], regression_label_stats = lib.data.standardize_labels(
            dataset.data["y"]
        )
    else:
        regression_label_stats = None
    dataset = dataset.to_torch(device)
    Y_train = dataset.data["y"]["train"].to(
        torch.long if dataset.task.is_multiclass else torch.float
    )

    # >>> model
    if "bins" in config:
        compute_bins_kwargs = (
            {
                "y": Y_train.to(
                    torch.long if dataset.task.is_classification else torch.float
                ),
                "regression": dataset.task.is_regression,
                "verbose": True,
            }
            if "tree_kwargs" in config["bins"]
            else {}
        )
        bin_edges = rtdl_num_embeddings.compute_bins(
            dataset["x_num"]["train"], **config["bins"], **compute_bins_kwargs
        )
        logger.info(f"Bin counts: {[len(x) - 1 for x in bin_edges]}")
    else:
        bin_edges = None

    config["model"]["arch_type"] = config["model"].get("arch_type", "tabm")

    model = Model(
        n_num_features=dataset.n_num_features,
        cat_cardinalities=dataset.compute_cat_cardinalities()
        + [2] * dataset.n_bin_features,
        n_classes=dataset.task.try_compute_n_classes(),
        bins=bin_edges,
        **config["model"],
    )
    report["n_parameters"] = lib.deep.get_n_parameters(model)
    logger.info(f"n_parameters = {report['n_parameters']}")
    report["prediction_type"] = "labels" if dataset.task.is_regression else "logits"
    model.to(device)
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)

    # >>> training
    optimizer = lib.deep.make_optimizer(
        **config["optimizer"], params=lib.deep.make_parameter_groups(model)
    )
    loss_fn = lib.deep.get_loss_fn(dataset.task.type_)
    gradient_clipping_norm = config.get("gradient_clipping_norm")

    step = 0
    batch_size = config["batch_size"]
    report["epoch_size"] = epoch_size = math.ceil(dataset.size("train") / batch_size)
    eval_batch_size = 32768
    chunk_size = None
    generator = torch.Generator(device).manual_seed(config["seed"])

    report["metrics"] = {"val": {"score": -math.inf}}
    if "n_lr_warmup_epochs" in config:
        n_warmup_steps = min(10000, config["n_lr_warmup_epochs"] * epoch_size)
        n_warmup_steps = max(1, math.trunc(n_warmup_steps / epoch_size)) * epoch_size
        logger.info(f"{n_warmup_steps=}")
        lr_scheduler = torch.optim.lr_scheduler.LinearLR(
            optimizer, start_factor=0.0001, total_iters=n_warmup_steps
        )
    else:
        lr_scheduler = None
    timer = delu.tools.Timer()
    early_stopping = delu.tools.EarlyStopping(config["patience"], mode="max")
    parameter_statistics = config.get("parameter_statistics", config["seed"] == 1)
    training_log = []
    writer = torch.utils.tensorboard.SummaryWriter(output)  # type: ignore[code]

    def apply_model(part: PartKey, idx: Tensor) -> Tensor:

        if not('x_bin' in dataset.data) and not('x_cat' in dataset.data):
            return model(x_num=dataset.data['x_num'][part][idx], x_cat=None).squeeze()
        elif not('x_cat' in dataset.data):
            return model(x_num=dataset.data['x_num'][part][idx], x_cat=dataset.data['x_bin'][part][idx].long()).squeeze()
        elif not('x_bin' in dataset.data):
            ret = model(x_num=dataset.data['x_num'][part][idx], x_cat=dataset.data['x_cat'][part][idx]).squeeze()
            #print(ret.shape)
            return ret
        else:
            ret = model(x_num=dataset.data['x_num'][part][idx], x_cat=torch.cat((dataset.data['x_cat'][part][idx], dataset.data['x_bin'][part][idx].long()), dim=-1)).squeeze()
            #print(ret.shape)
            return ret
        if not ("x_bin" in dataset.data) and not ("x_cat" in dataset.data):
            return model(x_num=dataset.data["x_num"][part][idx], x_cat=None).squeeze(-1)
        elif not ("x_cat" in dataset.data):
            return model(
                x_num=dataset.data["x_num"][part][idx],
                x_cat=dataset.data["x_bin"][part][idx].long(),
            ).squeeze(-1)
        elif not ("x_bin" in dataset.data):
            return model(
                x_num=dataset.data["x_num"][part][idx],
                x_cat=dataset.data["x_cat"][part][idx],
            ).squeeze(-1)
        else:
            return model(
                x_num=dataset.data["x_num"][part][idx],
                x_cat=torch.cat(
                    (
                        dataset.data["x_cat"][part][idx],
                        dataset.data["x_bin"][part][idx].long(),
                    ),
                    dim=-1,
                ),
            ).squeeze(-1)


    @torch.inference_mode()
    def evaluate(
        parts: list[PartKey], eval_batch_size: int
    ) -> tuple[dict[PartKey, Any], dict[PartKey, np.ndarray], int]:
        model.eval()
        predictions: dict[PartKey, np.ndarray] = {}
        for part in parts:
            while eval_batch_size:
                try:
                    predictions[part] = (
                        torch.cat(
                            [
                                apply_model(part, idx).mean(-1)
                                for idx in torch.arange(
                                    len(dataset.data["y"][part]),
                                    device=device,
                                ).split(eval_batch_size)
                            ]
                        )
                        .cpu()
                        .numpy()
                    )
                except RuntimeError as err:
                    if not lib.is_oom_exception(err):
                        raise
                    eval_batch_size //= 2
                    logger.warning(f"eval_batch_size = {eval_batch_size}")
                else:
                    break
            if not eval_batch_size:
                RuntimeError("Not enough memory even for eval_batch_size=1")
        if regression_label_stats is not None:
            predictions = {
                k: v * regression_label_stats.std + regression_label_stats.mean
                for k, v in predictions.items()
            }
        metrics = (
            dataset.task.calculate_metrics(predictions, report["prediction_type"])
            if lib.are_valid_predictions(predictions)
            else {x: {"score": -999999.0} for x in predictions}
        )
        return metrics, predictions, eval_batch_size

    def save_checkpoint() -> None:
        lib.dump_checkpoint(
            output,
            {
                "step": step,
                "model": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "generator": generator.get_state(),
                "random_state": delu.random.get_state(),
                "early_stopping": early_stopping,
                "report": report,
                "timer": timer,
                "training_log": training_log,
            }
            | (
                {}
                if lr_scheduler is None
                else {"lr_scheduler": lr_scheduler.state_dict()}
            ),
        )
        lib.dump_report(output, report)
        lib.backup_output(output)

    print()
    timer.run()
    while config["n_epochs"] == -1 or step // epoch_size < config["n_epochs"]:
        print(f"[...] {output} | {timer}")

        # >>>
        model.train()
        epoch_losses = []
        for batch_idx in tqdm(
            torch.randperm(
                len(dataset.data["y"]["train"]), generator=generator, device=device
            ).split(batch_size),
            desc=f"Epoch {step // epoch_size} Step {step}",
        ):
            #print(Y_train[batch_idx][:, None].repeat(1, config["model"]["k"]).shape)
            loss, new_chunk_size = lib.deep.zero_grad_forward_backward(
                optimizer,
                lambda idx: loss_fn(
                    apply_model("train", idx),
                    Y_train[idx][:, None].repeat(1, config["model"]["k"]),
                ),
                batch_idx,
                chunk_size or batch_size,
            )

            if parameter_statistics and (
                step % epoch_size == 0  # The first batch of the epoch.
                or step // epoch_size == 0  # The first epoch.
            ):
                for k, v in lib.deep.compute_parameter_stats(model).items():
                    writer.add_scalars(k, v, step, timer.elapsed())
                    del k, v

            if gradient_clipping_norm is not None:
                nn.utils.clip_grad.clip_grad_norm_(
                    model.parameters(), gradient_clipping_norm
                )
            optimizer.step()

            if lr_scheduler is not None:
                lr_scheduler.step()
            step += 1
            epoch_losses.append(loss.detach())
            if new_chunk_size and new_chunk_size < (chunk_size or batch_size):
                chunk_size = new_chunk_size
                logger.warning(f"chunk_size = {chunk_size}")

        epoch_losses = torch.stack(epoch_losses).tolist()
        mean_loss = statistics.mean(epoch_losses)
        metrics, predictions, eval_batch_size = evaluate(
            ["val", "test"], eval_batch_size
        )

        training_log.append(
            {"epoch-losses": epoch_losses, "metrics": metrics, "time": timer.elapsed()}
        )
        lib.print_metrics(mean_loss, metrics)
        writer.add_scalars("loss", {"train": mean_loss}, step, timer.elapsed())
        for part in metrics:
            writer.add_scalars(
                "score", {part: metrics[part]["score"]}, step, timer.elapsed()
            )

        if metrics["val"]["score"] > report["metrics"]["val"]["score"]:
            print("🌸 New best epoch! 🌸")
            report["best_step"] = step
            report["metrics"] = metrics
            save_checkpoint()
            lib.dump_predictions(output, predictions)

        early_stopping.update(metrics["val"]["score"])
        if early_stopping.should_stop() or not lib.are_valid_predictions(predictions):
            break

        print()
    report["time"] = str(timer)

    # >>> finish
    model.load_state_dict(lib.load_checkpoint(output)["model"])
    report["metrics"], predictions, _ = evaluate(
        ["train", "val", "test"], eval_batch_size
    )
    report["chunk_size"] = chunk_size
    report["eval_batch_size"] = eval_batch_size
    lib.dump_predictions(output, predictions)
    lib.dump_summary(output, lib.summarize(report))
    save_checkpoint()
    lib.finish(output, report)
    return report


if __name__ == "__main__":
    lib.configure_libraries()
    lib.run_MainFunction_cli(main)
