import argparse
import json
import os
import h5py
import numpy as np
from cuml.decomposition import PCA
from cuml import TruncatedSVD, KMeans
from copy import deepcopy
from tqdm import tqdm

PROJECT_ROOTPATH = "/rootpath/adaptiveLengthEmbedding"
MODEL_ROOTPATH = ""
DATA_ROOTPATH = ""
OUTPUT_ROOTPATH = "",
MAX_SAMPLES = 128_000
seed = 1234
np.random.seed(seed)

def main(model_name, dataset_name):
    # 输入文件路径

    q_input_path = os.path.join(OUTPUT_ROOTPATH, model_name + "_" + dataset_name + "_exp", "origin/q_vectors.h5")
    d_input_path = os.path.join(OUTPUT_ROOTPATH, model_name + "_" + dataset_name + "_exp", "origin/d_vectors.h5")
    q_output_path = os.path.join(OUTPUT_ROOTPATH, model_name + "_" + dataset_name + "_exp", "cluster_transform/q_vectors.h5")
    d_output_path = os.path.join(OUTPUT_ROOTPATH, model_name + "_" + dataset_name + "_exp", "cluster_transform/d_vectors.h5")

    #1.fit
    pca = PCA(n_components=None)
    with h5py.File(d_input_path, "r") as h5f:
        d_vectors = h5f["vectors"][:].astype(np.float32)

    if d_vectors.shape[0] >= MAX_SAMPLES:
        indices = np.random.choice(d_vectors.shape[0], MAX_SAMPLES*8, replace=False)
        indices.sort()
        kmeans_fit_vectors = d_vectors[indices]
        # 3. KMeans聚类
        n_clusters = 25  # 动态调整聚类数
        kmeans = KMeans(n_clusters=n_clusters, random_state=seed,max_iter=1000)
        kmeans.fit(kmeans_fit_vectors)
        print("kmeans fit completed")
        batch_size = 128_000
        cluster_labels = np.empty(len(d_vectors), dtype=np.int32)
        for i in range(0, len(d_vectors), batch_size):
            end_idx = min(i + batch_size, len(d_vectors))
            batch_pred = kmeans.predict(d_vectors[i:end_idx])
            cluster_labels[i:end_idx] = batch_pred.flatten()
        # 4. 分层抽样
        unique_labels, counts = np.unique(cluster_labels, return_counts=True)
        total = len(d_vectors)
        samples_per_cluster = (counts / total * MAX_SAMPLES).astype(int)
        sampled_indices = []
        for label, num in zip(unique_labels, samples_per_cluster):
            mask = cluster_labels == label
            indices = np.where(mask)[0]
            sampled_indices.append(np.random.choice(indices, num, replace=False))

        sampled_indices = np.concatenate(sampled_indices)

        # 5. PCA拟合
        d_vectors_sampled = d_vectors[sampled_indices]
    else:
        d_vectors_sampled = d_vectors

    pca.fit(d_vectors_sampled)
    print("PCA fitting completed")

    #2.transform
    os.makedirs(os.path.join(OUTPUT_ROOTPATH, model_name + "_" + dataset_name + "_exp", "cluster_transform"), exist_ok=True)
    with h5py.File(q_input_path, "r") as h5f_s:
        with h5py.File(q_output_path, "w") as h5f_t:
            q_vectors = h5f_s["vectors"][:]
            q_transformed = pca.transform(q_vectors)
            h5f_t.create_dataset("ids", data= h5f_s["ids"][:], dtype=h5py.string_dtype())
            h5f_t.create_dataset("vectors", data=q_transformed, dtype=np.float32)

    with h5py.File(d_input_path, "r") as h5f_s:
        with h5py.File(d_output_path, "w") as h5f_t:
            h5f_t.create_dataset("ids", data= h5f_s["ids"][:], dtype=h5py.string_dtype())
            h5f_t.create_dataset("vectors", shape=h5f_s["vectors"].shape, dtype=np.float32)
            batch_size = MAX_SAMPLES
            for start_idx in tqdm(range(0, h5f_s["vectors"].shape[0], batch_size)):
                end_idx = start_idx + batch_size
                batch_vectors = h5f_s["vectors"][start_idx:end_idx]
                batch_transformed = pca.transform(batch_vectors)
                h5f_t["vectors"][start_idx:end_idx] = batch_transformed

    print(f"model {model_name} data {dataset_name} finished")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    # 添加命令行参数
    parser.add_argument("--model_name", type=str, required=True)
    parser.add_argument("--dataset_name", type=str, required=True)

    # 解析命令行参数
    args = parser.parse_args()
    config_path = os.path.join(PROJECT_ROOTPATH, "configs", "transform_config.json")
    with open(config_path) as f:
        config = json.load(f)
        MODEL_ROOTPATH = config["model_rootpath"]
        DATA_ROOTPATH = config["data_rootpath"]
        OUTPUT_ROOTPATH = config["output_rootpath"]

    main(args.model_name, args.dataset_name)
