from data_utils import *
from tqdm import tqdm

def generate_data_SymReg(data_dir, num_samples, num_actions):
    vertices, triangles = read_off(data_dir + '/chair.off')
    vertices -= np.mean(vertices, axis=0)
    vertices /= np.mean(np.linalg.norm(vertices, axis=1))
    action_list = [expm(0.2 * logm(special_ortho_group.rvs(3))) for _ in range(num_actions)]
    init_list = [special_ortho_group.rvs(3) for _ in range(num_samples)]
    images = np.empty((num_actions, num_samples, 3, 48, 48), dtype=np.uint8)
    for i in tqdm(range(num_samples)):
        for j in range(num_actions):
            R = init_list[i] @ action_list[j]
            images[j, i] = render(vertices, triangles, R)
    np.save(data_dir + '/images', images)
    print("Generated Images")

def generate_data_drlim(data_dir, num_samples):
    vertices, triangles = read_off(data_dir + '/chair.off')
    vertices -= np.mean(vertices, axis=0)
    vertices /= np.mean(np.linalg.norm(vertices, axis=1))
    x = np.empty((num_samples, 2, 3, 48, 48))
    for i in tqdm(range(num_samples)):
        action = expm(0.05 * logm(special_ortho_group.rvs(3)))
        R1 = special_ortho_group.rvs(3)
        x[i, 0] = render(vertices, triangles, R1)
        R2 = action @ R1
        x[i, 1] = render(vertices, triangles, R2)
    np.save(data_dir + '/images', x)
    print(f'Data saveed in {data_dir}')

if __name__ == '__main__':
    data_dir = '~/scratch/Chair_SymmetryReg'
    num_samples = 20000
    #generate_data_drlim(data_dir, num_samples)
    generate_data_SymReg(data_dir, num_samples, num_actions=4)