import torch
import os
import numpy as np
import pandas as pd
import sys
import time
import statistics
import pickle
#import sklearn.datasets as ds
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler, StandardScaler, Normalizer
from sklearn.decomposition import PCA
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score, pairwise_distances
from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer
#from sklearn.neighbors import KNeighborsRegressor
from sklearn.neighbors import KNeighborsClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV
from sklearn.datasets import make_moons
#from sklearn.dummy import DummyRegressor
#from sklearn.metrics import mean_squared_error
import matplotlib.pyplot as plt
from pmlb import fetch_data
from sklearn.datasets import fetch_openml
from torch import linalg as LA
from numpy import linalg as la

import gc
from timeit import default_timer as timer
from datetime import timedelta

from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from args_loader import get_args

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('device: {}'.format(device))
args = get_args()
#print('args: {}'.format(args))
args.device = device
#print('after adding args: {}'.format(args))
#print('device: {}, device count:{}'.format(device,torch.cuda.device_count()))







def preprocess_dataset(dataset):
    if dataset == 'mnist35':
        X, y = fetch_openml('mnist_784', version=1, return_X_y=True,as_frame=False,data_home = data_path)
        y = y.astype(np.int64)
        # tested for 3 vs 5)
        index = ((y==3) | (y==5))
        X_b = X[index,:]
        y_b = y[index]
        y_b =np.array(list(map(lambda n: 1 if n == 3 else -1, y_b)))
        print('**** # of examples: {} ****'.format(y_b.shape))
        print('**** # of features: {} ****'.format(X_b.shape[1]))
        X_train, X_test, y_train, y_test = train_test_split(X_b, y_b,random_state=args.seed,train_size=0.8)
        if args.topk_type == 1:
            X_train, X_val, y_train, y_val = train_test_split(X_train,y_train, random_state=args.seed,test_size=0.25)

        data_pipeline = Pipeline([('pca', PCA(n_components=20)),
                 ('scl', StandardScaler()),
                 ('normalize',Normalizer(norm="l2")),
                 ])


    elif dataset == 'fmnist24':
        X, y = fetch_openml('Fashion-MNIST', version=1, return_X_y=True,as_frame=False,data_home = data_path)
        y = y.astype(np.int64)
        index = ((y==2) | (y==4))
        X_b = X[index,:]
        y_b = y[index]
        y_b =np.array(list(map(lambda n: 1 if n == 2 else -1, y_b)))
        print('**** # of examples: {} ****'.format(y_b.shape))
        print('**** # of features: {} ****'.format(X_b.shape[1]))
        X_train, X_test, y_train, y_test = train_test_split(X_b, y_b,random_state=args.seed,train_size=0.8)
        if args.topk_type == 1:
            X_train, X_val, y_train, y_val = train_test_split(X_train,y_train, random_state=args.seed,test_size=0.25)

        data_pipeline = Pipeline([('pca', PCA(n_components=20)),
                         ('scl', StandardScaler()),
                         ('normalize',Normalizer(norm="l2")),
                         ])


    elif dataset == 'fmnist02':
        X, y = fetch_openml('Fashion-MNIST', version=1, return_X_y=True,as_frame=False,data_home = data_path)
        y = y.astype(np.int64)
        index = ((y==0) | (y==2))
        X_b = X[index,:]
        y_b = y[index]
        y_b =np.array(list(map(lambda n: 1 if n == 0 else -1, y_b)))
        print('**** # of examples: {} ****'.format(y_b.shape))
        print('**** # of features: {} ****'.format(X_b.shape[1]))
        X_train, X_test, y_train, y_test = train_test_split(X_b, y_b,random_state=args.seed,train_size=0.8)
        if args.topk_type == 1:
            X_train, X_val, y_train, y_val = train_test_split(X_train,y_train, random_state=args.seed,test_size=0.25)
        #print('args: {}'.format(args))
        data_pipeline = Pipeline([('pca', PCA(n_components=20)),
                         ('scl', StandardScaler()),
                         ('normalize',Normalizer(norm="l2")),
                         ])


    elif dataset == 'wine':
        X, y = fetch_openml('wine', version=7, return_X_y=True,as_frame=False,data_home = data_path)
        y =np.array(list(map(lambda n: 1 if n == 'True' else -1, y)))
        print('**** # of examples: {} ****'.format(y.shape))
        print('**** # of features: {} ****'.format(X.shape[1]))
        X_train, X_test, y_train, y_test = train_test_split(X, y,random_state=args.seed,train_size=0.8)
        if args.topk_type == 1:
            X_train, X_val, y_train, y_val = train_test_split(X_train,y_train, random_state=args.seed,test_size=0.25)
        data_pipeline = Pipeline([('scl', StandardScaler()),
                             ('normalize',Normalizer(norm="l2")),
                             ])

    elif dataset == 'segment':
        X, y = fetch_openml('segment', version=2, return_X_y=True,as_frame=False,data_home = data_path)
        y =np.array(list(map(lambda n: 1 if n == 'P' else -1, y)))
        print('**** # of examples: {} ****'.format(y.shape))
        print('**** # of features: {} ****'.format(X.shape[1]))
        X_train, X_test, y_train, y_test = train_test_split(X, y,random_state=args.seed,train_size=0.8)
        if args.topk_type == 1:
            X_train, X_val, y_train, y_val = train_test_split(X_train,y_train, random_state=args.seed,test_size=0.25)
        data_pipeline = Pipeline([('scl', StandardScaler()),
                             ('normalize',Normalizer(norm="l2")),
                             ])

    elif dataset == 'wind':
        X, y = fetch_openml('wind', version=2, return_X_y=True,as_frame=False,data_home = data_path)
        y =np.array(list(map(lambda n: 1 if n == 'P' else -1, y)))
        print('**** # of examples: {} ****'.format(y.shape))
        print('**** # of features: {} ****'.format(X.shape[1]))
        X_train, X_test, y_train, y_test = train_test_split(X, y,random_state=args.seed,train_size=0.8)
        if args.topk_type == 1:
            X_train, X_val, y_train, y_val = train_test_split(X_train,y_train, random_state=args.seed,test_size=0.25)
        data_pipeline = Pipeline([('scl', StandardScaler()),
                                 ('normalize',Normalizer(norm="l2")),
                                 ])

    elif dataset == 'puma8NH':
        X, y = fetch_openml('puma8NH', version=2, return_X_y=True,as_frame=False,data_home = data_path)
        y =np.array(list(map(lambda n: 1 if n == 'P' else -1, y)))
        print('**** # of examples: {} ****'.format(y.shape))
        print('**** # of features: {} ****'.format(X.shape[1]))
        X_train, X_test, y_train, y_test = train_test_split(X, y,random_state=args.seed,train_size=0.8)
        if args.topk_type == 1:
            X_train, X_val, y_train, y_val = train_test_split(X_train,y_train, random_state=args.seed,test_size=0.25)
        data_pipeline = Pipeline([('scl', StandardScaler()),
                                 ('normalize',Normalizer(norm="l2")),
                                 ])


    elif dataset == 'cpu_small':
        X, y = fetch_openml('cpu_small', version=3, return_X_y=True,as_frame=False,data_home = data_path)
        y =np.array(list(map(lambda n: 1 if n == 'P' else -1, y)))
        print('**** # of examples: {} ****'.format(y.shape))
        print('**** # of features: {} ****'.format(X.shape[1]))
        X_train, X_test, y_train, y_test = train_test_split(X, y,random_state=args.seed,train_size=0.8)
        if args.topk_type == 1:
            X_train, X_val, y_train, y_val = train_test_split(X_train,y_train, random_state=args.seed,test_size=0.25)
        data_pipeline = Pipeline([('scl', StandardScaler()),
                                 ('normalize',Normalizer(norm="l2")),
                                 ])



    elif dataset == 'twomoons':
        X,y = make_moons(n_samples = 5000, noise=0.2, random_state=42)
        y =np.array(list(map(lambda n: 1 if n == 1 else -1, y)))
        print('**** # of examples: {} ****'.format(y.shape))
        print('**** # of features: {} ****'.format(X.shape[1]))
        X_train, X_test, y_train, y_test = train_test_split(X, y,random_state=args.seed,train_size=0.8)
        if args.topk_type == 1:
            X_train, X_val, y_train, y_val = train_test_split(X_train,y_train, random_state=args.seed,test_size=0.25)
        data_pipeline = Pipeline([('scl', StandardScaler()),
                                         ('normalize',Normalizer(norm="l2")),
                                         ])



    X_train_tr = data_pipeline.fit_transform(X_train)
    X_test_tr = data_pipeline.transform(X_test)
    if args.topk_type == 1:
        X_val_tr = data_pipeline.transform(X_val)

    X_train_tr = torch.tensor(X_train_tr)
    X_test_tr = torch.tensor(X_test_tr)
    y_train = torch.tensor(y_train)
    y_test = torch.tensor(y_test)

    if args.topk_type == 1:
        X_val_tr = torch.tensor(X_val_tr)
        return X_train_tr,X_val_tr,X_test_tr,y_train,y_test
    else:
        return X_train_tr,X_test_tr,y_train,y_test



