import json

import numpy as np
import pandas as pd
from sentence_transformers import SentenceTransformer


def main():
    chunk1 = "servicenow-docs/data/train_chunks-00000-of-00002.parquet"
    chunk2 = "servicenow-docs/data/train_chunks-00001-of-00002.parquet"
    EMB_MODEL_NAME = "Qwen/Qwen3-Embedding-0.6B"

    # Load data from parquet chunks
    df1 = pd.read_parquet(chunk1)
    df2 = pd.read_parquet(chunk2)

    # Combine the chunks
    df = pd.concat([df1, df2], ignore_index=True)

    # Show number of records
    print(f"Number of records in chunk1: {len(df1)}")
    print(f"Number of records in chunk2: {len(df2)}")
    print(f"Total number of records: {len(df)}")

    # Compute number of distinct values in the 'text' column
    num_distinct_text = df["text"].nunique()
    print(f"Number of distinct values in 'text' column: {num_distinct_text}")

    # Get unique texts
    unique_texts = df["text"].unique().tolist()
    print(f"Number of unique texts: {len(unique_texts)}")

    # Calculate text lengths
    text_lengths = [len(text) for text in unique_texts]

    # Print some statistics
    print(f"Min text length: {min(text_lengths)}")
    print(f"Max text length: {max(text_lengths)}")
    print(f"Average text length: {sum(text_lengths) / len(text_lengths):.2f}")

    # Calculate 5th and 95th percentiles of text lengths
    percentile_5 = np.percentile(text_lengths, 5)
    percentile_95 = np.percentile(text_lengths, 95)
    median = np.median(text_lengths)

    print(f"5th percentile text length: {percentile_5:.2f}")
    print(f"95th percentile text length: {percentile_95:.2f}")
    print(f"Median text length: {median:.2f}")

    # Filter texts to keep only those between 100 and 6000 characters
    clean_texts = [text for text in unique_texts if 100 <= len(text) <= 6000]

    print(f"Size of clean_texts list: {len(clean_texts)}")
    print(f"Percentage of texts kept: {len(clean_texts) / len(unique_texts) * 100:.2f}%")

    emb_model = SentenceTransformer(EMB_MODEL_NAME, model_kwargs={"torch_dtype": "bfloat16"})
    emb_model.encode(["hi"])
    print("Model loaded successfully.")

    # Encode all clean texts using the loaded embedding model
    print(f"Encoding {len(clean_texts)} texts...")
    embeddings = emb_model.encode(clean_texts, show_progress_bar=True)

    # Convert to numpy array if not already
    embeddings = np.array(embeddings)

    print(f"Embeddings shape: {embeddings.shape}")
    print(f"Embeddings dtype: {embeddings.dtype}")

    with open("docs_embeddings.npy", "wb") as f:
        np.save(f, embeddings)
    with open("docs_embeddings.metadata.json", "w") as f:
        json.dump({"model": EMB_MODEL_NAME}, f, indent=2)
    with open("docs_embeddings.clean_texts.json", "w") as f:
        json.dump(clean_texts, f, indent=2, ensure_ascii=False)
    print("Embeddings saved successfully.")


if __name__ == "__main__":
    main()
