import argparse
import json
import os
import asyncio
import numpy as np
import math
import umap
os.environ['CUDA_VISIBLE_DEVICES'] = '3'
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
from kblam.gpt_session import GPT
from kblam.gpt_session_async import GPTAsync
from kblam.utils.data_utils import DataPoint
import psutil
from sklearn.mixture import GaussianMixture
from numpy.lib.format import open_memmap


def check_memory():
    """CPU and Memory Usage Check"""
    memory = psutil.virtual_memory()
    if memory.percent > 85:
        print(f"\033[91mWarning: Memory usage {memory.percent}%. Shutdown...\033[0m")
        return False
    return True

def parser_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_name",
        type=str,
        default="text-embedding-3-large",
        choices=["all-MiniLM-L6-v2", "text-embedding-3-large", "ada-embeddings", "text-embedding-ada-002"],
    )
    parser.add_argument("--dataset_name", type=str, default="synthetic_data_QA_augmented")
    parser.add_argument("--endpoint_url", type=str)
    parser.add_argument("--endpoint_api_key", type=str)
    parser.add_argument(
        "--dataset_path",
        type=str,
        required=False,
        help="Path to the dataset in JSON format.",
    )
    parser.add_argument("--output_path", type=str, default="***")
    parser.add_argument("--max_concurrency", type=int, default=10, help="Maximum number of concurrent API calls")
    parser.add_argument("--batch_size", type=int, default=8192, help="Batch size for embedding generation")
    parser.add_argument("--random", action="store_true")
    parser.add_argument("--generating_embeddings", action="store_true")
    # Chunked processing settings
    parser.add_argument("--chunked_processing", action="store_true", help="Enable chunked processing for large datasets")
    parser.add_argument("--chunk_size", type=int, default=100000, help="Size of each chunk for processing")
    parser.add_argument("--max_chunks", type=int, default=None, help="Maximum number of chunks to process (None for all)")
    # GMM cluster settings
    parser.add_argument("--cluster", action="store_true")
    parser.add_argument("--gmm_n_init", type=int, default=2, help="Number of initializations for GMM")  # 2
    parser.add_argument("--gmm_max_iter", type=int, default=50, help="Maximum number of iterations for GMM")
    parser.add_argument("--n_layers", type=int, default=2, help="Number of layers for hierarchical clustering")
    parser.add_argument("--umap_dim", type=int, default=64, help="Number of dimensions for UMAP")   # 64
    parser.add_argument("--umap_n_neighbors", type=int, default=30, help="Number of neighbors for UMAP")
    parser.add_argument("--umap_min_dist", type=float, default=0.1, help="Minimum distance for UMAP")
    parser.add_argument("--umap_metric", type=str, default="cosine", help="Metric for UMAP")
    args = parser.parse_args()
    return args


def _cosine_sim(a: np.ndarray, b: np.ndarray) -> float:
    denom = (np.linalg.norm(a) * np.linalg.norm(b)) + 1e-12
    return float(a @ b / denom)


def _window_spot_check(emb_matrix: np.ndarray, texts: list[str], recompute_batch_fn, indices: list[int], window: int = 3, tol: float = 1e-3):
    """
    sample some embeddings and check if they are correct
    """
    problems = []
    # to reduce API calls, we batch recompute
    batch_size = max(1, min(32, len(indices)))
    for start in range(0, len(indices), batch_size):
        batch_idx = indices[start:start + batch_size]
        batch_texts = [texts[i] for i in batch_idx]
        recomputed = recompute_batch_fn(batch_texts)
        for local_k, i in enumerate(batch_idx):
            e = recomputed[local_k]
            win_start = max(0, i - window)
            win_end = min(len(texts), i + window + 1)
            sims = [_cosine_sim(e, emb_matrix[j]) for j in range(win_start, win_end)]
            j_star = win_start + int(np.argmax(sims))
            sim_i = sims[int(i - win_start)]
            sim_star = max(sims)
            if j_star != i and (sim_i + tol) < sim_star:
                problems.append((int(i), int(j_star), float(sim_i), float(sim_star)))
    return problems


