import sys
sys.path.append(".")

import os
import random
import logging
import inspect
import warnings

import yaml

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

from typing import TYPE_CHECKING
if TYPE_CHECKING:  # This is a hack to make VS Code intellisense work
    # from tensorflow.python import keras
    from keras.api._v2 import keras
else:
    keras = tf.keras

from nxcl.rich import Progress
from nxcl.config import load_config, ConfigDict

from missing import models, data, metrics
from missing import functional as F


def main(config):
    # Random seed
    if config.test.get("seed") is None:
        config.test.seed = 42

    random.seed(int(config.test.seed))
    np.random.seed(int(config.test.seed))
    os.environ["PYTHONHASHSEED"] = str(config.test.seed)
    tf.random.set_seed(int(config.test.seed))

    # Create model
    if not hasattr(models, config.model.name):
        raise ValueError(f"Unknown model: {config.model.name}")

    config.model.setdefault("output_activation", config.dataset.output_activation)
    config.model.setdefault("output_dims", config.dataset.output_dims)

    model_class = getattr(models, config.model.name)
    model_argnames = inspect.signature(model_class).parameters.keys()
    model: keras.Model = model_class(**{k: v for k, v in config.model.items() if k in model_argnames})
    model_config = model.get_config()

    for k in model_argnames:
        if k not in config.model:
            warnings.warn(f"Using default model argument: {k} = {model_config[k]}")

    # Create dataloader
    normalize_fn = data.get_dataset_normalize_fn(config.dataset.name)
    model_preprocess_fn = getattr(model, "data_preprocessing_fn", lambda: None)()
    valid_preprocess_fn = data.build_preprocess_fn(normalize_fn, model_preprocess_fn, class_weights=None)

    # Test

    if config.test.get("seed") is not None:
        random.seed(int(config.test.seed))
        np.random.seed(int(config.test.seed))
        os.environ["PYTHONHASHSEED"] = str(config.test.seed)
        tf.random.set_seed(int(config.test.seed))

    def select_data(data, labels):
        return data

    def select_labels(data, labels):
        return labels

    valid_iterator, _ = data.build_valid_iterator(
        dataset_name=config.dataset.name,
        batch_size=config.train.batch_size,
        preprocess_fn=valid_preprocess_fn,
    )
    test_iterator, _ = data.build_test_iterator(
        dataset_name=config.dataset.name,
        batch_size=config.train.batch_size,
        preprocess_fn=valid_preprocess_fn,
    )

    valid_labels = np.concatenate(list(tfds.as_numpy(valid_iterator.map(select_labels))), axis=0)
    test_labels  = np.concatenate(list(tfds.as_numpy(test_iterator.map(select_labels))), axis=0)

    x, y = next(iter(test_iterator))
    model(x, training=False)

    print(f"Load weight from {config.test.checkpoint}")
    model.load_weights(config.test.checkpoint)

    ensemble_size = config.test.ensemble_size

    with Progress() as p:
        valid_preds = np.mean(np.stack([
            model.predict(valid_iterator, verbose=0) for _ in p.trange(ensemble_size, description="Valid")
        ], axis=0), axis=0)

        test_preds = np.mean(np.stack([
            model.predict(test_iterator, verbose=0) for _ in p.trange(ensemble_size, description="Test")
        ], axis=0), axis=0)

    test_cal_preds, opt_temp = F.calibrate_prediction(valid_preds, valid_labels, test_preds)

    print(f"Optimal temperature: {opt_temp:.4f}")

    test_metrics = {}

    standard_metrics_fn = {
        "auprc": metrics.auprc,
        "auroc": metrics.auroc,
        "accuracy": metrics.accuracy,
    }
    uncertainty_metrics_fn = {
        "brier": metrics.brier,
        "bal_brier": metrics.bal_brier,
        "ece": metrics.ece,
        "bal_ece": metrics.bal_ece,
        "logloss": metrics.logloss,
        "bal_logloss": metrics.bal_logloss,
    }

    test_metrics["Valid/temperature"] = round(float(opt_temp), 4)

    for metric_name, metric_fn in standard_metrics_fn.items():
        score = metric_fn(test_labels, test_preds)
        test_metrics[f"Test/{metric_name}"] = round(score.item(), 4)
        print(f"Test/{metric_name}: {score:.4f}")

    for metric_name, metric_fn in uncertainty_metrics_fn.items():
        score = metric_fn(test_labels, test_preds)
        test_metrics[f"Test/{metric_name}"] = round(score.item(), 4)
        print(f"Test/{metric_name}: {score:.4f}")

    for metric_name, metric_fn in uncertainty_metrics_fn.items():
        score = metric_fn(test_labels, test_cal_preds)
        test_metrics[f"Test/cal_{metric_name}"] = round(score.item(), 4)
        print(f"Test/cal_{metric_name}: {score:.4f}")

    with open(os.path.join(*(config.test.checkpoint.split("/")[:-1]), "test_cal_metric.yaml"), "w") as f:
        yaml.dump(test_metrics, f, sort_keys=False)


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    group = parser.add_mutually_exclusive_group(required=True)
    group.add_argument("-d", "--train-dir", type=str)
    group.add_argument("-f", "--config-file", type=str)
    parser.add_argument("-c", "--checkpoint", "--test.checkpoint", type=str, dest="checkpoint")
    parser.add_argument("-tf", "--test-config-file", type=str)
    args, rest_args = parser.parse_known_args()

    if args.train_dir is not None:
        args.config_file = os.path.join(args.train_dir, "config.yaml")
        if args.checkpoint is None:
            args.checkpoint = os.path.join(args.train_dir, "weights.h5")
    else:
        if args.checkpoint is None:
            raise ValueError("Must specify checkpoint if --config-file is specified")

    rest_args.extend(["--test.checkpoint", args.checkpoint])

    config: ConfigDict = load_config(args.config_file)
    parser = argparse.ArgumentParser(conflict_handler="resolve")
    parser.add_argument("-c", "--checkpoint", "--test.checkpoint", type=str, dest="test.checkpoint")
    args = parser.parse_args(rest_args)

    config.update(vars(args))

    print("Configs:")
    for k, v in config.items(flatten=True):
        print(f"    {k:<25}: {v}")

    main(config)

    # try:
    #     main(config)
    #     exit_code = 0
    # except KeyboardInterrupt:
    #     print("Interrupted")
    #     exit_code = 1
    # except Exception as e:
    #     print(e)
    #     exit_code = 2
