import os
from tqdm import tqdm
import numpy as np

def get_tabular_datasets(n_tot_threshold=100000, datadir='data'):
    datasets = []
    dataset_list = list(enumerate(sorted(os.listdir(datadir))))
    for idx, dataset in tqdm(dataset_list):

        if not os.path.isdir(datadir + "/" + dataset):
            continue
        if not os.path.isfile(datadir + "/" + dataset + "/" + dataset + ".txt"):
            continue
        dic = dict()
        for k, v in map(lambda x : x.split(), open(datadir + "/" + dataset + "/" + dataset + ".txt", "r").readlines()):
            dic[k] = v
        c = int(dic["n_clases="])
        d = int(dic["n_entradas="])
        n_train = int(dic["n_patrons_entrena="])
        n_val = int(dic["n_patrons_valida="])
        n_train_val = int(dic["n_patrons1="])
        n_test = 0
        if "n_patrons2=" in dic:
            n_test = int(dic["n_patrons2="])
        n_tot = n_train_val + n_test
    
        if n_tot > n_tot_threshold:
            continue
        # if idx in [24, 102]:
        #     continue
        # print (idx, dataset, "\tN:", n_tot, "\td:", d, "\tc:", c)
        datasets.append(dataset)
    return datasets

def print_dataset_info(dataset, datadir='data'):
    if not os.path.isdir(datadir + "/" + dataset):
        raise Exception('Dataset folder not found')
    if not os.path.isfile(datadir + "/" + dataset + "/" + dataset + ".txt"):
        raise Exception('Dataset file not found')
    dic = dict()
    for k, v in map(lambda x : x.split(), open(datadir + "/" + dataset + "/" + dataset + ".txt", "r").readlines()):
        dic[k] = v
    c = int(dic["n_clases="])
    d = int(dic["n_entradas="])
    n_train = int(dic["n_patrons_entrena="])
    n_val = int(dic["n_patrons_valida="])
    n_train_val = int(dic["n_patrons1="])
    n_test = 0
    if "n_patrons2=" in dic:
        n_test = int(dic["n_patrons2="])
    n_tot = n_train_val + n_test

    print (dataset, "\tN:", n_tot, "\td:", d, "\tc:", c)

def hyperparam_select_dataset(dataset, hyperparam_selection_method, device='cpu', datadir='data'):
    if not os.path.isdir(datadir + "/" + dataset):
        raise Exception('Dataset folder not found')
    if not os.path.isfile(datadir + "/" + dataset + "/" + dataset + ".txt"):
        raise Exception('Dataset file not found')
    dic = dict()
    for k, v in map(lambda x : x.split(), open(datadir + "/" + dataset + "/" + dataset + ".txt", "r").readlines()):
        dic[k] = v
    c = int(dic["n_clases="])
    d = int(dic["n_entradas="])
    n_train = int(dic["n_patrons_entrena="])
    n_val = int(dic["n_patrons_valida="])
    n_train_val = int(dic["n_patrons1="])
    n_test = 0
    if "n_patrons2=" in dic:
        n_test = int(dic["n_patrons2="])
    n_tot = n_train_val + n_test
    
    # load data
    f = open(datadir + '/' + dataset + "/" + dic["fich1="], "r").readlines()[1:]
    X = np.asarray(list(map(lambda x: list(map(float, x.split()[1:-1])), f)))
    y = np.asarray(list(map(lambda x: int(x.split()[-1]), f)))
    

    # Hyperparameter Selection
    fold = list(map(lambda x: list(map(int, x.split())),
                    open(datadir + "/" + dataset + "/" + "conxuntos.dat", "r").readlines()))
    train_fold, val_fold = fold[0], fold[1]

    method_hyperparams, best_train_traj = hyperparam_selection_method(X, y, train_fold, val_fold, c, device=device)
    return method_hyperparams, best_train_traj

def test_dataset(dataset, hyperparams, test_method, device='cpu', datadir='data'):
    if not os.path.isdir(datadir + "/" + dataset):
        raise Exception('Dataset folder not found')
    if not os.path.isfile(datadir + "/" + dataset + "/" + dataset + ".txt"):
        raise Exception('Dataset file not found')
    dic = dict()
    for k, v in map(lambda x : x.split(), open(datadir + "/" + dataset + "/" + dataset + ".txt", "r").readlines()):
        dic[k] = v
    c = int(dic["n_clases="])
    d = int(dic["n_entradas="])
    n_train = int(dic["n_patrons_entrena="])
    n_val = int(dic["n_patrons_valida="])
    n_train_val = int(dic["n_patrons1="])
    n_test = 0
    if "n_patrons2=" in dic:
        n_test = int(dic["n_patrons2="])
    n_tot = n_train_val + n_test
    
    # load data
    f = open(datadir + '/' + dataset + "/" + dic["fich1="], "r").readlines()[1:]
    X = np.asarray(list(map(lambda x: list(map(float, x.split()[1:-1])), f)))
    y = np.asarray(list(map(lambda x: int(x.split()[-1]), f)))

    # 4-fold cross-validating
    avg_acc = 0.0
    fold = list(map(lambda x: list(map(int, x.split())),
                    open(datadir + '/' + dataset + "/" + "conxuntos_kfold.dat", "r").readlines()))
    print("Training")
    for repeat in range(4):
        train_fold, test_fold = fold[repeat * 2], fold[repeat * 2 + 1]

        acc = test_method(X[train_fold], y[train_fold], X[test_fold], y[test_fold], c, device=device, **hyperparams)
        avg_acc += 0.25 * acc

    return avg_acc

def get_dataset_train_val_data(dataset, datadir='data'):
    if not os.path.isdir(datadir + "/" + dataset):
        raise Exception('Dataset folder not found')
    if not os.path.isfile(datadir + "/" + dataset + "/" + dataset + ".txt"):
        raise Exception('Dataset file not found')
    dic = dict()
    for k, v in map(lambda x : x.split(), open(datadir + "/" + dataset + "/" + dataset + ".txt", "r").readlines()):
        dic[k] = v
    # load data
    f = open(datadir + '/' + dataset + "/" + dic["fich1="], "r").readlines()[1:]
    X = np.asarray(list(map(lambda x: list(map(float, x.split()[1:-1])), f)))
    y = np.asarray(list(map(lambda x: int(x.split()[-1]), f)))
    

    # Hyperparameter Selection
    fold = list(map(lambda x: list(map(int, x.split())),
                    open(datadir + "/" + dataset + "/" + "conxuntos.dat", "r").readlines()))
    train_fold, val_fold = fold[0], fold[1]
    return {'X' : X, 'y' : y, 'train_fold' : train_fold, 'val_fold' : val_fold, 'c' : int(dic["n_clases="])}