import numpy as np
import scipy.io as sio
import matplotlib.pyplot as plt

from utils_knn import *

if __name__ == '__main__':
    mat_contents = sio.loadmat("./data/bbcsport-emd_tr_te_split.mat")

    idx_train = mat_contents["TR"]-1
    idx_test = mat_contents["TE"]-1
    y = mat_contents["Y"][0]

    X = mat_contents["X"][0]
    w = mat_contents["BOW_X"][0]

    dataset = "BBC"
    path = "./results_BBC/"

    rhos = [0.000001, 0.000005, 0.00001, 0.00005, 0.0001, 0.0005, 0.001, 0.005, 0.01, 0.1, 1.0]
    ntry = 1
    nproj = 500

    L_mean_rsw = np.zeros((len(rhos), ntry))

    for k in range(ntry):
        for i, rho in enumerate(rhos):
            d_sw = np.loadtxt(path + "d_projs"+str(nproj)+"_rsw_unnormalize_BBC_rho1"+str(rho)+"_rho2"+str(rho)+"_k"+str(k))
            L_acc = get_acc_knn(d_sw, y, idx_train, idx_test)
            L_mean_rsw[i, k] = np.mean(L_acc)

    print(f"NPROJ = {nproj}")       
    print("USOT", np.mean(L_mean_rsw, axis=-1), np.std(L_mean_rsw, axis=-1))

    rhos = [0.000001, 0.000005, 0.00001, 0.00005, 0.0001, 0.0005, 0.001, 0.005, 0.01, 0.1, 1.0]
    ntry = 1
    nproj = 50

    L_mean_sopt = np.zeros((len(rhos), ntry))
    for k in range(ntry):
        for i, rho in enumerate(rhos):
            d_sopt = np.loadtxt(path + "d_projs"+str(nproj)+"_sopt_unnormalize_BBC_rho1"+str(rho)+"_rho2"+str(rho)+"_k"+str(k))
            L_acc = get_acc_knn(d_sopt, y, idx_train, idx_test)
            L_mean_sopt[i, k] = np.mean(L_acc)

    print(f"\n\nNPROJ = {nproj}")       
    print("SOPT", np.mean(L_mean_sopt, axis=-1), np.std(L_mean_sopt, axis=-1))