# Create custom dataset and data loader
class Dataset(Dataset):
    def __init__(self,x,y):
        self.x = x
        self.y = y

    def __getitem__(self,index):
        return self.x[index],self.y[index]

    def __len__(self):
        return len(self.x)

class Dataset_x(Dataset):
    def __init__(self,x):
        self.x = x

    def __getitem__(self,index):
        return self.x[index]

    def __len__(self):
        return len(self.x)




class Eas (object):
    def __init__(self, **kwargs):
        '''
        The expand-and-sparsify object
        '''
        # w is 1d array for containing weights for binary classification
        self.w = None
        # th is a threshold for selecting top k in expecttaion
        self.th = None
        # counts stores number of training data points that are activated (set to 1) for the j-th feature
        self.counts = None
        # counts_test stores number of test data points that are activated (set to 1) for the j-th feature
        self.counts_test = None

        # topk_counts stores the number of activated topk bits for each training data points when topk
        # is applied in expectation
        self.topk_counts = None

        # topk_counts_test stores the number of activated topk bits for each test data points when topk
        # is applied in expectation
        self.topk_counts_test = None

        # batch size is the number of examples used at a time generate hash
        self.batch_size = None

        # This is a list of lists which contains the index of the activated patterns for each training example
        self.topk_info = []

        # This is a list of lists which contains the index of the activated patterns for each test example
        self.topk_info_test = []

        # binary parameter that decided whether topk_info will be calculated.
        self.want_topk_info = False

        #expansion factor
        self.exf=2
        self.k_size = None
        self.device = 'cpu'
        self.seed = 50
        #top k percentage
        self.topk_type = 0 # 0 for WTA (deterministic) 1 for in expectation
        if 'exf' in kwargs:
            self.exf = kwargs['exf']
        if 'k_size' in kwargs:
            self.k_size = kwargs['k_size']
        if 'device' in kwargs:
            self.device=kwargs['device']
        if 'seed' in kwargs:
            self.seed = kwargs['seed']
        if 'topk_type' in kwargs:
            self.topk_type = kwargs['topk_type']
        if 'batch_size' in kwargs:
            self.batch_size = kwargs['batch_size']

