import math
import statistics
import sys
from pathlib import Path
from typing import Any, Literal

import delu
import numpy as np
import rtdl_num_embeddings
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.tensorboard
from loguru import logger
from torch import Tensor
from tqdm import tqdm
from typing_extensions import NotRequired, TypedDict

if __name__ == '__main__':
    _cwd = Path.cwd()
    assert _cwd.joinpath(
        '.git'
    ).exists(), 'The script must be run from the root of the repository'
    sys.path.append(str(_cwd))
    del _cwd

import lib
import lib.data
import lib.deep
from lib import KWArgs, PartKey


class Residual_block(nn.Module):
    def __init__(self, d_in, d, dropout):
        super().__init__()
        self.linear0 = nn.Linear(d_in, d)
        self.Linear1 = nn.Linear(d, d_in)
        self.bn = nn.BatchNorm1d(d_in)
        self.dropout = nn.Dropout(dropout)
        self.activation = nn.ReLU()

    def forward(self, x):
        z = self.bn(x)
        z = self.linear0(z)
        z = self.activation(z)
        z = self.dropout(z)
        z = self.Linear1(z)
        # z=x+z
        return z


# This implementation is based on the official implementation of ModernNCA:
# https://github.com/qile2000/LAMDA-TALENT/blob/2d7a166772ca2e79d0fa8b5f73a1ae6dbb8c5f09/LAMDA-TALENT/model/models/modernNCA.py
class ModernNCA(nn.Module):
    def __init__(
        self,
        *,
        n_num_features: int,
        cat_cardinalities: list[int],
        n_classes: None | int,
        #
        dim: int,
        dropout: int,
        d_block: None | int = None,
        d_block_multiplier: None | float = None,
        n_blocks: int,
        bins: None | list[Tensor],
        num_embeddings: None | dict = None,
        temperature: float = 1.0,
        sample_rate: float = 0.8,
    ) -> None:
        if d_block is None:
            assert d_block_multiplier is not None
            d_block = int(d_block_multiplier * dim)
        else:
            assert d_block_multiplier is None
        super().__init__()

        if n_num_features == 0:
            assert bins is None
            self.num_module = None
        elif num_embeddings is None:
            assert bins is None
            self.num_module = None
        else:
            if bins is None:
                self.num_module = lib.deep.make_module(
                    **num_embeddings, n_features=n_num_features
                )
            else:
                assert num_embeddings['type'].startswith('PiecewiseLinearEmbeddings')
                self.num_module = lib.deep.make_module(**num_embeddings, bins=bins)

        self.cat_module = (
            lib.deep.OneHotEncoding0d(cat_cardinalities) if cat_cardinalities else None
        )

        self.d_in = n_num_features * (
            1 if num_embeddings is None else num_embeddings['d_embedding']
        ) + sum(cat_cardinalities)
        self.n_classes = n_classes
        self.dim = dim
        self.dropout = dropout
        self.d_block = d_block
        self.n_blocks = n_blocks
        self.T = temperature
        self.sample_rate = sample_rate
        if n_blocks > 0:
            self.post_encoder = nn.Sequential()
            for i in range(n_blocks):
                name = f'ResidualBlock{i}'
                self.post_encoder.add_module(name, self.make_layer())
            self.post_encoder.add_module('bn', nn.BatchNorm1d(dim))
        self.encoder = nn.Linear(self.d_in, dim)
        # self.bn=nn.BatchNorm1d(dim)

    def make_layer(self):
        block = Residual_block(self.dim, self.d_block, self.dropout)
        return block

    def _pre_encoder(self, x_num: None | Tensor, x_cat: None | Tensor) -> Tensor:
        x = []
        if x_num is not None:
            x.append(x_num if self.num_module is None else self.num_module(x_num))
        if x_cat is None:
            assert self.cat_module is None
        else:
            assert self.cat_module is not None
            x.append(self.cat_module(x_cat))
        x = torch.column_stack([x_.flatten(1, -1) for x_ in x])
        return x

    def forward(
        self,
        *,
        x_num: None | Tensor = None,
        x_cat: None | Tensor = None,
        y: None | Tensor,
        candidate_x_num: None | Tensor = None,
        candidate_x_cat: None | Tensor = None,
        candidate_y: Tensor,
        is_train: bool,
    ):
        if is_train:
            data_size = len(candidate_y)
            retrival_size = int(data_size * self.sample_rate)
            sample_idx = torch.randperm(data_size)[:retrival_size]
            candidate_x_num = (
                None if candidate_x_num is None else candidate_x_num[sample_idx]
            )
            candidate_x_cat = (
                None if candidate_x_cat is None else candidate_x_cat[sample_idx]
            )
            candidate_y = candidate_y[sample_idx]

        x = self._pre_encoder(x_num, x_cat)
        candidate_x = self._pre_encoder(candidate_x_num, candidate_x_cat)
        dtype = x.dtype if x.dtype != torch.int64 else torch.float32

        if self.n_blocks > 0:
            candidate_x = self.post_encoder(self.encoder(candidate_x.to(dtype)))
            x = self.post_encoder(self.encoder(x.to(dtype)))
        else:
            candidate_x = self.encoder(candidate_x.to(dtype))
            x = self.encoder(x.to(dtype))
        if is_train:
            assert y is not None
            candidate_x = torch.cat([x, candidate_x])
            candidate_y = torch.cat([y, candidate_y])
        else:
            assert y is None

        if self.n_classes is not None:
            candidate_y = F.one_hot(candidate_y, self.n_classes).to(dtype)
        elif len(candidate_y.shape) == 1:
            candidate_y = candidate_y.unsqueeze(-1)

        # The NCA-related computations are always performed at least in float32.
        if x.dtype != torch.float64:
            x = x.float()
            candidate_x = candidate_x.float()
            candidate_y = candidate_y.float()

        # calculate distance
        distances = torch.cdist(x, candidate_x, p=2)
        distances = distances / self.T
        # remove the label of training index
        if is_train:
            distances = distances.clone().fill_diagonal_(torch.inf)
        distances = F.softmax(-distances, dim=-1)
        logits = torch.mm(distances, candidate_y)
        # print(logits.shape)
        # print(logits[:, 1])
        eps = 1e-7
        if self.n_classes is not None:
            logits = torch.log(logits + eps)
        # return logits.squeeze()
        return logits.to(dtype)


