import os
import argparse
import yaml
import pandas as pd
import numpy as np
from tqdm import tqdm
from sklearn.cluster import KMeans
from sklearn.metrics import pairwise_distances_argmin_min
from openai import OpenAI
import hashlib
import json


# ============================================================
#                 EMBEDDING CACHE UTILITIES
# ============================================================

def hash_text(text: str) -> str:
    """Stable hash for caching."""
    return hashlib.sha256(text.encode("utf-8")).hexdigest()


def load_cache(cache_path: str):
    if os.path.exists(cache_path):
        with open(cache_path, "r") as f:
            raw = json.load(f)
            return {k: np.array(v, dtype=np.float32) for k, v in raw.items()}
    return {}


def save_cache(cache: dict, cache_path: str):
    serializable = {k: v.tolist() for k, v in cache.items()}
    with open(cache_path, "w") as f:
        json.dump(serializable, f)


# ============================================================
#          EMBEDDING WITH CACHE (recommended default)
# ============================================================

def embed_texts_cached(
    texts,
    model="text-embedding-3-large",
    batch_size=100,
    cache_path="embedding_cache.json"
):
    """
    Embed texts with persistent caching. Only missing entries trigger API calls.
    """

    client = OpenAI()
    cache = load_cache(cache_path)

    embeddings = [None] * len(texts)
    to_fetch = []
    to_fetch_idx = []

    # 1. Identify missing embeddings
    for idx, t in enumerate(texts):
        key = hash_text(t)
        if key in cache:
            embeddings[idx] = cache[key]
        else:
            to_fetch.append(t)
            to_fetch_idx.append(idx)

    # 2. Call OpenAI only for missing items
    if to_fetch:
        print(f"[INFO] Missing embeddings: {len(to_fetch)} / {len(texts)}")

        for i in tqdm(range(0, len(to_fetch), batch_size), desc="Embedding missing texts"):
            batch = to_fetch[i:i + batch_size]
            resp = client.embeddings.create(model=model, input=batch)
            for emb, orig_idx in zip(resp.data, to_fetch_idx[i:i+batch_size]):
                vec = np.array(emb.embedding, dtype=np.float32)
                key = hash_text(texts[orig_idx])
                cache[key] = vec
                embeddings[orig_idx] = vec

        save_cache(cache, cache_path)
        print(f"[CACHE] Updated embedding cache at {cache_path}")

    # Convert list → array
    return np.vstack(embeddings)


# ============================================================
#                     CLUSTERING FUNCTION
# ============================================================

def cluster_endowments(df, k, embed_model, batch_size, cache_path):
    """
    Cluster endow_text in the dataframe (df with eid, endow_text).
    Returns medoid indices and cluster labels.
    """

    texts = df["endow_text"].tolist()
    embeddings = embed_texts_cached(
        texts,
        model=embed_model,
        batch_size=batch_size,
        cache_path=cache_path
    )

    print(f"[INFO] Clustering into k={k} clusters…")

    km = KMeans(n_clusters=k, random_state=42, n_init="auto")
    km.fit(embeddings)

    # Map centroids back to real data rows
    closest_indices, _ = pairwise_distances_argmin_min(km.cluster_centers_, embeddings)

    return closest_indices, km.labels_


# ============================================================
#                       MAIN RUNNER
# ============================================================

def main(config_path):

    with open(config_path, "r") as f:
        config = yaml.safe_load(f)

    cluster_cfg = config["clustering"]
    embed_cfg = config.get("embedding", {})

    input_csv = cluster_cfg["input_csv"]
    output_csv = cluster_cfg["output_csv"]
    cluster_map_csv = cluster_cfg.get("cluster_map_csv")
    k = cluster_cfg["k"]

    embed_model = embed_cfg.get("model", "text-embedding-3-large")
    batch_size = embed_cfg.get("batch_size", 100)
    cache_path = embed_cfg.get("cache_path", "embedding_cache.json")

    # Load input CSV
    df = pd.read_csv(input_csv)
    print(f"[INFO] Loaded {len(df)} endowments from {input_csv}")

    if not {"eid", "endow_text"}.issubset(df.columns):
        raise ValueError("Input CSV must contain columns: eid, endow_text")

    # Ensure output directory exists
    os.makedirs(os.path.dirname(output_csv), exist_ok=True)
    if cluster_map_csv:
        os.makedirs(os.path.dirname(cluster_map_csv), exist_ok=True)

    # Run clustering
    medoid_indices, labels = cluster_endowments(
        df, k, embed_model, batch_size, cache_path
    )

    # Extract cluster medoids
    medoid_df = df.iloc[medoid_indices].reset_index(drop=True)

    # Save
    medoid_df.to_csv(output_csv, index=False)
    print(f"[DONE] Saved {len(medoid_df)} medoids to {output_csv}")

    if cluster_map_csv:
        df_map = pd.DataFrame({"eid": df["eid"], "cluster": labels})
        df_map.to_csv(cluster_map_csv, index=False)
        print(f"[DONE] Saved cluster assignments to {cluster_map_csv}")


# ============================================================
#                          CLI
# ============================================================

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", required=True, help="Path to YAML clustering config")
    args = parser.parse_args()
    main(args.config)