from collections import Counter

import h5py
import numpy as np
import os

dim = 768
threshold = 0.7

model_name = "gte-Qwen2-7B-instruct"
dataset_name = "nq"
output_rootpath = "/rootpath/vectors"
exp_name = f"{model_name}_{dataset_name}_exp"
d_vector_path = os.path.join(output_rootpath, exp_name, "transform/d_vectors.h5")
with h5py.File(d_vector_path, "r") as dhf:
    d_vectors = dhf["vectors"][:, :]
    print("read finish")
    max_val = np.max(np.abs(d_vectors), axis=1, keepdims=True)
    d_vectors = d_vectors / max_val
    d_squared = d_vectors ** 2
    d_dense_squared = d_squared[:, :dim]
    d_sparse_squared = d_squared[:, dim:]
    inverted_index = [[] for _ in range(d_sparse_squared.shape[1])]
    d_sum = np.sum(d_squared, axis=1,keepdims=True)
    d_dense_sum = np.sum(d_dense_squared, axis=1,keepdims=True)
    sparse_required = np.maximum((d_sum * threshold - d_dense_sum),0)
    print("sum finish")
    sparse_index = np.tile(np.arange(d_sparse_squared.shape[1]), (d_vectors.shape[0], 1))
    sorted_idx = np.argsort(-d_sparse_squared, axis=1)
    d_sparse_squared_sorted = np.take_along_axis(d_sparse_squared, sorted_idx, axis=1)
    d_index_sorted = np.take_along_axis(sparse_index, sorted_idx, axis=1)
    d_sparse_sorted_cum = np.cumsum(d_sparse_squared_sorted, axis=1)
    required_num = np.argmax(d_sparse_sorted_cum > sparse_required, axis=1)
    print("building inverted index...")
    for i in range(d_vectors.shape[0]):
        for j in range(required_num[i]):
            inverted_index[d_index_sorted[i, j]].append((i, d_sparse_squared_sorted[i, j]))
    inverted_index_num = [len(x) for x in inverted_index]
    count = Counter(inverted_index_num)
    print(count)


