import math
import shutil
import statistics
import sys
from pathlib import Path
from typing import Any, Literal

import delu
import numpy as np
import rtdl_num_embeddings
import scipy
import torch
import torch.nn as nn
from loguru import logger
from tabpfn import TabPFNClassifier, TabPFNRegressor
from torch import Tensor
from tqdm import tqdm
from typing_extensions import NotRequired, TypedDict

if __name__ == '__main__':
    _cwd = Path.cwd()
    assert _cwd.joinpath(
        '.git'
    ).exists(), 'The script must be run from the root of the repository'
    sys.path.append(str(_cwd))
    del _cwd

import lib
import lib.data
import lib.env
from lib import KWArgs, PartKey


class Config(TypedDict):
    seed: int
    data: KWArgs


def main(
    config: Config | str | Path,
    output: None | str | Path = None,
    *,
    force: bool = False,
) -> None | lib.JSONDict:
    # >>> Start
    config = Path(config)
    config, output = lib.load_config(config), config.with_suffix('')
    config = config.pop("space")
    # space = config.pop("space")
    # for k, v in space.items():
    #     config[k] = v
    # config, output = lib.check(config, output, config_type=Config)
    config, output = lib.check(config, output, config_type=Config)
    if not lib.start(output, force=force):
        return None

    lib.print_config(config)  # type: ignore[code]
    delu.random.seed(config['seed'])
    # device = lib.get_device()
    device = "cpu"
    report = lib.create_report(main, config)

    # >>> Data
    dataset = lib.data.build_dataset(**config['data'])
    if dataset.task.is_regression:
        dataset.data['y'], regression_label_stats = lib.data.standardize_labels(
            dataset.data['y']
        )
    else:
        regression_label_stats = None

    # Convert binary features to categorical features.
    if dataset.n_bin_features > 0:
        x_bin = dataset.data.pop('x_bin')
        # Remove binary features with just one unique value in the training set.
        # This must be done, otherwise, the script will fail on one specific dataset
        # from the "why" benchmark.
        n_bin_features = x_bin['train'].shape[1]
        good_bin_idx = [
            i for i in range(n_bin_features) if len(np.unique(x_bin['train'][:, i])) > 1
        ]
        if len(good_bin_idx) < n_bin_features:
            x_bin = {k: v[:, good_bin_idx] for k, v in x_bin.items()}

        if dataset.n_cat_features == 0:
            dataset.data['x_cat'] = {
                part: np.zeros((dataset.size(part), 0), dtype=np.int64)
                for part in x_bin
            }
        for part in x_bin:
            dataset.data['x_cat'][part] = np.column_stack(
                [dataset.data['x_cat'][part], x_bin[part].astype(np.int64)]
            )
        del x_bin
    dataset = dataset.to_torch(device)
    Y_train = dataset.data['y']['train'].to(
        torch.long if dataset.task.is_classification else torch.float
    )

    model = TabPFNRegressor() if dataset.task.is_regression else TabPFNClassifier()
    x = []
    if 'x_num' in dataset.data:
        x.append(dataset.data['x_num']["train"])
    if 'x_cat' in dataset.data:
        x.append(dataset.data['x_cat']["train"])
    X_train = torch.cat(x, 1)
    y_train = dataset.data['y']["train"]
    model.fit(X_train, y_train)
    
    report['prediction_type'] = 'labels' if dataset.task.is_regression else 'probs'
    # model.to(device)
    # if lib.is_dataparallel_available():
    #     model = nn.DataParallel(model)

    def apply_model(part: PartKey, idx: Tensor) -> Tensor:
        x = []
        if 'x_num' in dataset.data:
            x.append(dataset.data['x_num'][part][idx])
        if 'x_cat' in dataset.data:
            x.append(dataset.data['x_cat'][part][idx])
        X_test = torch.cat(x, 1)
        if dataset.task.is_regression:
            return torch.tensor(model.predict(X_test)).float()
        elif dataset.task.is_binclass:
            return torch.tensor(model.predict_proba(X_test)).float()[:, 1]
        else:
            return torch.tensor(model.predict_proba(X_test)).float()

    def evaluate(
        parts: list[PartKey], eval_batch_size: int
    ) -> tuple[dict[PartKey, Any], dict[PartKey, np.ndarray], int]:
        predictions: dict[PartKey, np.ndarray] = {}
        for part in parts:
            while eval_batch_size:
                try:
                    predictions[part] = (
                        torch.cat(
                            [
                                apply_model(part, idx)
                                for idx in torch.arange(
                                    len(dataset.data['y'][part]),
                                    device=device,
                                ).split(eval_batch_size)
                            ]
                        )
                        .cpu()
                        .numpy()
                    )
                except RuntimeError as err:
                    if not lib.is_oom_exception(err):
                        raise
                    eval_batch_size //= 2
                    logger.warning(f'eval_batch_size = {eval_batch_size}')
                else:
                    break
            if not eval_batch_size:
                RuntimeError('Not enough memory even for eval_batch_size=1')
        if regression_label_stats is not None:
            predictions = {
                k: v * regression_label_stats.std + regression_label_stats.mean
                for k, v in predictions.items()
            }
        metrics = (
            dataset.task.calculate_metrics(predictions, report['prediction_type'])
            if lib.are_valid_predictions(predictions)
            else {x: {'score': -999999.0} for x in predictions}
        )
        return metrics, predictions, eval_batch_size

    print()
    eval_batch_size = 32768
    report['metrics'], predictions, eval_batch_size = evaluate(
        ['train', 'val', 'test'], eval_batch_size
    )
    lib.dump_predictions(output, predictions)
    lib.dump_summary(output, lib.summarize(report))
    lib.finish(output, report)
    return report


if __name__ == '__main__':
    lib.configure_libraries()
    lib.run(main)
