import numpy as np

traintest = 'train'
labelvecs = np.load('data/labelvecs/labelvecs_{}.npy'.format(traintest))

Y = np.zeros((len(labelvecs), 10))
vecX = []

alert = []

def choose_candidate(labelvec, candidates):
    while candidates:
        chosen = candidates[np.random.randint(len(candidates))]
        for ii, observation in enumerate(labelvec):
            if ii == chosen: continue
            if sum(labelvec[ii] == labelvec[chosen]) == 11:
                candidates.remove(chosen)
                break
        if chosen in candidates:
            break
    if not candidates:
        return None
    else:
        return chosen

for idx in range(len(labelvecs)):
    labelvec = labelvecs[idx].reshape(10, 11)
    assert len(np.where(np.sum(labelvec[:, 6:], axis=1)==1)[0]) == 10
    assert len(np.where(np.sum(labelvec[:, :6], axis=1)==1)[0]) == 10
    
    colorblind_traits = labelvec[:, :6]
    colorblind_max_traits = np.sum(colorblind_traits, axis=0)
    winners_colorblind = np.argwhere(colorblind_max_traits == np.amax(colorblind_max_traits)).reshape(-1)
    colorblind_candidates = set([i for i, v in enumerate(np.where(colorblind_traits)[1]) if v in winners_colorblind])

    shapeblind_traits = labelvec[:, 6:]
    shapeblind_max_traits = np.sum(shapeblind_traits, axis=0)
    winners_shapeblind = np.argwhere(shapeblind_max_traits == np.amax(shapeblind_max_traits)).reshape(-1)
    shapeblind_candidates = set([i for i, v in enumerate(np.where(shapeblind_traits)[1]) if v in winners_shapeblind])

    candidates = list(colorblind_candidates & shapeblind_candidates)
    if not candidates:
        if np.random.random() > 0.5:
            candidates = list(colorblind_candidates)
        else:
            candidates = list(shapeblind_candidates)
    
    chosen = choose_candidate(labelvec, candidates)
    if not chosen:
        candidates = list(colorblind_candidates | shapeblind_candidates)
        chosen = choose_candidate(labelvec, candidates) 
        if not chosen:
            candidates = list(set(list(range(10))) - set(candidates))
            chosen = choose_candidate(labelvec, candidates)
            if not chosen:
                alert.append(idx)
                chosen = np.random.randint(10)

    Y[idx][chosen] = 1.
    vecX.append(labelvec[chosen])
    if idx % 1000 == 0:
        print(idx)

vecX = np.array(vecX)

np.save('data/vecX/vecX_{}.npy'.format(traintest), vecX)
np.save('data/Y_{}.npy'.format(traintest), Y)