
import numpy as np
from tqdm import trange

d = np.load('bin_to_colors.npy') / 255.0
src = d[:, :, np.newaxis, np.newaxis, :]    # (4096, 16, 1, 1, 2)
tgt = d[np.newaxis, np.newaxis, :, :, :]    # (1, 1, 4096, 16, 2)

bsz = 512
cost_matrix = np.zeros((65536, 4096))
target_indices = np.zeros((65536, 4096))
for i in trange(0, 4096, bsz):
    for j in range(0, 4096, bsz):
        src_batch = d[i:i+bsz]  # (bsz, 16, 2)
        tgt_batch = d[j:j+bsz]  # (bsz, 16, 2)
        src_exp = src_batch[:, :, np.newaxis, np.newaxis, :]  # (bsz, 16, 1, 1, 2)
        tgt_exp = tgt_batch[np.newaxis, np.newaxis, :, :, :]  # (1, 1, bsz, 16, 2)
        l1_dist = np.abs(src_exp - tgt_exp).sum(axis=-1)      # (bsz, 16, bsz, 16)
        min_dist = np.min(l1_dist, axis=-1)                   # (bsz, 16, bsz)
        min_index = np.argmin(l1_dist, axis=-1)               # (bsz, 16, bsz)
        cost_matrix[i*16:(i+bsz)*16, j:j+bsz] = min_dist.reshape(-1, bsz)
        target_indices[i*16:(i+bsz)*16, j:j+bsz] = min_index.reshape(-1, bsz)
target_indices = target_indices.astype(int)

np.save('cost_matrix', cost_matrix)
np.save('target_indices', target_indices)