def compute_embeddings(
    encoder_model_spec: str, dataset: list[DataPoint], part: str, batch_size: int = 100
) -> np.array:
    """Compute embeddings for the given dataset in batches using the encoder model spec."""
    embeddings = []
    all_elements = []
    for entity in dataset:
        if part == "key_string":
            all_elements.append(entity.key_string)
        elif part == "description":
            all_elements.append(entity.description)
        else:
            raise ValueError(f"Part {part} not supported.")
    chunks = [
        all_elements[i : i + batch_size]
        for i in range(0, len(all_elements), batch_size)
    ]

    model = SentenceTransformer(encoder_model_spec, device="cuda", cache_folder="***")
    for chunk in tqdm(chunks):
        embd = model.encode(chunk, convert_to_numpy=True)
        embeddings.append(embd)

    embeddings = np.concatenate(embeddings, 0)
    assert len(embeddings) == len(all_elements)
    return embeddings


async def compute_embeddings_async(
    gpt: GPTAsync, dataset: list[DataPoint], part: str, max_concurrency: int = 10
) -> list[list[float]]:
    """Compute embeddings for the given dataset using async GPT with concurrency control."""
    embeddings = []
    all_elements = []
    for entity in dataset:
        if part == "key_string":
            all_elements.append(entity.key_string)
        elif part == "description":
            all_elements.append(entity.description)
        else:
            raise ValueError(f"Part {part} not supported.")
    
    # Create semaphore to limit concurrent API calls
    semaphore = asyncio.Semaphore(max_concurrency)
    
    async def get_embedding_with_semaphore(element):
        async with semaphore:
            return await gpt.generate_embedding(element)
    
    # Create tasks for all embeddings
    tasks = [get_embedding_with_semaphore(element) for element in all_elements]
    
    # Process with progress bar
    embeddings = []
    for task in tqdm(asyncio.as_completed(tasks), total=len(tasks), desc=f"Generating {part} embeddings"):
        embedding = await task
        embeddings.append(embedding)
    
    return embeddings


def process_chunked_embeddings(args, dataset_path, save_name):
    """Process large datasets in chunks to avoid memory issues - truly streaming approach"""
    print(f"\033[94m[Starting streaming chunked processing for large dataset]\033[0m")
    
    # Initialize model
    if args.model_name == "all-MiniLM-L6-v2":
        model = SentenceTransformer(args.model_name, device="cuda:0", cache_folder="***")
    else:
        raise NotImplementedError(f"Chunked processing not implemented for model {args.model_name}")
    
    chunk_size = args.chunk_size
    chunk_files = []
    total_processed = 0
    chunk_idx = 0
    
    # Create output directory for chunk files
    chunk_dir = f"{args.output_path}/chunks_{args.dataset_name}_{save_name}"
    os.makedirs(chunk_dir, exist_ok=True)
    
    print(f"Processing with chunk size: {chunk_size}")
    
    # Stream through the file once
    with open(dataset_path, 'r') as f:
        chunk_data = []
        
        for line_num, line in enumerate(f):                
            try:
                data = json.loads(line.strip())
                chunk_data.append(DataPoint(**data))
            except Exception as e:
                print(f"Error loading line {line_num}: {e}")
                continue
            
            # When chunk is full, process it
            if len(chunk_data) >= chunk_size:
                print(f"\033[94m[Processing chunk {chunk_idx + 1} with {len(chunk_data)} records]\033[0m")
                
                # Check memory
                if not check_memory():
                    print("Insufficient memory, stopping...")
                    exit(0)
                
                # Process this chunk
                chunk_key_file, chunk_val_file = process_single_chunk(
                    chunk_data, chunk_idx, chunk_dir, model, args.batch_size
                )
                
                if chunk_key_file and chunk_val_file:
                    chunk_files.append((chunk_key_file, chunk_val_file))
                    total_processed += len(chunk_data)
                    print(f"Chunk {chunk_idx + 1} completed, total processed: {total_processed}")
                
                # Clear memory
                del chunk_data
                chunk_data = []
                chunk_idx += 1
                
                # Check if we've reached max chunks
                if args.max_chunks and chunk_idx >= args.max_chunks:
                    print(f"Reached maximum chunks limit: {args.max_chunks}")
                    break
        
        # Process remaining data in last chunk
        if chunk_data:
            print(f"\033[94m[Processing final chunk {chunk_idx + 1} with {len(chunk_data)} records]\033[0m")
            
            if check_memory():
                chunk_key_file, chunk_val_file = process_single_chunk(
                    chunk_data, chunk_idx, chunk_dir, model, args.batch_size
                )
                
                if chunk_key_file and chunk_val_file:
                    chunk_files.append((chunk_key_file, chunk_val_file))
                    total_processed += len(chunk_data)
                    print(f"Final chunk completed, total processed: {total_processed}")
            
            del chunk_data
    
    print(f"\033[94m[All chunks processed, merging {len(chunk_files)} chunk files]\033[0m")
    
    # Merge all chunk files into final files
    merge_chunk_files(chunk_files, args.output_path, args.dataset_name, save_name)
    
    # Clean up chunk files
    for chunk_key_file, chunk_val_file in chunk_files:
        try:
            os.remove(chunk_key_file)
            os.remove(chunk_val_file)
        except:
            pass
    
    # Remove chunk directory
    try:
        os.rmdir(chunk_dir)
    except:
        pass
    
    print(f"\033[92m[Streaming processing completed! Total {total_processed} records processed]\033[0m")
    return total_processed


