#!/usr/bin/env python3
"""
Compute cosine similarity to the Human centroid (original embedding space) across all points.

Outputs (per model):
  tables/cosine_to_human_LUAR.md / .tex
  tables/cosine_to_human_CISR.md / .tex
  tables/cosine_to_human_SD.md   / .tex
"""

import os
import sys
import random
import argparse
from typing import Dict, List

import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from sentence_transformers import SentenceTransformer
from transformers import AutoModel, AutoTokenizer, set_seed

# Project-local imports (assumes same environment as your plotting script)
from genpaths import *
from nicks_dpo.create_preference_data import get_luar_embeddings

from sklearn.cluster import KMeans
from typing import Dict, List, Tuple

def topk_cosine_to_humans(method_emb: torch.Tensor, human_emb: torch.Tensor, k: int = 10) -> Tuple[float, float]:
    """
    For each point in method_emb, compute cosine similarities to all human_emb,
    take the top-k, average per point, then aggregate mean±std across points.
    Assumes both are L2-normalized. Returns (mean, std).
    """
    # (N_m, D) @ (D, N_h) -> (N_m, N_h)
    sims = method_emb @ human_emb.T
    topk = torch.topk(sims, k=min(k, human_emb.size(0)), dim=1).values  # (N_m, k)
    per_point = topk.mean(dim=1)                                        # (N_m,)
    return per_point.mean().item(), per_point.std(unbiased=False).item()

def mixture_centroid_stats(method_emb: torch.Tensor, human_emb: torch.Tensor, n_clusters: int = 5, seed: int = 43) -> Tuple[float, float]:
    """
    Cluster human embeddings, form normalized centroids, then for each point in method_emb
    take cosine to the nearest centroid. Aggregate mean±std across points.
    """
    km = KMeans(n_clusters=n_clusters, n_init="auto", random_state=seed)
    H = human_emb.numpy()
    km.fit(H)
    C = torch.from_numpy(km.cluster_centers_).float()
    C = C / (C.norm(dim=1, keepdim=True) + 1e-12)  # normalize centroids

    sims = method_emb @ C.T                          # (N_m, K)
    nearest = sims.max(dim=1).values                # (N_m,)
    return nearest.mean().item(), nearest.std(unbiased=False).item()



# ----------------------------
# Config / defaults
# ----------------------------
PALETTE = [
    "#0072B2", "#E69F00", "#009E73", "#D55E00",
    "#CC79A7", "#000000", "#56B4E9", "#F0E442", "#999999",
]

DEFAULT_METHODS = ["Human", "Machine", "LLMOPT", "OUTFOX", "Paraphrasing", "DIPPER", "Prompting", "TinyStyler", "Ours"]
DEFAULT_MODELS  = ["LUAR", "CISR", "SD"]

# ----------------------------
# Helpers
# ----------------------------
def load_text(arr) -> List[str]:
    """
    arr: something like HUMAN / MACHINE constants from genpaths (list-like)
    We expect:
      arr[2] -> path to JSONL with lines
      arr[3] -> column name or index for text field
    """
    text = pd.read_json(arr[2], lines=True)[arr[3]].tolist()
    if len(text) > 0 and isinstance(text[0], list):
        text = [j[0] for j in text]
    return text


