import numpy as np
from sklearn.datasets import make_multilabel_classification


#%%

def gensynsep(K, nkey, nkeyonlow, nkeyonhigh, nbad, nbadon, labelnoise, T, seed = 1234):
    np.random.seed(seed)
    keysclass = []
    for i in range(K):
        keywords = np.zeros(nkey)
        inkeyon = np.random.randint(nkeyonlow, nkeyonhigh + 1)
        ikeyson = np.random.randint(0, nkey, inkeyon)
        keywords[ikeyson] = 1
        keysclass.append(keywords)
    keysclass = np.array(keysclass)   
    if np.unique(keysclass, axis = 0).shape[0] != K:
        print("è malissimo") # i.e. there are multiple classes with the same keywords.
        return 0
    X = []
    y = []
    outcomes = list(range(K))
    for t in range(T):
        i = np.random.randint(0, K, 1)[0]
        badons = np.random.randint(0, nbad, nbadon)
        badvec = np.zeros(nbad)
        badvec[badons] = 1
        xt = np.concatenate((keysclass[i, :].flatten(), badvec))
        X.append(xt)
        coin = np.random.choice([0, 1], 1, p = [1 - labelnoise, labelnoise])[0]
        if coin == 1:
            yt = np.random.choice(outcomes, 1, p = np.ones(K)/K)[0]
        else:
            yt = i
        y.append(yt)
        
    dataname = ("SynSep_K" + str(K) + "_keys" + str(nkey) + "_keylow" + str(nkeyonlow) + "_keyhigh" + 
                str(nkeyonhigh) + "_nbad" + str(nbad) + "_nbadon" + str(nbadon) + "_labelnoise" + str(labelnoise))
    return(np.array(y), np.array(X), dataname)



def gen_FIXED_multilabel_data(n: int, d: int, m: int, T: int, seed = 1234):
    np.random.seed(seed)
    X, Y = make_multilabel_classification(n_samples=T, n_features=n, n_classes=d, n_labels=m, random_state=42)
    count_ones = np.sum(Y == 1, axis=1)
    X = X[count_ones == m]
    Y = Y[count_ones == m]
    num = len(Y)
    
    while num < T:
        X_, Y_ = make_multilabel_classification(n_samples=T, n_features=n, n_classes=d, n_labels=m,random_state=42)
        count_ones = np.sum(Y_ == 1, axis=1)
        X_ = X_[count_ones == m]
        Y_ = Y_[count_ones == m]
        num += len(Y_)
        X = np.concatenate((X, X_), axis=0)
        Y = np.concatenate((Y, Y_), axis=0)
    
    return X[:T], Y[:T]

def int_to_binary_list(n, length=5):
    return [int(bit) for bit in format(n, f'0{length}b')]