import numpy as np
import os

parent_dir = '/root/main/ranking/sim'

num_q = 100

def preproc(vecs:np.array, noise=True):
    vecs = vecs - vecs.mean(0,keepdims=True)
    vecs = vecs/np.linalg.norm(vecs, axis=-1, keepdims=True)
    if noise:
        noise_term = np.random.randn(*vecs.shape)
        noise_term = noise_term/np.linalg.norm(noise_term, axis=-1, keepdims=True)
        vecs = vecs + 0.2*noise_term
        vecs = vecs/np.linalg.norm(vecs, axis=-1, keepdims=True)
    return vecs

if False:
    dim = 16
    num_vectors = 1000
    folder = os.path.join(parent_dir, f'random_{num_vectors}_{dim}_{num_q}')
    os.mkdir(folder)
    vecs = np.random.randn(num_vectors, dim)
    vecs = vecs/np.linalg.norm(vecs, axis=-1, keepdims=True)
elif False:
    dim = 128
    parent_data_dir = "/root/main/ranking/data"
    data_file = os.path.join(parent_data_dir, f'separated/dim{dim}.txt')
    folder = os.path.join(parent_dir, f'separated_{dim}_{num_q}')
    os.mkdir(folder)
    vecs = np.loadtxt(data_file)
    vecs = preproc(vecs)
elif False:
    dim = 32
    overlap = 100
    parent_data_dir = "/root/main/ranking/data"
    data_file = os.path.join(parent_data_dir, f'g2-txt/g2-{dim}-{overlap}.txt')
    folder = os.path.join(parent_dir, f'g2_{dim}_{overlap}')
    os.mkdir(folder)
    vecs = np.loadtxt(data_file)
    vecs = preproc(vecs, noise=False)
elif False:
    n = 300
    dim = 32
    overlap = 100
    parent_data_dir = "/root/main/ranking/data"
    data_file = os.path.join(parent_data_dir, f'g2-txt/g2-{dim}-{overlap}.txt')
    folder = os.path.join(parent_dir, f'g2_{dim}_{overlap}_trunc')
    os.mkdir(folder)
    vecs = np.loadtxt(data_file)
    vecs = preproc(vecs, noise=False)
    np.random.shuffle(vecs)
    vecs = vecs[:300,:]
else:
    dim = 128
    parent_data_dir = "/root/main/ranking/data/synthetic_cluster"
    data_file = os.path.join(parent_data_dir, '10_64_160.txt')
    folder = os.path.join(parent_dir, f'synthetic_cluster')
    os.makedirs(folder, exist_ok=True)
    vecs = np.loadtxt(data_file)
    vecs = preproc(vecs)

def get_query_from_vecs(vecs, magnitude = 0.4):
    noise_term = np.random.randn(vecs.shape[1])
    noise_term = noise_term/np.linalg.norm(noise_term, keepdims=True)
    idx = np.random.randint(low=0, high=vecs.shape[0])
    q = vecs[idx,:] + magnitude * noise_term
    q = q/np.linalg.norm(q, keepdims=True)
    return q

q = np.stack([get_query_from_vecs(vecs) for _ in range(num_q)], axis = 0)
print(q.shape)

np.savetxt(fname=os.path.join(folder, 'vecs.txt'), X=vecs)
np.savetxt(fname=os.path.join(folder, 'q.txt'), X=q)