def get_embeddings_for_model(model_name: str,
                             method_data: Dict[str, List[str]],
                             luar_batch_size: int = 1024,
                             st_batch_size: int = 256) -> Dict[str, torch.Tensor]:
    """
    Returns dict: method_name -> L2-normalized torch.Tensor [N, D] in ORIGINAL space.
    """
    method_embeddings = {}

    if model_name != "LUAR":
        HF_id = "AnnaWegmann/Style-Embedding" if model_name == "CISR" else "StyleDistance/styledistance"
        model = SentenceTransformer(HF_id)
        model.eval()
        model.cuda()

        for method_name, texts in method_data.items():
            print(f"[{model_name}] Embedding: {method_name} (n={len(texts)})")
            normalize = method_name != "Human"
            with torch.no_grad():
                emb = model.encode(
                    texts,
                    batch_size=st_batch_size,
                    convert_to_tensor=True,
                    normalize_embeddings=normalize,
                    show_progress_bar=True
                ).cpu()
            # Already L2-normalized by SentenceTransformers when normalize_embeddings=True
            method_embeddings[method_name] = emb

    else:
        # LUAR
        model = AutoModel.from_pretrained("rrivera1849/LUAR-MUD", trust_remote_code=True)
        model.eval()
        model.cuda()
        tokenizer = AutoTokenizer.from_pretrained("rrivera1849/LUAR-MUD", trust_remote_code=True)

        for method_name, texts in method_data.items():
            print(f"[{model_name}] Embedding: {method_name} (n={len(texts)})")
            normalize = method_name != "Human"
            emb = get_luar_embeddings(texts, model, tokenizer, batch_size=luar_batch_size, single=False, normalize=normalize).cpu()
            method_embeddings[method_name] = emb

    return method_embeddings


def cosine_tables_for_model(model_name: str,
                            methods: List[str],
                            method_embeddings: Dict[str, torch.Tensor],
                            outdir: str = "tables"):
    """
    Compute cosine to human centroid, save Markdown and LaTeX tables.
    """
    os.makedirs(outdir, exist_ok=True)

    assert "Human" in method_embeddings, "Human method missing."
    human_emb = method_embeddings["Human"]                             # [Nh, D], L2-normalized
    human_centroid = human_emb.mean(dim=0)
    human_centroid = F.normalize(human_centroid, p=2,dim=-1)
    # human_centroid = human_centroid / (human_centroid.norm() + 1e-12) # normalize centroid

    rows = []
    from torch.nn import CosineSimilarity
    sim = CosineSimilarity(dim=-1)
    for method_name in methods:
        if method_name == "Human":
            continue
        emb = method_embeddings[method_name]                           # [N, D], L2-normalized
        cos = sim(emb, human_centroid.repeat(emb.size(0), 1))
        rows.append({
            "Method": method_name,
            "Mean Cosine": cos.mean().item(),
            "Std": cos.std(unbiased=False).item(),
            "N": emb.size(0),
        })

    stats_df = pd.DataFrame(rows).sort_values("Mean Cosine", ascending=False).reset_index(drop=True)

    # Nicely formatted tables
    md_table = stats_df[["Method", "Mean Cosine", "Std", "N"]].to_markdown(index=False, floatfmt=".4f")
    tex_table = stats_df[["Method", "Mean Cosine", "Std", "N"]].to_latex(index=False, float_format="%.4f", escape=True)

    md_path = os.path.join(outdir, f"cosine_to_human_{model_name}.md")
    tex_path = os.path.join(outdir, f"cosine_to_human_{model_name}.tex")
    stats_df.to_csv(os.path.join(outdir, f"cosine_to_human_{model_name}.csv"), index=False)

    with open(md_path, "w") as f:
        f.write(f"# {model_name}: Cosine Similarity to Human Centroid (all points)\n\n{md_table}\n")
    with open(tex_path, "w") as f:
        f.write(tex_table + "\n")

    print("\n" + "="*76)
    print(f"{model_name}: Cosine Similarity to Human Centroid (all points)")
    print("="*76)
    print(md_table + "\n")
    print(f"Saved: {md_path}")
    print(f"Saved: {tex_path}")
    
    # --- Extra: k-NN-to-Humans and Mixture-Centroid tables ---
    extras_rows_knn, extras_rows_mix = [], []
    human_emb = method_embeddings["Human"]  # already normalized

    for method_name in methods:
        emb = method_embeddings[method_name]

        # k-NN cosine (neighbors in Human set)
        knn_mean, knn_std = topk_cosine_to_humans(emb, human_emb, k=10)
        extras_rows_knn.append({"Method": method_name, "Top-10 kNN Mean": knn_mean, "Std": knn_std, "N": emb.size(0)})

        # Mixture centroids (nearest human cluster centroid)
        mix_mean, mix_std = mixture_centroid_stats(emb, human_emb, n_clusters=5, seed=43)
        extras_rows_mix.append({"Method": method_name, "MixCentroid Mean": mix_mean, "Std": mix_std, "N": emb.size(0)})

    # Save/print kNN table
    knn_df = pd.DataFrame(extras_rows_knn).sort_values("Top-10 kNN Mean", ascending=False).reset_index(drop=True)
    md_knn = knn_df.to_markdown(index=False, floatfmt=".4f")
    tex_knn = knn_df.to_latex(index=False, float_format="%.4f", escape=True)

    with open(os.path.join(outdir, f"knn_to_humans_{model_name}.md"), "w") as f:
        f.write(f"# {model_name}: Top-10 kNN Cosine to Humans (all points)\n\n{md_knn}\n")
    with open(os.path.join(outdir, f"knn_to_humans_{model_name}.tex"), "w") as f:
        f.write(tex_knn + "\n")

    print("\n" + "-"*76)
    print(f"{model_name}: Top-10 kNN Cosine to Humans")
    print("-"*76)
    print(md_knn + "\n")

    # Save/print Mixture-Centroid table
    mix_df = pd.DataFrame(extras_rows_mix).sort_values("MixCentroid Mean", ascending=False).reset_index(drop=True)
    md_mix = mix_df.to_markdown(index=False, floatfmt=".4f")
    tex_mix = mix_df.to_latex(index=False, float_format="%.4f", escape=True)

    with open(os.path.join(outdir, f"mixcentroid_to_humans_{model_name}.md"), "w") as f:
        f.write(f"# {model_name}: Nearest Human Mixture-Centroid Cosine (all points)\n\n{md_mix}\n")
    with open(os.path.join(outdir, f"mixcentroid_to_humans_{model_name}.tex"), "w") as f:
        f.write(tex_mix + "\n")

    print("\n" + "-"*76)
    print(f"{model_name}: Nearest Human Mixture-Centroid Cosine")
    print("-"*76)
    print(md_mix + "\n")


