import numpy as np
from skmultilearn.dataset import load_from_arff
from skmultilearn.model_selection import IterativeStratification

def read_arff(path, label_count):
    path_to_arff_file=path+".arff"
    X, Y = load_from_arff(
        path_to_arff_file,
        label_count=label_count,
        label_location="end",
        load_sparse=False,
    )
    return np.array(X.todense()), np.array(Y.todense())

def readData(path, label_count, numtrain):
    X, Y = read_arff(path, label_count)
    if(numtrain<1):
        numtrain = int(len(X)*numtrain)
    Xt,Yt = X[numtrain:], Y[numtrain:]
    X,Y = X[:numtrain],Y[:numtrain]
    return X,Y,Xt,Yt

def readData_CV(path, label_count, CV=10):
    X, Y = read_arff(path, label_count)
    k_fold = IterativeStratification(n_splits=CV, order=1)
    # for train, test in k_fold.split(X, Y):
    #     print(np.shape(train),np.shape(test))
    return k_fold, X, Y 

if __name__=="__main__":
    datasnames = ['Birds', 'Enron', 'Langlog', 'Medical', 'Scene', 'VirusGO', 'Yeast', 'Yelp', 'HumanGO', 'TMC2007_500']
    num_train = [432, 1151, 978, 659, 1618, 131, 1629, 7240, 2053, 19140]
    num_label = [19, 53, 75, 45, 6, 6, 14, 5, 14, 22]
    path = "data/"
    for dataIdx in range(1):
        X,Y = read_arff(path+datasnames[dataIdx], num_label[dataIdx])
        print(type(X),np.shape(X),np.shape(Y))
        X,Y,Xt,Yt = readData(path+datasnames[dataIdx], num_label[dataIdx], 2/3)
        print(type(X),np.shape(X),np.shape(Y),np.shape(Xt),np.shape(Yt))
        k_fold, X, Y = readData_CV(path+datasnames[dataIdx], num_label[dataIdx])
        for train, test in k_fold.split(X, Y):
            print(type(X),np.shape(train),np.shape(test))
        