#         if 'want_topk_info' in kwargs:
#             self.want_topk_info = kwargs['want_topk_info']



    def top_k(self, matrix):
        '''
        Returns the top k values in each row of a 2d tensoreither deterministically or in expectation depending self.topk_type
        Parameters:
        matrix (2d torch.tensor):
        ****Two versions of topk are combined in to one****
        if self.topk_type=0 then simply returns the top k values in each row
        if self.topk_type=1 then return the hash as follows: if the i-th coordinate of row is greater than
        the i-th coordinate of self.th then
        set this coordinate to 1 else 0
        '''

        #compute topk deterministically
        if self.topk_type == 0:
            _, index = torch.topk(matrix,self.k_size, largest = True)
            matrix_binary = torch.zeros_like(matrix)
            matrix_binary[torch.arange(matrix.shape[0]), index.t()] = 1.

            return matrix_binary
        #compute topk in expectation
        else:
            repeated_th = self.th.repeat(matrix.shape[0],1)
            matrix_binary = (matrix>repeated_th).double()

            return matrix_binary


# This is modified simple implementation of EAS classifier
class Easc(Eas):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def fit(self,X_t,y,X_v=None):
        '''
        Generates 1d weight vector for binary classifier

        Parameters:
                    X_t(numpy.array): a numpy array of predictor varibles (train)
                    X_v(numpy.array): a numpy array of predictor varibles (validation)
                    y(numpy.array): a numpy array of labels

        We process batch_size of examples at a time

        '''
        # we will check if information for self_colun is properlty being stored in counts and topk_info



        rows, cols = X_t.shape
        y = y.reshape(-1,1)
        self.w = torch.zeros((cols*self.exf)).to(device=self.device)

        self.topk_info = []
        self.topk_info_test = []

        self.th = torch.Tensor().to(device=self.device)
        self.counts = torch.zeros((cols*self.exf)).to(device=self.device)
        self.topk_counts = torch.Tensor().to(device=self.device)


        X_t = X_t.to(device=self.device)
        if args.topk_type == 1:
            X_v = X_v.to(device=self.device)



        #print('Total no. of training examples: {}'.format(rows))


        if self.topk_type == 1:
            # Compute threshold (block by block)
            # reset random seed so that the same sequence of random matrix blocks are generated
            start = timer()

            torch.manual_seed(self.seed)

            quantile = 1.0-(self.k_size/(cols*self.exf))
            for i in range(self.exf):
                #generate separate sub prjection matrix block
                M = (torch.randn(cols,cols)/np.sqrt(cols)).double().to(device=self.device)

                hash_matrix_block = torch.mm(X_v,M).to(device=self.device)
                # update threshold vector
                self.th = torch.hstack((self.th,torch.quantile(hash_matrix_block,quantile,dim=0)))

            del hash_matrix_block
            gc.collect()
            torch.cuda.empty_cache()
            # print('k: {}, m(expansion factor): {}, Quantile (k/(m*d)): {}'.format(self.k_size,self.exf,quantile))
            # print('Threshold computaion done for topk in expectation')
            end = timer()
            #print('Threshold computation time: {}'.format(timedelta(seconds=end-start)))

        start = timer()

        #**** simplify code with dataloader

        train_dataset = Dataset(X_t,y)
        train_loader = DataLoader(dataset = train_dataset,batch_size = self.batch_size,shuffle=False)
        projection_pipeline = Pipeline([('normalize',Normalizer(norm="l2")),])
        for (X_slice,y_slice) in train_loader:
            b_size = X_slice.shape[0]
            #print('b_size: {}'.format(b_size))
            torch.manual_seed(self.seed)
            hash_matrix = torch.Tensor().to(device=self.device)
            for j in range(self.exf):
                #generate separate sub prjection matrix block
                M = torch.randn(cols,cols)

                if self.topk_type == 0:
                    M = M/LA.vector_norm(M, dim = 0).reshape(1,-1)
                else:
                    M = M/np.sqrt(cols)


                M = M.double().to(device=self.device)

                hash_matrix_block = torch.mm(X_slice,M).to(device=self.device)
                hash_matrix = torch.hstack((hash_matrix,hash_matrix_block))

            # process hash matrix (new, compact and efficient)

            hash_matrix = self.top_k(hash_matrix)
            self.counts = self.counts+torch.sum(hash_matrix,dim=0)
            self.topk_counts = torch.hstack((self.topk_counts,torch.sum(hash_matrix,dim=1)))
            repeated_y_slice = y_slice.squeeze().repeat(hash_matrix.shape[1],1).t()
            aaa = hash_matrix*repeated_y_slice.to(device=self.device)

            self.w = self.w + torch.sum(aaa,dim=0)

            # update topk_info
            if self.want_topk_info:
                for i_t in range(b_size):
                    self.topk_info.append(hash_matrix[i_t].nonzero().squeeze().cpu().tolist())


            #free hash_matrix memory
            del hash_matrix
            gc.collect()
            torch.cuda.empty_cache()
            del hash_matrix_block
            gc.collect()
            torch.cuda.empty_cache()


        #final self.w update
        #we need to divide w by counts. If an entry in counts is zero the corresponding entry
        # in w must be zero as well. However we can not divide by zero. So we up date counts
        # by adding 1 to all the entries that have self.counts value zero.

        self.counts = self.counts+(self.counts==0).double()
        self.w = self.w/self.counts

        #print('fit method done!')

        end = timer()
        #print('Time taken for fit method: {}'.format(timedelta(seconds=end-start)))

        return





    def predict(self,X):
        rows, cols = X.shape
        X = X.to(device=self.device)
        self.topk_counts_test = torch.Tensor().to(device=self.device)
        self.counts_test = torch.zeros((cols*self.exf)).to(device=self.device)
        if self.topk_type == 1:
            reduced_k_size = int(self.k_size/10.)
            if reduced_k_size == 0:
                reduced_k_size = 5
            #print('*** reduced_k_size: {} ***'.format(reduced_k_size))
        output = torch.Tensor().to(device=self.device)
        start = timer()

        #**** simplify code with dataloader
        test_dataset = Dataset_x(X)
        test_loader = DataLoader(dataset = test_dataset,batch_size = self.batch_size,shuffle=False)
        for (X_slice) in test_loader:
            b_size = X_slice.shape[0]
            #print('b_size: {}'.format(b_size))
            torch.manual_seed(self.seed)
            hash_matrix = torch.Tensor().to(device=self.device)
            for j in range(self.exf):
                #generate separate sub prjection matrix block
                #M = torch.randn(cols,cols).double().to(device=self.device)
                M = torch.randn(cols,cols)
                # If data is unit norm the following three lines are not required

                if self.topk_type == 0:
                    M = M/LA.vector_norm(M, dim = 0).reshape(1,-1)
                else:
                    M = M/np.sqrt(cols)
                M = M.double().to(device=self.device)

                hash_matrix_block = torch.mm(X_slice,M).to(device=self.device)
                hash_matrix = torch.hstack((hash_matrix,hash_matrix_block))


            # process hash matrix (new, compact and efficient)

            hash_matrix = self.top_k(hash_matrix)
            #print('hash_matrix size: {}'.format(hash_matrix.shape))

            if self.topk_type == 1:
                # Update hash matrix to consider the k/10 closest directions

                torch.manual_seed(self.seed)
                masked_distance = torch.Tensor().to(device=self.device)
                for j in range(self.exf):
                    M = (torch.randn(cols,cols)/np.sqrt(cols)).double().to(device=self.device)
                    hash_matrix_slice = hash_matrix[:,j*cols:(j+1)*cols]
                    inner_prod = torch.mm(X_slice,M)*hash_matrix_slice
                    mask = (inner_prod == 0.)
                    inner_prod[mask] = -float("inf")
                    vector_norm = LA.vector_norm(M, dim = 0).square()
                    repeat_vector_norm = vector_norm.repeat(X_slice.shape[0],1)
                    masked_distance_block = 1. + (repeat_vector_norm - 2.*inner_prod)
                    masked_distance = torch.hstack((masked_distance,masked_distance_block))


                _, distance_index = torch.topk(masked_distance,reduced_k_size, largest = False)
                matrix_binary = torch.zeros_like(hash_matrix)
                matrix_binary[torch.arange(masked_distance.shape[0]), distance_index.t()] = 1.
                hash_matrix = matrix_binary



            self.counts_test = self.counts_test+torch.sum(hash_matrix,dim=0)
            self.topk_counts_test = torch.hstack((self.topk_counts_test,torch.sum(hash_matrix,dim=1)))

            data_block_output = torch.sparse.mm(hash_matrix.double(),torch.unsqueeze(self.w,1)).squeeze()


            output = torch.hstack((output,data_block_output))


            # update topk_info_test
            if self.want_topk_info:
                for i_t in range(b_size):
                    self.topk_info_test.append(hash_matrix[i_t].nonzero().squeeze().tolist())


            #free hash_matrix memory
            del hash_matrix
            gc.collect()
            torch.cuda.empty_cache()
            del hash_matrix_block
            gc.collect()
            torch.cuda.empty_cache()


        if self.topk_type == 0:
            output = output/self.k_size
        else:
            output = output/reduced_k_size



        output = torch.sign(output)

        #print('predict method done!')

        end = timer()




        output = output.cpu().detach()

        return output


