import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from scipy.linalg import orth
import os
import knnie

np.random.seed(10)

if __name__ == "__main__":
    # Hyperparameters
    D = 15
    list_d = np.array([1, 5, 10, 15])
    n_dims = len(list_d)

    list_n = np.linspace(5, 5000, num=500)
    n_nsamples = len(list_n)

    train_risks = np.zeros((n_dims, n_nsamples))
    test_risks = np.zeros((n_dims, n_nsamples))
    gen_errors = np.zeros((n_dims, n_nsamples))
    smi_bound = np.zeros((n_dims, n_nsamples))


    for i in range(n_dims):
        d = list_d[i]
        for j in range(n_nsamples):
            n = list_n[j]
            train_risks[i, j] = D - d / n 
            test_risks[i, j] = D + d / n
            gen_errors[i, j] = 2 * d / n
            lmbda = d * (1 + 1 / n) ** 2 + (D - d)
            mi = d / 2 * np.log(n / (n - 1))
            smi_bound[i, j] = 2 * np.sqrt(lmbda * mi)

    # Plot results
    plt.figure(figsize=(4, 5))
    cmap = matplotlib.colormaps.get_cmap('Set1')
    for i in range(n_dims):
        d = list_d[i]
        plt.plot(list_n, gen_errors[i], '--', label="Gen. error, d={}".format(d), c=cmap(i), lw=2.)
        plt.plot(list_n, smi_bound[i], '-', label="SMI bound, d={}".format(d), c=cmap(i), lw=2.)
        if d == D:
            plt.plot(list_n, D*np.sqrt(2*((list_n+1)/list_n)**2 * np.log(list_n/(list_n-1))), '-.', c="k", label="Bu et al.", lw=2.6, alpha=.5)

    plt.xlabel(r"$n$", fontsize=14)
    # plt.xticks(list_n)
    plt.ylabel("Generalization error", fontsize=14)
    plt.xscale('log')
    plt.yscale('log')
    plt.legend(fontsize=11)
    plt.grid()
    namefig = "analytical_gme_generr.pdf"
    plt.tight_layout()
    plt.savefig(namefig)

    # plt.figure()
    # for i in range(n_dims):
    #     d = list_d[i]
    #     plt.plot(list_n, train_risks[i], '-x', label="Empirical risk, d={}".format(d))
    #     plt.plot(list_n, test_risks[i], '--x', label="True risk, d = {}".format(d))
    
    # plt.xlabel(r"$n$")
    # plt.xticks(range(1,D))
    # plt.ylabel("Risks")
    # plt.yscale('log')
    # plt.legend()
    # namefig = "analytical_gme_risks.pdf"
    # plt.savefig(namefig)
