import os
import argparse
import torch
import clip
import json
from PIL import Image
from tqdm import tqdm
from collections import Counter
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from sklearn.cluster import KMeans

def get_args():
    p = argparse.ArgumentParser(description="Anonymous CLIP token mining + clustering + intra-cluster refinement (recare variant)")
    p.add_argument("--image_dir", type=str, default="./images", help="Directory with input images (.jpg/.jpeg/.png)")
    p.add_argument("--reference_word", type=str, default="target", help="Reference word to define the target direction")
    p.add_argument("--reports_dir", type=str, default="./reports", help="Output directory for reports")
    p.add_argument("--vocab_cache", type=str, default="./cache/vocab_clip_vitb32_fp16.npy", help="Path to cache CLIP vocab embeddings")
    p.add_argument("--model_name", type=str, default="ViT-B/32", help="CLIP model name")

    p.add_argument("--top_k_per_image", type=int, default=50)
    p.add_argument("--top_n_tokens", type=int, default=100)

    p.add_argument("--batch_size_txt", type=int, default=2048)
    p.add_argument("--chunk_size_sim", type=int, default=4096)

    p.add_argument("--n_clusters", type=int, default=6)
    p.add_argument("--remove_mode", type=str, default="k", choices=["k", "percentile"], help="How to drop extreme clusters by residual")
    p.add_argument("--k_each_side", type=int, default=1)
    p.add_argument("--lower_pct", type=float, default=10.0)
    p.add_argument("--upper_pct", type=float, default=90.0)

    p.add_argument("--alpha", type=float, default=0.01, help="Threshold slack for recare removal")
    p.add_argument("--min_cluster_size", type=int, default=2)

    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
    p.add_argument("--plot_labels", action="store_true", help="Overlay token texts on t-SNE scatter")
    p.add_argument("--tsne_png", type=str, default="token_embedding.png")
    return p.parse_args()

def set_seed(seed: int = 42):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def build_or_load_vocab_embeddings(model, vocab_tokens, vocab_cache: str, batch_size_txt: int, device):
    os.makedirs(os.path.dirname(vocab_cache), exist_ok=True)
    if os.path.exists(vocab_cache):
        return np.load(vocab_cache, mmap_mode="r")

    all_feats = []
    with torch.no_grad():
        for s in tqdm(range(0, len(vocab_tokens), batch_size_txt), desc="Encoding vocab (batched)"):
            batch_tokens = vocab_tokens[s:s+batch_size_txt]
            toks = clip.tokenize(batch_tokens).to(device)
            text_feat = model.encode_text(toks)
            text_feat = text_feat / text_feat.norm(dim=-1, keepdim=True)
            all_feats.append(text_feat.half().cpu().numpy())
            del toks, text_feat
            torch.cuda.empty_cache()
    arr = np.concatenate(all_feats, axis=0)
    np.save(vocab_cache, arr.astype(np.float16))
    return np.load(vocab_cache, mmap_mode="r")

def encode_words(model, words, device):
    with torch.no_grad():
        toks = clip.tokenize(words).to(device)
        feats = model.encode_text(toks)
        feats = feats / feats.norm(dim=-1, keepdim=True)
    return feats.cpu().float()

def topk_sim_over_chunks(img_feat_cpu: np.ndarray, text_features_np: np.ndarray, k: int, chunk_size_sim: int):
    import heapq
    V = text_features_np.shape[0]
    heap = [] 
    for s in range(0, V, chunk_size_sim):
        e = min(s + chunk_size_sim, V)
        tf = text_features_np[s:e]  
        sims = (tf.astype(np.float32) @ img_feat_cpu.astype(np.float32))
        if len(heap) < k:
            for i, val in enumerate(sims):
                heapq.heappush(heap, (float(val), s + i))
                if len(heap) > k:
                    heapq.heappop(heap)
        else:
            th = heap[0][0]
            for i, val in enumerate(sims):
                fv = float(val)
                if fv > th:
                    heapq.heapreplace(heap, (fv, s + i))
                    th = heap[0][0]
    heap.sort(reverse=True) 
    sims_sorted = [x for x, _ in heap]
    idxs_sorted = [i for _, i in heap]
    return idxs_sorted, sims_sorted


def minmax_torch(x: torch.Tensor):
    mn, mx = x.min().item(), x.max().item()
    if mx - mn < 1e-12:
        return torch.zeros_like(x)
    return (x - mn) / (mx - mn)