# ----------------------------
# Main
# ----------------------------
def main():
    parser = argparse.ArgumentParser(description="Compute cosine-to-human-centroid tables over ALL points.")
    parser.add_argument("--seed", type=int, default=43)
    parser.add_argument("--outdir", type=str, default="tables")
    parser.add_argument("--methods", nargs="*", default=DEFAULT_METHODS,
                        help="Methods to include (must match genpaths globals).")
    parser.add_argument("--models", nargs="*", default=DEFAULT_MODELS,
                        help="Which embedding models to use: subset of LUAR, CISR, SD.")
    parser.add_argument("--luar_batch_size", type=int, default=1024)
    parser.add_argument("--st_batch_size", type=int, default=256)
    args = parser.parse_args()

    set_seed(args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # 1) Load ALL text for each method (no subsampling)
    method_data: Dict[str, List[str]] = {}
    for method_name in args.methods:
        try:
            arr = globals()[method_name.upper()]   # from genpaths
        except KeyError:
            raise KeyError(f"Method {method_name} not found as a global in genpaths (looked for {method_name.upper()}).")
        texts = load_text(arr)
        # texts = texts[:100]
        print(f"Loaded {len(texts):6d} texts for {method_name}")
        method_data[method_name] = texts

    # 2) For each embedding model, compute embeddings and tables
    for model_name in args.models:
        method_embeddings = get_embeddings_for_model(
            model_name,
            method_data,
            luar_batch_size=args.luar_batch_size,
            st_batch_size=args.st_batch_size
        )
        cosine_tables_for_model(model_name, args.methods, method_embeddings, outdir=args.outdir)

    return 0

if __name__ == "__main__":
    sys.exit(main())
