import torch
import os
from tqdm import tqdm
import numpy as np
import random
import itertools
import pickle
import copy

def _convert_one_hot(y, c, device='cpu'):
    o = torch.zeros((y.shape[0], c), device=device)
    o[np.arange(y.shape[0]), y] = 1
    return o

class MyDataset:
    def __init__(self, X, y,
                 folds_for_validation,
                 folds_for_testing,
                 mode='classification',
                 c=None,
                 name=None):

        # Datasets are either classification or regression
        self.X = X
        self.y = y
        self.folds_for_validation = folds_for_validation
        self.folds_for_testing = folds_for_testing
        self.mode=mode
        self.c = c
        self.name = name

    def print_dataset_info(self):
        print (self.name, "\tN:", self.X.shape[0], "\td:", self.X.shape[1], "\tc:", self.c)

    def get_train_test_data_iterator(self, normalize=False, one_hot_y=True, center=False, device='cpu'):
        for train_fold, test_fold in zip(self.folds_for_testing['train_fold'], self.folds_for_testing['test_fold']):
            X_train = torch.tensor(self.X[train_fold], device=device)
            y_train = torch.tensor(self.y[train_fold], device=device)
            X_test = torch.tensor(self.X[test_fold], device=device)
            y_test = torch.tensor(self.y[test_fold], device=device)

            if one_hot_y:
                y_train = _convert_one_hot(y_train, self.c, device=device)
                y_test = _convert_one_hot(y_test, self.c, device=device)

            if center:
                X_train -= torch.mean(X_train, dim=0)
                X_test -= torch.mean(X_train, dim=0)
                
            if normalize:
                X_train /= torch.linalg.norm(X_train, dim=-1).reshape(-1, 1)
                X_test /= torch.linalg.norm(X_test, dim=-1).reshape(-1, 1)


            yield (X_train, y_train), (X_test, y_test)


    def get_train_val_data(self, normalize=False, one_hot_y=True, center=False, device='cpu'):
        train_fold, val_fold = self.folds_for_validation['train_fold'], self.folds_for_validation['val_fold']
        X_train = torch.tensor(self.X[train_fold], device=device)
        y_train = torch.tensor(self.y[train_fold], device=device)
        X_val = torch.tensor(self.X[val_fold], device=device)
        y_val = torch.tensor(self.y[val_fold], device=device)

        if one_hot_y:
            y_train = _convert_one_hot(y_train, self.c, device=device)
            y_val = _convert_one_hot(y_val, self.c, device=device)

        if center:
            X_train -= torch.mean(X_train, dim=0)
            X_val -= torch.mean(X_train, dim=0)

        if normalize:
            X_train /= torch.linalg.norm(X_train, dim=-1).reshape(-1, 1)
            X_val /= torch.linalg.norm(X_val, axis=-1).reshape(-1, 1)

        return (X_train, y_train), (X_val, y_val)


def _modular_arithmetic_data(p, splits={'train' : 0.5, 'test' : 0.1}, seed=None):
    if seed is not None:
        random.seed(seed)
    n = p**2
    input_pairs = list(itertools.product(range(p),repeat=2))
    random.shuffle(input_pairs)

    data = {}
    data['c'] = p # number of classes
    curr_n = 0
    for split in splits.keys():
        n_split = int(splits[split]*n)
        
        pairs_data = input_pairs[curr_n:curr_n+n_split]
        X_curr = torch.zeros((n_split,2*p))
        y_curr = torch.zeros((n_split), dtype=torch.int32)
        for idx, (i,j) in enumerate(pairs_data):
            X_curr[idx,i] = 1
            X_curr[idx,j+p] = 1
            y_curr[idx] = (i+j)%p

        data[split] = (X_curr, y_curr)
        curr_n += n_split
    return data

def get_arithmetic_dataset(p, splits={'train' : 0.5, 'val' : 0.1, 'test' : 0.1}, seed=None):
    data = _modular_arithmetic_data(p, splits, seed)
    c = data['c']
    X_train, y_train = data['train']
    X_val, y_val = data['val']
    X_test, y_test = data['test']
    X = np.asarray(torch.concatenate((X_train, X_val, X_test)))
    y = np.asarray(torch.concatenate((y_train, y_val, y_test)))
    n_train = y_train.shape[0]
    n_val = y_val.shape[0]
    n_test = y_test.shape[0]
    train_fold = range(n_train)
    val_fold = range(n_train, n_train+n_val)
    test_fold = range(n_train+n_val, n_train+n_val+n_test)
    folds_for_validation = {'train_fold' : train_fold, 'val_fold' : val_fold}
    folds_for_testing = {'train_fold' : [train_fold], 'test_fold' : [test_fold]}
    split_name = f'train{splits['train']}_val{splits['val']}_test{splits['test']}'
    dataset_obj = MyDataset(X=X, y=y,
                            folds_for_validation=folds_for_validation,
                            folds_for_testing=folds_for_testing,
                            mode='classification', c=c, name=f'modarith_p{p}_seed{seed}_{split_name}')
    return dataset_obj

