import argparse
import json
import os
import h5py
import numpy as np
from cuml.decomposition import PCA
from copy import deepcopy
from tqdm import tqdm

PROJECT_ROOTPATH = "/rootpath/adaptiveLengthEmbedding"
MODEL_ROOTPATH = ""
DATA_ROOTPATH = ""
OUTPUT_ROOTPATH = "",
MAX_SAMPLES = 128_000
np.random.seed(1234)

def main(model_name, dataset_name):
    q_input_path = os.path.join(OUTPUT_ROOTPATH, model_name + "_" + dataset_name + "_exp", "origin/q_vectors.h5")
    d_input_path = os.path.join(OUTPUT_ROOTPATH, model_name + "_" + dataset_name + "_exp", "origin/d_vectors.h5")
    q_output_path = os.path.join(OUTPUT_ROOTPATH, model_name + "_" + dataset_name + "_exp", "transform/q_vectors.h5")
    d_output_path = os.path.join(OUTPUT_ROOTPATH, model_name + "_" + dataset_name + "_exp", "transform/d_vectors.h5")

    #1.fit
    pca = PCA(n_components=None)
    with h5py.File(d_input_path, "r") as h5f:
        if h5f["vectors"].shape[0] <= MAX_SAMPLES:
            d_vectors = h5f["vectors"][:]
        else:
            print("generate random indices")
            indices = np.random.choice(h5f["vectors"].shape[0], MAX_SAMPLES, replace=False)
            indices.sort()
            print("generate completed")
            d_vectors = h5f["vectors"][:]
            d_vectors = d_vectors[indices]
        pca.fit(d_vectors)
    print("fit finished")

    #2.transform
    os.makedirs(os.path.join(OUTPUT_ROOTPATH, model_name + "_" + dataset_name + "_exp", "transform"), exist_ok=True)
    with h5py.File(q_input_path, "r") as h5f_s:
        with h5py.File(q_output_path, "w") as h5f_t:
            q_vectors = h5f_s["vectors"][:]
            q_transformed = pca.transform(q_vectors)
            h5f_t.create_dataset("ids", data= h5f_s["ids"][:], dtype=h5py.string_dtype())
            h5f_t.create_dataset("vectors", data=q_transformed, dtype=np.float32)

    with h5py.File(d_input_path, "r") as h5f_s:
        with h5py.File(d_output_path, "w") as h5f_t:
            h5f_t.create_dataset("ids", data= h5f_s["ids"][:], dtype=h5py.string_dtype())
            h5f_t.create_dataset("vectors", shape=h5f_s["vectors"].shape, dtype=np.float32)
            batch_size = MAX_SAMPLES
            for start_idx in tqdm(range(0, h5f_s["vectors"].shape[0], batch_size)):
                end_idx = start_idx + batch_size
                batch_vectors = h5f_s["vectors"][start_idx:end_idx]
                batch_transformed = pca.transform(batch_vectors)
                h5f_t["vectors"][start_idx:end_idx] = batch_transformed

    print(f"model {model_name} data {dataset_name} finished")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument("--model_name", type=str, required=True)
    parser.add_argument("--dataset_name", type=str, required=True)

    args = parser.parse_args()
    config_path = os.path.join(PROJECT_ROOTPATH, "configs", "transform_config.json")
    with open(config_path) as f:
        config = json.load(f)
        MODEL_ROOTPATH = config["model_rootpath"]
        DATA_ROOTPATH = config["data_rootpath"]
        OUTPUT_ROOTPATH = config["output_rootpath"]

    main(args.model_name, args.dataset_name)