class Config(TypedDict):
    seed: int
    data: KWArgs
    bins: NotRequired[KWArgs]
    model: KWArgs
    optimizer: KWArgs
    batch_size: int
    patience: int
    n_epochs: int
    gradient_clipping_norm: NotRequired[float]
    amp_dtype: NotRequired[Literal['float16', 'bfloat16']]


def main(
    config: Config | str | Path,
    output: None | str | Path = None,
    *,
    force: bool = False,
) -> None | lib.JSONDict:
    # >>> Start
    assert set(config) >= Config.__required_keys__
    assert set(config) <= Config.__required_keys__ | Config.__optional_keys__
    if not lib.start(output, force=force):
        return None

    lib.show_config(config)  # type: ignore[code]
    output = Path(output)
    delu.random.seed(config['seed'])
    device = lib.get_device()
    report = lib.create_report(config)  # type: ignore[code]

    # >>> Data
    dataset = lib.data.build_dataset(**config['data'])

    if dataset.task.is_regression:
        dataset.data['y'], regression_label_stats = lib.data.standardize_labels(
            dataset.data['y']
        )
    else:
        regression_label_stats = None

    # Convert binary features to categorical features.
    if dataset.n_bin_features > 0:
        x_bin = dataset.data.pop('x_bin')
        # Remove binary features with just one unique value in the training set.
        # This must be done, otherwise, the script will fail on one specific dataset
        # from the "why" benchmark.
        n_bin_features = x_bin['train'].shape[1]
        good_bin_idx = [
            i for i in range(n_bin_features) if len(np.unique(x_bin['train'][:, i])) > 1
        ]
        if len(good_bin_idx) < n_bin_features:
            x_bin = {k: v[:, good_bin_idx] for k, v in x_bin.items()}

        if dataset.n_cat_features == 0:
            dataset.data['x_cat'] = {
                part: np.zeros((dataset.size(part), 0), dtype=np.int64)
                for part in x_bin
            }
        for part in x_bin:
            dataset.data['x_cat'][part] = np.column_stack(
                [dataset.data['x_cat'][part], x_bin[part].astype(np.int64)]
            )
        del x_bin

    dataset = dataset.to_torch(device)
    Y_train = dataset.data['y']['train'].to(
        torch.long if dataset.task.is_classification else torch.float
    )

    # >>> model
    if 'bins' in config:
        # Compute the bins for PiecewiseLinearEncoding and PiecewiseLinearEmbeddings.
        compute_bins_kwargs = (
            {
                'y': Y_train.to(
                    torch.long if dataset.task.is_classification else torch.float
                ),
                'regression': dataset.task.is_regression,
                'verbose': True,
            }
            if 'tree_kwargs' in config['bins']
            else {}
        )
        bin_edges = rtdl_num_embeddings.compute_bins(
            dataset['x_num']['train'], **config['bins'], **compute_bins_kwargs
        )
        logger.info(f'Bin counts: {[len(x) - 1 for x in bin_edges]}')
    else:
        bin_edges = None
    model = ModernNCA(
        n_num_features=dataset.n_num_features,
        cat_cardinalities=dataset.compute_cat_cardinalities(),
        n_classes=dataset.task.try_compute_n_classes(),
        **config['model'],
        bins=bin_edges,
    )
    report['n_parameters'] = lib.deep.get_n_parameters(model)
    logger.info(f'n_parameters = {report["n_parameters"]}')
    report['prediction_type'] = 'labels' if dataset.task.is_regression else 'probs'
    model.to(device)

    # >>> Training
    step = 0
    train_size = dataset.size('train')
    batch_size = config['batch_size']
    report['epoch_size'] = epoch_size = math.ceil(train_size / batch_size)
    eval_batch_size = 32768
    chunk_size = None
    train_indices = torch.arange(train_size, device=device)

    optimizer = lib.deep.make_optimizer(
        **config['optimizer'],
        params=lib.deep.make_parameter_groups(model),
    )
    gradient_clipping_norm = config.get('gradient_clipping_norm')
    loss_fn = F.nll_loss if dataset.task.is_classification else F.mse_loss

    batch_generator = torch.Generator(device).manual_seed(config['seed'])
    timer = delu.tools.Timer()
    early_stopping = delu.tools.EarlyStopping(config['patience'], mode='max')
    training_log = []
    writer = torch.utils.tensorboard.SummaryWriter(output)  # type: ignore[code]

    amp_dtype = config.get('amp_dtype')
    if amp_dtype is not None:
        amp_dtype = getattr(torch, amp_dtype)
    amp_enabled = amp_dtype is not None
    scaler = torch.cuda.amp.GradScaler() if amp_dtype is torch.float16 else None  # type: ignore[code]
    logger.info(f'AMP enabled: {amp_dtype is not None}')

    def get_Xy(part: PartKey, idx: None | Tensor) -> tuple[dict[str, Tensor], Tensor]:
        batch = (
            {
                key: dataset.data[key][part]
                for key in ['x_num', 'x_bin', 'x_cat']
                if key in dataset
            },
            dataset.data['y'][part],
        )
        return (
            batch
            if idx is None
            else ({k: v[idx] for k, v in batch[0].items()}, batch[1][idx])
        )

    def apply_model(part: str, idx: Tensor, training: bool) -> Tensor:
        # Currently, this argument is not used. However, it can be useful for other
        # variations of the model.
        del training
        x, y = get_Xy(part, idx)

        candidate_indices = train_indices
        # NOTE: is_train and training are different things, as explained here:
        # https://github.com/yandex-research/tabular-dl-tabr/issues/5#issuecomment-1726063188
        is_train = part == 'train'
        if is_train:
            # NOTE: here, the training batch is removed from the candidates.
            # It will be added back inside the model's forward pass.
            candidate_indices = candidate_indices[~torch.isin(candidate_indices, idx)]
        candidate_x, candidate_y = get_Xy(
            'train',
            # This condition is here for historical reasons, it could be just
            # the unconditional `candidate_indices`.
            None if candidate_indices is train_indices else candidate_indices,
        )

        with torch.autocast(  # type: ignore[code]
            device.type,
            enabled=amp_enabled,
            dtype=torch.bfloat16 if amp_enabled else None,
        ):
            return model(
                **x,
                y=y if is_train else None,
                **{f'candidate_{k}': v for k, v in candidate_x.items()},
                candidate_y=candidate_y,
                is_train=is_train,
            ).squeeze(-1)

    @torch.inference_mode()
    def evaluate(
        parts: list[PartKey], eval_batch_size: int
    ) -> tuple[dict[PartKey, Any], dict[PartKey, np.ndarray], int]:
        model.eval()
        predictions: dict[PartKey, np.ndarray] = {}
        for part in parts:
            while eval_batch_size:
                try:
                    predictions[part] = (
                        torch.cat(
                            [
                                apply_model(part, idx, False)
                                for idx in torch.arange(
                                    dataset.size(part), device=device
                                ).split(eval_batch_size)
                            ]
                        )
                        .cpu()
                        .numpy()
                    )
                except RuntimeError as err:
                    if not lib.is_oom_exception(err):
                        raise
                    eval_batch_size //= 2
                    delu.cuda.free_memory()
                    logger.warning(f'eval_batch_size = {eval_batch_size}')
                else:
                    break
            if not eval_batch_size:
                RuntimeError('Not enough memory even for eval_batch_size=1')
        delu.cuda.free_memory()
        if dataset.task.is_regression:
            assert regression_label_stats is not None
            predictions = {
                k: v * regression_label_stats.std + regression_label_stats.mean
                for k, v in predictions.items()
            }
        else:
            predictions = {k: np.exp(v) for k, v in predictions.items()}
            if dataset.task.is_binclass:
                predictions = {k: v[:, 1] for k, v in predictions.items()}
        metrics = (
            dataset.task.calculate_metrics(predictions, report['prediction_type'])
            if lib.are_valid_predictions(predictions)
            else {x: {'score': lib.WORST_SCORE} for x in predictions}
        )
        return metrics, predictions, eval_batch_size

    def save_checkpoint() -> None:
        lib.dump_checkpoint(
            output,
            {
                'step': step,
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'batch_generator': batch_generator.get_state(),
                'random_state': delu.random.get_state(),
                'early_stopping': early_stopping,
                'report': report,
                'timer': timer,
                'training_log': training_log,
            }
            | ({} if scaler is None else {'scaler': scaler.state_dict()}),
        )
        lib.dump_report(output, report)
        lib.backup_output(output)

    print()
    timer.run()
    while config['n_epochs'] == -1 or step // epoch_size < config['n_epochs']:
        print(f'[...] {output} | {timer}')

        model.train()
        epoch_losses = []
        for batch_idx in tqdm(
            torch.randperm(
                len(dataset.data['y']['train']),
                generator=batch_generator,
                device=device,
            ).split(batch_size),
            desc=f'Epoch {step // epoch_size} Step {step}',
        ):
            loss, new_chunk_size = lib.deep.zero_grad_forward_backward(
                optimizer,
                lambda idx: loss_fn(apply_model('train', idx, True), Y_train[idx]),
                batch_idx,
                chunk_size or batch_size,
            )

            if gradient_clipping_norm is not None:
                if scaler is not None:
                    scaler.unscale_(optimizer)
                nn.utils.clip_grad.clip_grad_norm_(
                    model.parameters(), gradient_clipping_norm
                )
            if scaler is None:
                optimizer.step()
            else:
                scaler.step(optimizer)
                scaler.update()

            step += 1
            epoch_losses.append(loss.detach())
            if new_chunk_size and new_chunk_size < (chunk_size or batch_size):
                chunk_size = new_chunk_size
                logger.warning(f'chunk_size = {chunk_size}')

        epoch_losses = torch.stack(epoch_losses).tolist()
        mean_loss = statistics.mean(epoch_losses)
        metrics, predictions, eval_batch_size = evaluate(
            ['val', 'test'], eval_batch_size
        )

        training_log.append(
            {'epoch-losses': epoch_losses, 'metrics': metrics, 'time': timer.elapsed()}
        )
        lib.print_metrics(mean_loss, metrics)
        writer.add_scalars('loss', {'train': mean_loss}, step, timer.elapsed())
        for part in metrics:
            writer.add_scalars(
                'score', {part: metrics[part]['score']}, step, timer.elapsed()
            )

        if (
            'metrics' not in report
            or metrics['val']['score'] > report['metrics']['val']['score']
        ):
            print('🌸 New best epoch! 🌸')
            report['best_step'] = step
            report['metrics'] = metrics
            save_checkpoint()
            lib.dump_predictions(output, predictions)

        early_stopping.update(metrics['val']['score'])
        if early_stopping.should_stop() or not lib.are_valid_predictions(predictions):
            break

        print()
    report['time'] = str(timer)

    # >>>
    if lib.get_checkpoint_path(output).exists():
        model.load_state_dict(lib.load_checkpoint(output)['model'])
    report['metrics'], predictions, _ = evaluate(
        ['train', 'val', 'test'], eval_batch_size
    )
    report['chunk_size'] = chunk_size
    report['eval_batch_size'] = eval_batch_size
    lib.dump_predictions(output, predictions)
    lib.dump_summary(output, lib.summarize(report))
    save_checkpoint()
    lib.finish(output, report)
    return report


if __name__ == '__main__':
    lib.configure_libraries()
    lib.run(main)