def compute_single_run(exp_factor,X_train_tr,y_train,X_test_tr,y_test,seed,X_valid=None):
    dim = X_train_tr.shape[1]

    k = int(dim*np.log(dim*exp_factor))
    aa = Easc(exf = exp_factor, k_size = k, batch_size=args.batch_size,topk_type = args.topk_type,want_topk_info=True,device = args.device, seed=seed)
    aa.fit(X_train_tr,y_train,X_valid)
    predicted_y_test = aa.predict(X_test_tr)
    acc = accuracy_score(y_test,predicted_y_test)

    return exp_factor,acc


def compute_multiple_runs(exp_factor_list,X_train_tr,y_train,X_test,y_test,X_valid=None):
    acc_list = {}
    for exp_factor in exp_factor_list:
        acc_list[exp_factor] = []
        acc_over_runs = []
        for item in range(args.no_runs):
            exp_factor,acc = compute_single_run(exp_factor,X_train_tr,y_train,X_test_tr,y_test,args.seed+item,X_valid)
            acc_over_runs.append(acc)
            #print('Run # {} done'.format(item+1))
        mean_acc = statistics.mean(acc_over_runs)
        std_acc = statistics.pstdev(acc_over_runs)
        acc_list[exp_factor].append(mean_acc)
        acc_list[exp_factor].append(std_acc)
        #print('exp_factor: {}, mean_acc: {}, std_acc: {}'.format(exp_factor,mean_acc,std_acc))
        print('expansion factor: {} done.'.format(exp_factor))
    return acc_list