def process_single_chunk(chunk_data, chunk_idx, chunk_dir, model, batch_size):
    """Process a single chunk of data and save to temporary files"""
    try:
        # Extract texts
        key_texts = [data.key_string for data in chunk_data]
        value_texts = [data.description for data in chunk_data]
        
        # Process key embeddings
        key_embeddings = []
        for i in tqdm(range(0, len(key_texts), batch_size)):
            # if not check_memory():
            #     exit(0)
            batch = key_texts[i:i + batch_size]
            embd = model.encode(batch, convert_to_numpy=True)
            key_embeddings.append(embd)
        
        # Process value embeddings
        val_embeddings = []
        for i in tqdm(range(0, len(value_texts), batch_size)):
            if not check_memory():
                exit(0)
            batch = value_texts[i:i + batch_size]
            embd = model.encode(batch, convert_to_numpy=True)
            val_embeddings.append(embd)
        
        # Concatenate embeddings
        key_embeddings = np.concatenate(key_embeddings, axis=0)
        val_embeddings = np.concatenate(val_embeddings, axis=0)
        
        # Save chunk files
        chunk_key_file = f"{chunk_dir}/chunk_{chunk_idx:06d}_key.npy"
        chunk_val_file = f"{chunk_dir}/chunk_{chunk_idx:06d}_val.npy"
        
        np.save(chunk_key_file, key_embeddings.astype(np.float32))
        np.save(chunk_val_file, val_embeddings.astype(np.float32))
        
        # Clear memory
        del key_embeddings, val_embeddings, key_texts, value_texts
        
        return chunk_key_file, chunk_val_file
        
    except Exception as e:
        print(f"Error processing chunk {chunk_idx}: {e}")
        return None, None


def merge_chunk_files(chunk_files, output_path, dataset_name, save_name):
    """Merge all chunk files into final embedding files"""
    if not chunk_files:
        print("No chunk files to merge")
        return
    
    print(f"Merging {len(chunk_files)} chunk files...")
    
    # Load first chunk to get dimensions
    first_key_file, first_val_file = chunk_files[0]
    first_key_embd = np.load(first_key_file)
    first_val_embd = np.load(first_val_file)
    
    key_dim = first_key_embd.shape[1]
    val_dim = first_val_embd.shape[1]
    
    # Calculate total size
    total_size = sum(np.load(key_file).shape[0] for key_file, _ in chunk_files)
    
    # Create final memory-mapped files
    final_key_file = f"{output_path}/{dataset_name}_{save_name}_embd_key.npy"
    final_val_file = f"{output_path}/{dataset_name}_{save_name}_embd_value.npy"
    
    key_mm = open_memmap(final_key_file, mode='w+', dtype=np.float32, shape=(total_size, key_dim))
    val_mm = open_memmap(final_val_file, mode='w+', dtype=np.float32, shape=(total_size, val_dim))
    
    # Merge chunks
    current_idx = 0
    for chunk_key_file, chunk_val_file in tqdm(chunk_files, desc="Merging chunks"):
        key_embd = np.load(chunk_key_file)
        val_embd = np.load(chunk_val_file)
        
        chunk_size = key_embd.shape[0]
        end_idx = current_idx + chunk_size
        
        key_mm[current_idx:end_idx] = key_embd.astype(np.float32)
        val_mm[current_idx:end_idx] = val_embd.astype(np.float32)
        
        current_idx = end_idx
        
        # Clear memory
        del key_embd, val_embd
    
    # Flush to disk
    del key_mm, val_mm
    
    print(f"Merge completed! Final files saved to {final_key_file} and {final_val_file}")


