
import torch
import numpy as np
import tqdm

# Finding pairs of embeddings where one supersedes the former.
# 
# s \ del(x) + add(x) \ del(y) + add(y) = s \ del(y) + add(y)
# 
# since del(y) \cap add(y) = \emptyset , this means
# 
# add(x) \subset add(y), del(x) \subset del(y)


def redundancy(model,data,topk=None):
    print("computing effects...")
    effects = model.word_effects.bool()
    effects = effects.detach().cpu().numpy()
    print(effects.dtype)
    dels = effects[0]
    adds = effects[1]
    names = data._idx_to_word
    
    # If n is omitted or None, most_common() returns all elements in the counter.
    selected = data.vocab.most_common(topk)
    idxs = np.array([ data.word_to_idx(word) for word, count in selected])

    dels        = dels[idxs]
    adds        = adds[idxs]
    names       = names[idxs]

    V, E = dels.shape

    superset_count = []
    subset_count   = []
    indices = []
    for i in tqdm.tqdm(range(V)):
        if (i%100)==0:
            print("vocab",i,"/",V)
        # broadcasting [E] x [V,1,E] -> [V,E] -> [V]
        dels_superset = (dels[i] >= dels).all(axis=1)
        adds_superset = (adds[i] >= adds).all(axis=1)

        superset = np.bitwise_and(dels_superset, adds_superset)
        superset[i] = 0
        superset_count.append(superset.sum())

        dels_subset = (dels[i] <= dels).all(axis=1)
        adds_subset = (adds[i] <= adds).all(axis=1)
        subset = np.bitwise_and(dels_subset, adds_subset)
        subset[i] = 0
        subset_count.append(subset.sum())

        indices1 = np.nonzero(superset)[0]
        indices1 = indices1[:100]            # reduce the number of entries to write
        indices0 = np.full_like(indices1, i)
        indices2 = np.stack((indices0,indices1),axis=1)
        indices.append(indices2)

    indices = np.concatenate(indices,axis=0)

    superset_nonzero = np.nonzero(superset_count)
    path   = model.local(f"redundant_superset_count_{topk}.csv")
    print(f"writing to {path}...")
    with open(path, 'wb') as f:
        np.savetxt(f,np.stack((names,superset_count),axis=1)[superset_nonzero],"%s", encoding='utf8')

    subset_nonzero = np.nonzero(subset_count)
    path   = model.local(f"redundant_subset_count_{topk}.csv")
    print(f"writing to {path}...")
    with open(path, 'wb') as f:
        np.savetxt(f,np.stack((names,subset_count),axis=1)[subset_nonzero],"%s", encoding='utf8')

    pairs = names[indices]

    path   = model.local(f"redundant_pairs_{topk}.csv")
    print(f"writing to {path}...")
    with open(path, 'wb') as f:
        np.savetxt(f,pairs,"%s", encoding='utf8')