#evaluate on each dataset
for dataset in args.datasets:
    data_path = os.path.join(args.data_path, 'datasets',dataset)
    result_path = os.path.join(args.data_path,'results',dataset)
    if not os.path.exists(data_path):
        os.makedirs(data_path)
    if not os.path.exists(result_path):
        os.makedirs(result_path)
        #print("The new directory is created!")
    #print('dataset: {}, datapath: {}, result_path: {}'.format(dataset,data_path,result_path))
    print('****** Dataset: {} ******'.format(dataset))

    if args.topk_type == 1:
        X_train_tr,X_val_tr,X_test_tr,y_train,y_test = preprocess_dataset(dataset)
    else:
        X_train_tr,X_test_tr,y_train,y_test = preprocess_dataset(dataset)
    #print('X_train_tr shape: {}'.format(X_train_tr.shape))

    # knn classification
    n_neighbors = 1
    knn = KNeighborsClassifier(n_neighbors=n_neighbors)
    knn.fit(X_train_tr,y_train)
    y_hat_knn = knn.predict(X_test_tr)
    knn_1_accuracy = accuracy_score(y_test,y_hat_knn)
    print('knn 1 acc: {}'.format(knn_1_accuracy))
    n_neighbors = 10
    knn = KNeighborsClassifier(n_neighbors=n_neighbors)
    knn.fit(X_train_tr,y_train)
    y_hat_knn = knn.predict(X_test_tr)
    knn_10_accuracy = accuracy_score(y_test,y_hat_knn)
    print('knn 10 acc: {}'.format(knn_10_accuracy))

    # random forest classification
    param_grid = [
    #{'n_estimators': [250, 500, 750, 1000], 'n_jobs':[-1]},]
    {'n_estimators': [250,500,750,1000], 'n_jobs':[-1]},]
    rf_clf = RandomForestClassifier()
    grid_search = GridSearchCV(rf_clf, param_grid, cv=3,scoring='roc_auc')
    grid_search.fit(X_train_tr,y_train)
    final_rf_model = grid_search.best_estimator_
    final_rf_model.fit(X_train_tr,y_train)
    y_hat_rf = final_rf_model.predict(X_test_tr)
    rf_accuracy = accuracy_score(y_test,y_hat_rf)
    print('rf_acc: {}'.format(rf_accuracy))




    exp_factor_list_large = [50,100,500, 1000,1500,2000,2500,3000] # for mnist35, fmnist24, wine, wind, cpu_small, segment
    exp_factor_list_small = [10,20,30,40,50,60,80,90,100] #for twomoons, puma8NH
    if dataset in ['twomoons','puma8NH']:
        exp_factor_list = exp_factor_list_small
    else:
        exp_factor_list = exp_factor_list_large
    if args.topk_type == 1:
        acc_list = compute_multiple_runs(exp_factor_list,X_train_tr,y_train,X_test_tr,y_test,X_val_tr)
    else:
        acc_list = compute_multiple_runs(exp_factor_list,X_train_tr,y_train,X_test_tr,y_test)

    result_file = os.path.join(result_path,dataset+str(args.topk_type)+'_results')
    knn_result_file = os.path.join(result_path,dataset+str(args.topk_type)+'_knn_results')
    rf_result_file = os.path.join(result_path,dataset+str(args.topk_type)+'_rf_results')

    #print results
    for exp_factor in exp_factor_list:
        print('expansion_factor: {}, mean accuracy: {:.4f}, standard_dev: {:.4f}'.format(exp_factor,acc_list[exp_factor][0],acc_list[exp_factor][1]))


    dbfile = open(result_file, 'wb')
    # source, destination
    pickle.dump(acc_list, dbfile)
    dbfile.close()

    knn_dbfile = open(knn_result_file, 'wb')
    pickle.dump({'knn_1_acc':knn_1_accuracy,'knn_10_acc':knn_10_accuracy}, knn_dbfile)
    knn_dbfile.close()

    rf_dbfile = open(rf_result_file, 'wb')
    pickle.dump({'rf_acc':rf_accuracy}, rf_dbfile)
    rf_dbfile.close()


