

import numpy as np

#%%

def gensynsep(K, nkey, nkeyonlow, nkeyonhigh, nbad, nbadon, labelnoise, T, seed = 1234):
    np.random.seed(seed)
    keysclass = []
    for i in range(K):
        keywords = np.zeros(nkey)
        inkeyon = np.random.randint(nkeyonlow, nkeyonhigh + 1)
        ikeyson = np.random.randint(0, nkey, inkeyon)
        keywords[ikeyson] = 1
        keysclass.append(keywords)
    keysclass = np.array(keysclass)   
    if np.unique(keysclass, axis = 0).shape[0] != K:
        return("è malissimo") # i.e. there are multiple classes with the same keywords.
    X = []
    y = []
    outcomes = list(range(K))
    for t in range(T):
        i = np.random.randint(0, K, 1)[0]
        badons = np.random.randint(0, nbad, nbadon)
        badvec = np.zeros(nbad)
        badvec[badons] = 1
        xt = np.concatenate((keysclass[i, :].flatten(), badvec))
        X.append(xt)
        coin = np.random.choice([0, 1], 1, p = [1 - labelnoise, labelnoise])[0]
        if coin == 1:
            yt = np.random.choice(outcomes, 1, p = np.ones(K)/K)[0]
        else:
            yt = i
        y.append(yt)
        
    dataname = ("SynSep_K" + str(K) + "_keys" + str(nkey) + "_keylow" + str(nkeyonlow) + "_keyhigh" + 
                str(nkeyonhigh) + "_nbad" + str(nbad) + "_nbadon" + str(nbadon) + "_labelnoise" + str(labelnoise))
    return(np.array(y), np.array(X), dataname)