def get_sparse_parity_datasets(datadir='data/sparse_parity'):
    return os.listdir(datadir)

def load_sparse_parity_dataset(dataset, datadir='data/sparse_parity'):
    data = pickle.load(open(datadir + '/' + dataset + '/data.pkl', 'rb'))
    c = data['c']
    X_train, y_train = data['train']
    X_val, y_val = data['val']
    X_test, y_test = data['test']
    X = np.asarray(torch.concatenate((X_train, X_val, X_test)))
    y = np.asarray(torch.concatenate((y_train, y_val, y_test)))
    n_train = y_train.shape[0]
    n_val = y_val.shape[0]
    n_test = y_test.shape[0]
    train_fold = range(n_train)
    val_fold = range(n_train, n_train+n_val)
    test_fold = range(n_train+n_val, n_train+n_val+n_test)
    folds_for_validation = {'train_fold' : train_fold, 'val_fold' : val_fold}
    folds_for_testing = {'train_fold' : [train_fold], 'test_fold' : [test_fold]}
    dataset_obj = MyDataset(X=X, y=y,
                            folds_for_validation=folds_for_validation,
                            folds_for_testing=folds_for_testing,
                            mode='classification', c=c, name=datadir + '/' + dataset)
    return dataset_obj

def get_multi_index_datasets(datadir='data/multi_index'):
    return get_sparse_parity_datasets(datadir=datadir)

def load_multi_index_dataset(dataset, datadir='data/multi_index'):
    return load_sparse_parity_dataset(dataset, datadir=datadir)


def get_tabular_datasets(n_tot_threshold=100000, datadir='data/tabular_benchmarks'):
    datasets = []
    dataset_list = list(enumerate(sorted(os.listdir(datadir))))
    for idx, dataset in tqdm(dataset_list):

        if not os.path.isdir(datadir + "/" + dataset):
            continue
        if not os.path.isfile(datadir + "/" + dataset + "/" + dataset + ".txt"):
            continue
        dic = dict()
        for k, v in map(lambda x : x.split(), open(datadir + "/" + dataset + "/" + dataset + ".txt", "r").readlines()):
            dic[k] = v
        c = int(dic["n_clases="])
        d = int(dic["n_entradas="])
        n_train = int(dic["n_patrons_entrena="])
        n_val = int(dic["n_patrons_valida="])
        n_train_val = int(dic["n_patrons1="])
        n_test = 0
        if "n_patrons2=" in dic:
            n_test = int(dic["n_patrons2="])
        n_tot = n_train_val + n_test
    
        if n_tot > n_tot_threshold:
            continue
        # if idx in [24, 102]:
        #     continue
        # print (idx, dataset, "\tN:", n_tot, "\td:", d, "\tc:", c)
        datasets.append(dataset)
    return datasets

def load_tabular_dataset(dataset, datadir='data/tabular_benchmarks'):
    if not os.path.isdir(datadir + "/" + dataset):
        raise Exception('Dataset folder not found')
    if not os.path.isfile(datadir + "/" + dataset + "/" + dataset + ".txt"):
        raise Exception('Dataset file not found')
    dic = dict()
    for k, v in map(lambda x : x.split(), open(datadir + "/" + dataset + "/" + dataset + ".txt", "r").readlines()):
        dic[k] = v
    c = int(dic["n_clases="])
    d = int(dic["n_entradas="])
    n_train = int(dic["n_patrons_entrena="])
    n_val = int(dic["n_patrons_valida="])
    n_train_val = int(dic["n_patrons1="])
    n_test = 0
    if "n_patrons2=" in dic:
        n_test = int(dic["n_patrons2="])
    n_tot = n_train_val + n_test
    
    # load data
    f = open(datadir + '/' + dataset + "/" + dic["fich1="], "r").readlines()[1:]
    X = np.asarray(list(map(lambda x: list(map(float, x.split()[1:-1])), f)))
    y = np.asarray(list(map(lambda x: int(x.split()[-1]), f)))
    

    # Hyperparameter Selection
    fold = list(map(lambda x: list(map(int, x.split())),
                    open(datadir + "/" + dataset + "/" + "conxuntos.dat", "r").readlines()))
    train_fold, val_fold = fold[0], fold[1]
    folds_for_validation={'train_fold' : fold[0], 'val_fold' : fold[1]}


    fold = list(map(lambda x: list(map(int, x.split())),
                    open(datadir + '/' + dataset + "/" + "conxuntos_kfold.dat", "r").readlines()))

    folds_for_testing = {}
    folds_for_testing['train_fold'] = []
    folds_for_testing['test_fold'] = []
    for repeat in range(4):
        folds_for_testing['train_fold'].append(fold[repeat * 2])
        folds_for_testing['test_fold'].append(fold[repeat * 2 + 1])
    

    dataset_obj = MyDataset(X=X,y=y,
                            folds_for_validation=folds_for_validation,
                            folds_for_testing=folds_for_testing,
                            mode='classification',
                            c=c,
                            name=datadir + '/' + dataset)
    return dataset_obj