import argparse

import numpy as np
import torch
from sklearn.preprocessing import normalize
from tqdm import tqdm

parser = argparse.ArgumentParser()
parser.add_argument("--batch-size", type=int)
parser.add_argument("--center")
parser.add_argument("--input")
args = parser.parse_args()

list_gpu_device = [torch.device(i % 8) for i in range(8)]
torch.backends.cudnn.benchmark = True


@torch.no_grad()
def score(x, y):
    bs_score = torch.einsum("ik, jk -> ij", x, y)
    score, index = torch.topk(bs_score, k=10, dim=1)
    index_pt = index[:, 0].long()
    score_pt = score
    return index_pt, score_pt


class runner():
    def __init__(self, feat, center, rank, size=8):
        self.num_local = feat.shape[0] // size + int(rank < feat.shape[0] % size)
        self.start = feat.shape[0] // size * rank + min(rank, feat.shape[0] % size)
        current_feat = feat[self.start : self.start + self.num_local]
        self.feat = torch.from_numpy(current_feat).to(torch.device(rank % 8))
        self.center = torch.from_numpy(center).to(torch.device(rank % 8))
        self.is_end = False
        self.index = 0

        self.pt_label = torch.zeros(self.feat.size(0), dtype=torch.long, device=self.feat.device)
        self.pt_score = torch.zeros(self.feat.size(0), 10, device=self.feat.device)

    def __call__(self, batch_size):
        if self.index + batch_size < self.num_local:
            end = self.index + batch_size
        else:
            end = self.num_local
            self.is_end = True

        x = self.feat[self.index: end]
        y = self.center

        index_pt, score_pt = score(x, y)
        self.pt_label[self.index: end] = index_pt
        self.pt_score[self.index: end] = score_pt
        self.index += batch_size


@torch.no_grad()
def main():
    size = 8
    feat = normalize(np.load(args.input)[:, :128])
    center = normalize(np.load(args.center)[:, :128])

    list_runner = []

    for rank in range(size):
        list_runner.append(runner(feat, center, rank))

    pbar = tqdm(total=feat.shape[0])

    end_list = [0, ] * 8
    while sum(end_list) < 8:
        for idx, runner_instance in enumerate(list_runner):
            runner_instance: runner
            if not end_list[idx]:
                runner_instance(args.batch_size)
                pbar.update(args.batch_size)

            if runner_instance.is_end:
                end_list[idx] = 1

    np_label = np.concatenate([x.pt_label.cpu().numpy() for x in list_runner])
    np_score = np.concatenate([x.pt_score.cpu().numpy() for x in list_runner])
    np.save(f"{args.input}_label.npy", np_label)
    np.save(f"{args.input}_score.npy", np_score)

    # for i in range(np_label.shape[0]):
    #     print(f"{i} {np_label[i]} {np_score[i][0]}")


if __name__ == "__main__":
    main()
