#检测需要多少个维度能达到指定阈值
from collections import Counter
import h5py
import numpy as np
import os

model_name = "gte-Qwen2-7B-instruct"
datasets = ["fiqa","nq","hotpotqa","msmarco"]
kinds = ["origin","transform"]
output_rootpath = "/rootpath/vectors"
PROJECT_ROOTPATH = "/rootpath/adaptiveLengthEmbedding"

def analysis(dataset_name,kind):
    data_num = 128_000
    thresholds = [0.5,0.55,0.6,0.65,0.7,0.75,0.8,0.85,0.9,0.95]
    exp_name = f"{model_name}_{dataset_name}_exp"
    d_vector_path = os.path.join(output_rootpath, exp_name, f"{kind}/d_vectors.h5")
    with h5py.File(d_vector_path, "r") as dhf:
        if data_num > dhf["vectors"].shape[0]:
            data_num = dhf["vectors"].shape[0]
        d_vectors = dhf["vectors"][:data_num, :]
        print("read finish")
        max_val = np.max(np.abs(d_vectors), axis=1, keepdims=True)
        d_vectors = d_vectors / max_val
        d_squared = d_vectors ** 2
        d_squared_sorted = -np.sort(-d_squared, axis=1)
        d_squared_sorted_cum = np.cumsum(d_squared_sorted, axis=1)
        result = []
        for threshold in thresholds:
            required = np.sum(d_squared, axis=1,keepdims=True) * threshold
            required_num = np.argmax(d_squared_sorted_cum > required, axis=1)
            result.append(required_num)
        np.save(os.path.join(PROJECT_ROOTPATH, "result", "threshold_statistics",f"{dataset_name}_{kind}.npy"),np.stack(result) )

for dataset_name in datasets:
    for kind in kinds:
        analysis(dataset_name,kind)