# Plot results of each dataset
# for dataset in args.datasets:
#     result_path = os.path.join(args.data_path,'results',dataset)
#     result_file = os.path.join(result_path,dataset+str(args.topk_type)+'_results')
#     knn_result_file = os.path.join(result_path,dataset+str(args.topk_type)+'_knn_results')
#     rf_result_file = os.path.join(result_path,dataset+str(args.topk_type)+'_rf_results')
#     exp_factor_list =[]
#     acc_list = []
#     std_list = []
#     dbfile = open(result_file, 'rb')
#     db = pickle.load(dbfile)
#     for key in db:
#         print('exp_factor: {}, mean_acc: {}, std_acc:{}'.format(key,db[key][0],db[key][1]))
#         exp_factor_list.append(key)
#         acc_list.append(db[key][0])
#         std_list.append(db[key][1])
#     dbfile.close()
#
#     knn_dbfile = open(knn_result_file, 'rb')
#     knn_db = pickle.load(knn_dbfile)
#     knn_1_acc_arr = knn_db['knn_1_acc']*np.ones_like(np.array(exp_factor_list))
#     knn_10_acc_arr = knn_db['knn_10_acc']*np.ones_like(np.array(exp_factor_list))
#     knn_dbfile.close()
#
#     rf_dbfile = open(rf_result_file, 'rb')
#     rf_db = pickle.load(rf_dbfile)
#     rf_acc_arr = rf_db['rf_acc']*np.ones_like(np.array(exp_factor_list))
#     knn_dbfile.close()
#
#     print('exp_fac_list: {}'.format(exp_factor_list))
#     print('acc_list: {}'.format(acc_list))
#     print('std_list: {}'.format(std_list))
#     exp_factor_arr =np.array(exp_factor_list)
#     acc_arr = np.array(acc_list)
#     std_arr = np.array(std_list)
#     plt.plot(exp_factor_arr, acc_arr, 'Db-')
#     plt.fill_between(exp_factor_arr, acc_arr-std_arr, acc_arr+std_arr,color='b',alpha=0.2)
#     plt.plot(exp_factor_arr,knn_1_acc_arr,'k-')
#     plt.plot(exp_factor_arr,knn_10_acc_arr,'m-')
#     plt.plot(exp_factor_arr,rf_acc_arr,'g-')
#     plt.savefig(os.path.join(args.data_path,'results',dataset,str(args.topk_type)+'mygraph.jpg'))
#     plt.show()