def main():
    args = get_args()
    set_seed(args.seed)

    os.makedirs(args.reports_dir, exist_ok=True)

    # Load CLIP
    device = torch.device(args.device)
    model, preprocess = clip.load(args.model_name, device=device)
    model.eval()

    tokenizer = clip.simple_tokenizer.SimpleTokenizer()
    vocab_tokens = list(tokenizer.encoder.keys())

    text_features_np = build_or_load_vocab_embeddings(
        model=model,
        vocab_tokens=vocab_tokens,
        vocab_cache=args.vocab_cache,
        batch_size_txt=args.batch_size_txt,
        device=device,
    )
    D = text_features_np.shape[1]

    print("Mining tokens from images ...")
    most_common_counter = Counter()
    with torch.no_grad():
        for fname in tqdm(sorted(os.listdir(args.image_dir))):
            if not fname.lower().endswith((".jpg", ".jpeg", ".png")):
                continue
            img = Image.open(os.path.join(args.image_dir, fname)).convert("RGB")
            image = preprocess(img).unsqueeze(0).to(device)

            img_feat = model.encode_image(image)
            img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)
            img_feat_cpu = img_feat.squeeze(0).detach().cpu().numpy()  # [D]

            topk_indices, _ = topk_sim_over_chunks(
                img_feat_cpu=img_feat_cpu,
                text_features_np=text_features_np,
                k=args.top_k_per_image,
                chunk_size_sim=args.chunk_size_sim,
            )
            for i in topk_indices:
                most_common_counter.update([vocab_tokens[i]])

            del image, img_feat, img_feat_cpu
            torch.cuda.empty_cache()

    most_common_tokens = [tok for tok, _ in most_common_counter.most_common(args.top_n_tokens)]

    W = encode_words(model, most_common_tokens, device)     
    C = encode_words(model, [args.reference_word], device).T

    similarities = (W @ C).squeeze(1)                       
    PC = (C @ C.T) / (C.T @ C + 1e-12)                
    I = torch.eye(D, dtype=W.dtype)
    residuals_all = W @ (I - PC).T                          
    residual_norms_all = residuals_all.norm(dim=1)          

    emb_np = W.numpy()
    tsne = TSNE(n_components=2, random_state=args.seed, init="pca", learning_rate="auto")
    emb_2d = tsne.fit_transform(emb_np)

    plt.figure(figsize=(10, 8))
    plt.scatter(emb_2d[:, 0], emb_2d[:, 1], s=16)
    if args.plot_labels:
        for i, tok in enumerate(most_common_tokens):
            plt.text(emb_2d[i, 0], emb_2d[i, 1], tok, fontsize=7)
    plt.title("Token Embeddings (t-SNE)")
    plt.tight_layout()
    plt.savefig(os.path.join(args.reports_dir, args.tsne_png), metadata={})
    plt.close()

    kmeans = KMeans(n_clusters=args.n_clusters, random_state=args.seed, n_init="auto")
    labels = kmeans.fit_predict(emb_2d)

    token_counts = torch.tensor([most_common_counter[tok] for tok in most_common_tokens], dtype=torch.float32)
    sim = similarities.clone()
    res = residual_norms_all.clone()
    sim_n = minmax_torch(sim)
    res_n = minmax_torch(res)

    df = pd.DataFrame({
        "token": most_common_tokens,
        "cluster": labels,
        "count": token_counts.numpy().astype(int),
        "sim": sim.numpy().astype(float),
        "residual_norm": res.numpy().astype(float),
        "sim_n": sim_n.numpy().astype(float),
        "residual_norm_n": res_n.numpy().astype(float),
    })

    print("\n=== Initial clustering: tokens per cluster (before removal) ===")
    unique_clusters = sorted(np.unique(labels))
    for cid in unique_clusters:
        sub = pd.DataFrame({
            "token": most_common_tokens,
            "cluster": labels,
            "count": [most_common_counter[t] for t in most_common_tokens],
            "sim": similarities.numpy().astype(float),
            "residual_norm": residual_norms_all.numpy().astype(float),
        })
        sub = sub[sub["cluster"] == cid].copy()
        sub = sub.sort_values(["residual_norm", "sim", "count"], ascending=[False, False, False])

        toks = sub["token"].tolist()
        print(f"\n[Cluster {cid}]  (#tokens={len(toks)})")
        print(", ".join(toks))

    cluster_stats = (
        df.groupby("cluster")
          .agg(
              n_tokens=("token", "size"),
              total_count=("count", "sum"),
              mean_residual=("residual_norm", "mean"),
              mean_sim=("sim", "mean")
          )
          .reset_index()
          .sort_values("mean_residual", ascending=True)
    )

    print("\n=== Cluster-level stats (sorted by mean_residual ASC) ===")
    print(cluster_stats.to_string(index=False, float_format=lambda x: f"{x:.4f}"))
    cluster_stats.to_csv(os.path.join(args.reports_dir, "cluster_stats.csv"), index=False)

    n_clusters_detected = cluster_stats["cluster"].nunique()
    if args.remove_mode == "k":
        k_eff = max(0, min(args.k_each_side, n_clusters_detected // 2))
        ordered = cluster_stats.sort_values("mean_residual", ascending=True)["cluster"].tolist()
        remove_clusters = set(ordered[:k_eff] + (ordered[-k_eff:] if k_eff > 0 else []))
    else:
        low_thr = cluster_stats["mean_residual"].quantile(args.lower_pct / 100.0)
        high_thr = cluster_stats["mean_residual"].quantile(args.upper_pct / 100.0)
        remove_clusters = set(
            cluster_stats.loc[
                (cluster_stats["mean_residual"] <= low_thr) | (cluster_stats["mean_residual"] >= high_thr),
                "cluster"
            ].tolist()
        )

    df_removed = df[df["cluster"].isin(remove_clusters)].copy()
    df_kept    = df[~df["cluster"].isin(remove_clusters)].copy()

    print("\n=== Clusters to remove ===")
    print(sorted(list(remove_clusters)) if len(remove_clusters) else "(none)")

    print("\n=== Removed token count / Kept token count ===")
    print(f"removed: {len(df_removed)} / kept: {len(df_kept)}")

    kept_cluster_stats = (
        df_kept.groupby("cluster")
               .agg(n_tokens=("token", "size"),
                    total_count=("count", "sum"),
                    mean_residual=("residual_norm", "mean"),
                    mean_sim=("sim", "mean"))
               .reset_index()
               .sort_values("mean_residual", ascending=True)
    )
    print("\n=== Kept clusters (mean_residual ASC) ===")
    print(kept_cluster_stats.to_string(index=False, float_format=lambda x: f"{x:.4f}"))
    df_removed.to_csv(os.path.join(args.reports_dir, "token_metrics_removed_by_extremes.csv"), index=False)
    df_kept.to_csv(os.path.join(args.reports_dir, "token_metrics_kept_after_extremes.csv"), index=False)

    tok2row = {tok: i for i, tok in enumerate(most_common_tokens)}
    W_all = W 

    df_kept = df_kept.copy()
    df_kept["recare_remove"] = False
    df_kept["recare_score"] = np.nan

    removed_in_cluster = {}

    for c_id in sorted(df_kept["cluster"].unique()):
        tokens_c = df_kept.loc[df_kept["cluster"] == c_id, "token"].tolist()
        if len(tokens_c) < args.min_cluster_size:
            removed_in_cluster[c_id] = []
            continue

        idxs_c = [tok2row[t] for t in tokens_c]
        W_c = W_all[idxs_c] 

        sums = W_c.sum(dim=0, keepdim=True)
        counts = W_c.shape[0]
        P_minus_i = (sums - W_c) / (counts - 1)  

        proj_orth = (torch.eye(D, dtype=W_c.dtype) - PC)   
        D_minus = (P_minus_i @ proj_orth.T)                 
        d_norm2 = (D_minus ** 2).sum(dim=1)                  

        total = d_norm2.sum()
        mean_excl_i = (total - d_norm2) / (counts - 1)      

        thresh = (1.0 + args.alpha) * mean_excl_i
        remove_mask = (d_norm2 > thresh)

        removed_tokens = []
        for tok, rm, score in zip(tokens_c, remove_mask.tolist(), d_norm2.tolist()):
            df_kept.loc[df_kept["token"] == tok, "recare_score"] = float(score)
            if rm:
                df_kept.loc[df_kept["token"] == tok, "recare_remove"] = True
                removed_tokens.append(tok)
        removed_in_cluster[c_id] = removed_tokens

    df_final_removed = df_kept[df_kept["recare_remove"]].copy()
    df_final_kept    = df_kept[~df_kept["recare_remove"]].copy()

    print("\n=== Intra-cluster refinement summary ===")
    for c_id in sorted(removed_in_cluster.keys()):
        print(f"Cluster {c_id}: removed {len(removed_in_cluster[c_id])} tokens")
        if removed_in_cluster[c_id]:
            preview = ", ".join(removed_in_cluster[c_id][:20])
            print("  - " + preview + (" ..." if len(removed_in_cluster[c_id]) > 20 else ""))

    print("\nTotals -> removed:", len(df_final_removed), "/ kept:", len(df_final_kept))

    df_final_removed.to_csv(os.path.join(args.reports_dir, "token_metrics_removed_recare.csv"), index=False)
    df_final_kept.to_csv(os.path.join(args.reports_dir, "token_metrics_kept_after_recare.csv"), index=False)

    careset_tokens = df_final_kept["token"].tolist()

    output_dict = {
        "nudity_careset": [f"A {tok}" for tok in careset_tokens],
        "nudity_careset_photo": [f"A photo of {tok}" for tok in careset_tokens],
    }

    json_path = os.path.join(args.reports_dir, "careset_prompts.json")
    with open(json_path, "w", encoding="utf-8") as f:
        json.dump(output_dict, f, indent=2, ensure_ascii=False)

    print(f"\nCARE set prompts saved to {json_path}")

    print("\nDone. Outputs saved under", os.path.abspath(args.reports_dir))

if __name__ == "__main__":
    main()
