import numpy as np

from dataset import N101, N201
from Graph_GP import graphGP_fit
from perf_metrics import *

import tensorflow as tf
import matplotlib.pyplot as plt

def plot_regression(Y_train, Y_test, mean_tr, mean_test, var_tr, var_test):
    # Plotting
    fig, axes = plt.subplots(1, 2, figsize=(18, 8), sharex=True, sharey=True)

    # Non-ARD Predictions
    axes[0].scatter(tf.squeeze(Y_train), tf.squeeze(mean_tr))
    axes[1].scatter(tf.squeeze(Y_test), tf.squeeze(mean_test))

    # Training Data
    axes[0].errorbar(
        tf.squeeze(Y_train),
        tf.squeeze(mean_tr),
        yerr=tf.squeeze(np.sqrt(var_tr)),
        fmt='o',
        ecolor='blue',
        color='blue',
        alpha=0.6,
        label='Training Data',
        capsize=3
    )

    # Test Data
    axes[1].errorbar(
        tf.squeeze(Y_test),
        tf.squeeze(mean_test),
        yerr=tf.squeeze(np.sqrt(var_test)),
        fmt='o',
        ecolor='blue',
        color='blue',
        alpha=0.6,
        label='Test Data',
        capsize=3
    )

    # Plot y = x Reference Line
    min_y = min(np.min(Y_train), np.min(Y_test))
    max_y = max(np.max(Y_train), np.max(Y_test))
    axes[0].plot([min_y, max_y], [min_y, max_y], 'k--', label='Ideal Fit')

    # Labels and Title for Non-ARD
    axes[0].set_xlabel('True $y$', fontsize=14)
    axes[0].set_ylabel('Predicted $y$', fontsize=14)
    axes[0].set_title('Train Predictions', fontsize=16)
    axes[0].legend(fontsize=12)
    axes[0].grid(True, linestyle='--', alpha=0.5)

    # Plot y = x Reference Line
    axes[1].plot([min_y, max_y], [min_y, max_y], 'k--', label='Ideal Fit')

    # Labels and Title for ARD
    axes[1].set_xlabel('True $y$', fontsize=14)
    axes[1].set_title('Test Predictions', fontsize=16)
    axes[1].legend(fontsize=12)
    axes[1].grid(True, linestyle='--', alpha=0.5)

    # Overall Title and Layout
    fig.suptitle(f'Graph GP {dataset_name} dataset, train size {Y_train.shape[0]} test size {Y_test.shape[0]}', fontsize=18)
    fig.tight_layout(rect=[0, 0, 1, 0.96])
    plt.show()


if __name__ == "__main__":

    dataset_name = "N101"
    metric = 0    # 0; val, 1: test
    task = "cifar10-valid"
    noisy = False
    SCALE = True
    seed = 0
    n_train = 50
    n_test = 400
    n_sample = n_train + n_test
    kernel_exp = True

    # Preparing dataset and parameters
    if dataset_name == "N101":
        dataset = N101(path="data/")
        task = "cifar10-valid"
    elif dataset_name == "N201":
        dataset = N201(path="data/")
        assert task in ['cifar10-valid', 'cifar100', 'ImageNet16-120']
    else:
        raise NotImplementedError("Dataset not implemented yet.")
    idxs = np.random.randint(423624, size=n_sample)
    G = dataset.index_sampling(idxs)
    ys = np.array([dataset.eval(graph=g, score="log-err", task=task, noisy=noisy)[metric] for g in G])[..., None].astype(
        "float64")

    # Train test split
    G_train, G_test = G[:n_train], G[-n_test:]
    if SCALE:
        # Scale y
        ys_mean, ys_std = ys.mean(), ys.std()
        Y_ = tf.convert_to_tensor(ys)
        ys = (Y_ - ys_mean) / ys_std
    Y_train, Y_test = ys[:n_train], ys[-n_test:]

    # fit GP using G
    kernel = dataset.get_kernel(kernel_exp)
    GPmodel = graphGP_fit(G_train, Y_train, kernel)

    # predict
    mean_tr, var_tr = GPmodel.predict_f(G_train)
    mean_test, var_test = GPmodel.predict_f(G_test)

    if SCALE:
        # Unscale y
        def unscale(y):
            return y * ys_std + ys_mean
        def unscale_var(var):
            std = np.sqrt(var)
            std *= ys_std
            return np.square(std)
        Y_train = unscale(Y_train)
        Y_test = unscale(Y_test)
        mean_tr = unscale(mean_tr)
        mean_test = unscale(mean_test)
        var_tr = unscale_var(var_tr)
        var_test = unscale_var(var_test)
    else:
        Y_test = Y_test.flatten()

    print('RMSE: ', rmse(mean_test, Y_test), end=" ")
    print('Spearman: ', spearman(mean_test, Y_test), end=" ")
    print('NLL: ', nll(mean_test, np.sqrt(var_test), Y_test), end=" ")
    print('Average prediction error', average_error(mean_test, Y_test))

    plot_regression(Y_train, Y_test, mean_tr, mean_test, var_tr, var_test)