async def main():
    args = parser_args()
    os.makedirs(args.output_path, exist_ok=True)

    if args.model_name == "all-MiniLM-L6-v2":
        save_name = "all-MiniLM-L6-v2"
    elif args.model_name in ["ada-embeddings", "text-embedding-ada-002"]:
        save_name = "OAI"
    else:
        save_name = "BigOAI"

    # Check if we should use chunked processing
    if args.chunked_processing:
        print(f"\033[94m[Using chunked processing mode]\033[0m")
        if args.generating_embeddings:
            process_chunked_embeddings(args, args.dataset_path, save_name)
        # Load embeddings for clustering
        key_embeds = np.load(f"{args.output_path}/{args.dataset_name}_{save_name}_embd_key.npy")
        value_embeds = np.load(f"{args.output_path}/{args.dataset_name}_{save_name}_embd_value.npy")
    else:
        # Original processing logic
        if not any(split in args.dataset_name for split in ["train", "val", "test"]):
            with open(args.dataset_path, "r") as file:
                loaded_dataset = json.loads(file.read())
                dataset = [DataPoint(**line) for line in loaded_dataset]
        
        if args.generating_embeddings:
            if args.model_name == "all-MiniLM-L6-v2":
                if not args.random:
                    # generate embeddings (streaming write)
                    key_texts = [data.key_string for data in dataset]
                    value_texts = [data.description for data in dataset]

                    model = SentenceTransformer(args.model_name, device="cuda", cache_folder="***")

                    # Key embeddings streaming
                    batch_size = args.batch_size
                    total_n = len(key_texts)
                    # prime first batch to know dim
                    first_batch = key_texts[:batch_size]
                    first_embd = model.encode(first_batch, convert_to_numpy=True)
                    key_dim = int(first_embd.shape[1])
                    key_file = f"{args.output_path}/{args.dataset_name}_{save_name}_embd_key.npy"
                    key_mm = open_memmap(key_file, mode='w+', dtype=np.float32, shape=(total_n, key_dim))
                    # write first batch
                    key_mm[0:len(first_batch)] = first_embd.astype(np.float32)
                    wrote = len(first_batch)
                    for i in tqdm(range(batch_size, total_n, batch_size), desc="Generating key_string embeddings"):
                        if not check_memory():
                            exit(0)
                        batch_texts = key_texts[i:i + batch_size]
                        embd = model.encode(batch_texts, convert_to_numpy=True)
                        key_mm[i:i+len(batch_texts)] = embd.astype(np.float32)
                        wrote += len(batch_texts)
                    del key_mm  # flush to disk

                    # Value embeddings streaming
                    total_n_v = len(value_texts)
                    first_batch_v = value_texts[:batch_size]
                    first_embd_v = model.encode(first_batch_v, convert_to_numpy=True)
                    val_dim = int(first_embd_v.shape[1])
                    val_file = f"{args.output_path}/{args.dataset_name}_{save_name}_embd_value.npy"
                    val_mm = open_memmap(val_file, mode='w+', dtype=np.float32, shape=(total_n_v, val_dim))
                    val_mm[0:len(first_batch_v)] = first_embd_v.astype(np.float32)
                    wrote_v = len(first_batch_v)
                    for i in tqdm(range(batch_size, total_n_v, batch_size), desc="Generating description embeddings"):
                        if not check_memory():
                            exit(0)
                        batch_texts = value_texts[i:i + batch_size]
                        embd = model.encode(batch_texts, convert_to_numpy=True)
                        val_mm[i:i+len(batch_texts)] = embd.astype(np.float32)
                        wrote_v += len(batch_texts)
                    del val_mm  # flush to disk
                else:
                    # generate embeddings with pair-wise random shuffle (preserve KV alignment, randomize order)
                    key_texts = [data.key_string for data in dataset]
                    value_texts = [data.description for data in dataset]

                    # Create a single permutation to shuffle pairs consistently
                    rng = np.random.default_rng(2025)
                    total_n = len(key_texts)
                    perm = rng.permutation(total_n)
                    key_texts = [key_texts[i] for i in perm]
                    value_texts = [value_texts[i] for i in perm]

                    model = SentenceTransformer(args.model_name, device="cuda", cache_folder="***")

                    # Key embeddings streaming (shuffled order)
                    batch_size = args.batch_size
                    first_batch = key_texts[:batch_size]
                    first_embd = model.encode(first_batch, convert_to_numpy=True)
                    key_dim = int(first_embd.shape[1])
                    key_file = f"{args.output_path}/{args.dataset_name}_{save_name}_embd_key.npy"
                    key_mm = open_memmap(key_file, mode='w+', dtype=np.float32, shape=(total_n, key_dim))
                    key_mm[0:len(first_batch)] = first_embd.astype(np.float32)
                    for i in tqdm(range(batch_size, total_n, batch_size), desc="Generating key_string embeddings (shuffled)"):
                        if not check_memory():
                            exit(0)
                        batch_texts = key_texts[i:i + batch_size]
                        embd = model.encode(batch_texts, convert_to_numpy=True)
                        key_mm[i:i+len(batch_texts)] = embd.astype(np.float32)
                    del key_mm  # flush to disk

                    # Value embeddings streaming (shuffled order, aligned with keys)
                    first_batch_v = value_texts[:batch_size]
                    first_embd_v = model.encode(first_batch_v, convert_to_numpy=True)
                    val_dim = int(first_embd_v.shape[1])
                    val_file = f"{args.output_path}/{args.dataset_name}_{save_name}_embd_value.npy"
                    val_mm = open_memmap(val_file, mode='w+', dtype=np.float32, shape=(total_n, val_dim))
                    val_mm[0:len(first_batch_v)] = first_embd_v.astype(np.float32)
                    for i in tqdm(range(batch_size, total_n, batch_size), desc="Generating description embeddings (shuffled)"):
                        if not check_memory():
                            exit(0)
                        batch_texts = value_texts[i:i + batch_size]
                        embd = model.encode(batch_texts, convert_to_numpy=True)
                        val_mm[i:i+len(batch_texts)] = embd.astype(np.float32)
                    del val_mm  # flush to disk
            elif args.model_name in ["ada-embeddings", "text-embedding-3-large", "text-embedding-ada-002"]:
                if not args.random:
                    # generate embeddings (streaming write)
                    gpt = GPT(args.model_name, args.endpoint_url, args.endpoint_api_key)
                    batch_size = args.batch_size
                    # Keys
                    print("Generating key_string embeddings...")
                    key_texts = [data.key_string for data in dataset]
                    total_n = len(key_texts)
                    key_file = f"{args.output_path}/{args.dataset_name}_{save_name}_embd_key.npy"
                    # prime first batch
                    first_batch = key_texts[:batch_size]
                    first_embeddings = gpt.generate_embeddings_batch(first_batch) or []
                    if len(first_embeddings) == 0:
                        raise RuntimeError("Failed to generate first batch of key embeddings")
                    key_dim = int(len(first_embeddings[0]))
                    key_mm = open_memmap(key_file, mode='w+', dtype=np.float32, shape=(total_n, key_dim))
                    key_mm[0:len(first_embeddings)] = np.asarray(first_embeddings, dtype=np.float32)
                    for i in tqdm(range(batch_size, total_n, batch_size), desc="Generating key_string embeddings"):
                        batch_texts = key_texts[i:i + batch_size]
                        batch_embeddings = gpt.generate_embeddings_batch(batch_texts) or []
                        if len(batch_embeddings) != len(batch_texts):
                            print(f"Warning: Key_string embedding generation failed for batch {i//batch_size + 1}")
                        if len(batch_embeddings) > 0:
                            key_mm[i:i+len(batch_embeddings)] = np.asarray(batch_embeddings, dtype=np.float32)
                    del key_mm

                    # Values
                    print("Generating description embeddings...")
                    value_texts = [data.description for data in dataset]
                    total_nv = len(value_texts)
                    val_file = f"{args.output_path}/{args.dataset_name}_{save_name}_embd_value.npy"
                    first_batch_v = value_texts[:batch_size]
                    first_embeddings_v = gpt.generate_embeddings_batch(first_batch_v) or []
                    if len(first_embeddings_v) == 0:
                        raise RuntimeError("Failed to generate first batch of value embeddings")
                    val_dim = int(len(first_embeddings_v[0]))
                    val_mm = open_memmap(val_file, mode='w+', dtype=np.float32, shape=(total_nv, val_dim))
                    val_mm[0:len(first_embeddings_v)] = np.asarray(first_embeddings_v, dtype=np.float32)
                    for i in tqdm(range(batch_size, total_nv, batch_size), desc="Generating description embeddings"):
                        batch_texts = value_texts[i:i + batch_size]
                        batch_embeddings = gpt.generate_embeddings_batch(batch_texts) or []
                        if len(batch_embeddings) != len(batch_texts):
                            print(f"Warning: Description embedding generation failed for batch {i//batch_size + 1}")
                        if len(batch_embeddings) > 0:
                            val_mm[i:i+len(batch_embeddings)] = np.asarray(batch_embeddings, dtype=np.float32)
                    del val_mm            
                else:
                    # generate embeddings
                    gpt = GPTAsync(args.model_name, args.endpoint_url, args.endpoint_api_key)
                    # For async path, stream to disk with index tracking
                    async def stream_async(part: str, outfile: str):
                        all_elements = [getattr(data, part) for data in dataset]
                        semaphore = asyncio.Semaphore(args.max_concurrency)

                        async def get_one(idx: int, text: str):
                            async with semaphore:
                                emb = await gpt.generate_embedding(text)
                                return idx, emb

                        # fetch first to get dim
                        first_idx, first_text = 0, all_elements[0]
                        _, first_emb = await get_one(first_idx, first_text)
                        dim = int(len(first_emb))
                        mm = open_memmap(outfile, mode='w+', dtype=np.float32, shape=(len(all_elements), dim))
                        mm[first_idx] = np.asarray(first_emb, dtype=np.float32)

                        tasks = [asyncio.create_task(get_one(i, t)) for i, t in enumerate(all_elements) if i != first_idx]
                        for coro in tqdm(asyncio.as_completed(tasks), total=len(tasks), desc=f"Generating {part} embeddings"):
                            idx, emb = await coro
                            mm[idx] = np.asarray(emb, dtype=np.float32)
                        del mm

                await stream_async("key_string", f"{args.output_path}/{args.dataset_name}_{save_name}_embd_key.npy")
                await stream_async("description", f"{args.output_path}/{args.dataset_name}_{save_name}_embd_value.npy")
            else:
                raise ValueError(f"Model {args.model_name} not supported.")
            # results were written incrementally to disk as .npy via memmap
            # load back for downstream clustering if needed
            key_embeds = np.load(f"{args.output_path}/{args.dataset_name}_{save_name}_embd_key.npy")
            value_embeds = np.load(f"{args.output_path}/{args.dataset_name}_{save_name}_embd_value.npy")
        else:
            key_embeds = np.load(f"{args.output_path}/{args.dataset_name}_{save_name}_embd_key.npy")
            value_embeds = np.load(f"{args.output_path}/{args.dataset_name}_{save_name}_embd_value.npy")

    # check the consistency of the KV pairs
    try:
        if 'dataset' in locals():
            key_texts = [data.key_string for data in dataset]
            value_texts = [data.description for data in dataset]

            rng = np.random.default_rng(2025)
            num_samples = int(min(200, len(key_texts)))
            sample_indices = sorted(rng.choice(len(key_texts), size=num_samples, replace=False).tolist()) if len(key_texts) > 0 else []

            if len(sample_indices) > 0:
                if args.model_name == "all-MiniLM-L6-v2":
                    st_model = SentenceTransformer(args.model_name, device="cuda", cache_folder="***")
                    def recompute_keys(text_batch: list[str]):
                        return st_model.encode(text_batch, convert_to_numpy=True)
                    def recompute_vals(text_batch: list[str]):
                        return st_model.encode(text_batch, convert_to_numpy=True)
                else:
                    gpt_checker = GPT(args.model_name, args.endpoint_url, args.endpoint_api_key)
                    def recompute_keys(text_batch: list[str]):
                        return np.asarray(gpt_checker.generate_embeddings_batch(text_batch) or [], dtype=np.float32)
                    def recompute_vals(text_batch: list[str]):
                        return np.asarray(gpt_checker.generate_embeddings_batch(text_batch) or [], dtype=np.float32)

                key_problems = _window_spot_check(key_embeds, key_texts, recompute_keys, sample_indices, window=3, tol=1e-3)
                val_problems = _window_spot_check(value_embeds, value_texts, recompute_vals, sample_indices, window=3, tol=1e-3)

                print("\n===== Consistency Check Report =====")
                print(f"Checked samples: {len(sample_indices)} (window=3)")
                print(f"Key mismatches: {len(key_problems)}")
                if len(key_problems) > 0:
                    print(f"Key examples (up to 5): {key_problems[:5]}")
                print(f"Value mismatches: {len(val_problems)}")
                if len(val_problems) > 0:
                    print(f"Value examples (up to 5): {val_problems[:5]}")
                print("===================================\n")
        else:
            print("[Consistency Check] Skipped (dataset not available in this run context).")
    except Exception as e:
        print(f"[Consistency Check] Failed with error: {e}")

    # clustering
    if args.cluster:
        keys = np.array(key_embeds, dtype=np.float32)   # [N, dim]
        print(f"\033[94m[UMAP dimension reduction from {keys.shape[1]} to {args.umap_dim} dimensions]\033[0m")
        reducer = umap.UMAP(
            n_components=args.umap_dim,
            n_neighbors=args.umap_n_neighbors,
            min_dist=args.umap_min_dist,
            metric=args.umap_metric,
            random_state=42,
            verbose=True
        )
        keys_low = reducer.fit_transform(keys.astype(np.float32))   # [N, 64]

        for i in range(args.n_layers):
            if i == 0:
                keys_last = keys_low
                original_keys_last = keys
            k = int(math.ceil(math.pow(len(keys_low), (args.n_layers - i) / (args.n_layers + 1))))
            print(f"\033[94m[Hierarchical Clustering {len(keys_last)} keys with {i+1} / {args.n_layers} layers and {k} clusters]\033[0m")
            gmm = GaussianMixture(
                n_components=k,
                n_init=args.gmm_n_init,
                max_iter=args.gmm_max_iter,
                random_state=42,
                verbose=1,
                init_params='k-means++'
            ).fit(keys_last)
            labels = gmm.predict(keys_last)
            centers = gmm.means_
            # mapping dict
            clust2idlist_mapping = {int(c): [int(i) for i in range(len(keys_last)) if labels[i] == c] for c in range(k)}  # clust_id -> [idx]
            id2clust_mapping = {int(i): int(c) for i, c in enumerate(labels)}  # idx -> clust_id
            # cluster mean with original keys
            original_centers = np.mean(original_keys_last[clust2idlist_mapping[0]], axis=0)
            for c in range(1, k):
                original_centers = np.concatenate([original_centers, np.mean(original_keys_last[clust2idlist_mapping[c]], axis=0)])
            original_centers = original_centers.reshape(-1, original_keys_last.shape[-1])

            if i == args.n_layers - 1:
                np.save(f"{args.output_path}/{args.dataset_name}_{save_name}_embd_key_root.npy", original_centers.astype(np.float32))
                with open(f"{args.output_path}/{args.dataset_name}_{save_name}_embd_key_root_c2id_mapping.json", "w") as f:
                    json.dump(clust2idlist_mapping, f)
                with open(f"{args.output_path}/{args.dataset_name}_{save_name}_embd_key_root_id2c_mapping.json", "w") as f:
                    json.dump(id2clust_mapping, f)
            else:
                np.save(f"{args.output_path}/{args.dataset_name}_{save_name}_embd_key_inter{args.n_layers - 1 - i}.npy", original_centers.astype(np.float32))
                with open(f"{args.output_path}/{args.dataset_name}_{save_name}_embd_key_inter{args.n_layers - 1 - i}_c2id_mapping.json", "w") as f:
                    json.dump(clust2idlist_mapping, f)
                with open(f"{args.output_path}/{args.dataset_name}_{save_name}_embd_key_inter{args.n_layers - 1 - i}_id2c_mapping.json", "w") as f:
                    json.dump(id2clust_mapping, f)
            keys_last = centers
            original_keys_last = original_centers
        

if __name__ == "__main__":
    asyncio.run(main())

# Usage example for chunked processing:
# python generate_kb_embeddings_gmm.py \
#     --model_name all-MiniLM-L6-v2 \
#     --dataset_path /path/to/your/large_dataset.json \
#     --dataset_name your_dataset \
#     --output_path /path/to/output/ \
#     --generating_embeddings \
#     --chunked_processing \
#     --chunk_size 10000 \
#     --batch_size 512
