#!/usr/bin/env python3
"""
graph_retrieval.py — Build graph from train jsonl and provide graph-aware retriever
- Builds kNN graph with cosine similarity using sentence-transformers
- Saves graph + embedding index
- Provides a GraphRetriever class for pipelines A/B with:
    * class-balance (ensure FLAGGED presence)
    * MMR diversity
    * multi-objective score: alpha*sim - beta*redundancy + gamma*balance_bonus
"""

import os, json, math, pickle
from typing import List, Dict, Tuple
import numpy as np
import networkx as nx
from sklearn.neighbors import NearestNeighbors
from sentence_transformers import SentenceTransformer

def load_jsonl(path: str) -> List[Dict]:
    out = []
    with open(path, "r", encoding="utf-8") as f:
        for ln in f:
            out.append(json.loads(ln))
    return out

def build_graph(train_jsonl: str, out_dir: str, model_name: str = "sentence-transformers/all-MiniLM-L6-v2", k: int = 10):
    os.makedirs(out_dir, exist_ok=True)
    rows = load_jsonl(train_jsonl)
    texts = [r["text"] for r in rows]
    labels = [r.get("label","UNKNOWN") for r in rows]
    ids = [r["text_id"] for r in rows]
    model = SentenceTransformer(model_name)
    X = model.encode(texts, batch_size=64, convert_to_numpy=True, show_progress_bar=True, normalize_embeddings=True)
    nn = NearestNeighbors(n_neighbors=min(k+1, len(X)), metric="cosine").fit(X)
    dist, idx = nn.kneighbors(X)
    # cosine distance -> similarity
    sim = 1.0 - dist

    G = nx.Graph()
    for i, tid in enumerate(ids):
        G.add_node(tid, label=labels[i])
    for i, tid in enumerate(ids):
        for j, neigh in enumerate(idx[i][1:]):  # skip self
            s = float(sim[i][j+1])
            if s <= 0: continue
            G.add_edge(ids[i], ids[neigh], weight=s)
    with open(os.path.join(out_dir,"embeddings.npz"), "wb") as f:
        np.savez(f, X=X, ids=np.array(ids, dtype=object))
    with open(os.path.join(out_dir,"graph.gpickle"), "wb") as f:
        pickle.dump(G, f)
    print(f"[graph] nodes={G.number_of_nodes()}, edges={G.number_of_edges()} → {out_dir}/graph.gpickle")

class GraphRetriever:
    def __init__(self, store_dir: str, model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
        with open(os.path.join(store_dir,"embeddings.npz"), "rb") as f:
            data = np.load(f, allow_pickle=True)
            self.X = data["X"]
            self.ids = data["ids"].tolist()
        with open(os.path.join(store_dir,"graph.gpickle"), "rb") as f:
            self.G = pickle.load(f)
        self.model = SentenceTransformer(model_name)

    def _encode(self, text: str) -> np.ndarray:
        v = self.model.encode([text], convert_to_numpy=True, normalize_embeddings=True)[0]
        return v

    def retrieve(self, text: str, labels: Dict[str,str], top_k: int = 10, alpha: float = 1.0, beta: float = 0.5, gamma: float = 0.5, balance_target: float = 0.5) -> List[Tuple[str, float]]:
        """
        Multi-objective reranking:
          score = alpha*sim - beta*redundancy + gamma*balance_bonus
        Returns list of (text_id, score)
        """
        q = self._encode(text)
        sims = (self.X @ q)  # since normalized
        cand_idx = np.argsort(-sims)[: max(50, top_k*5)]
        selected = []
        selected_vecs = []
        counts = {"FLAGGED":0, "NOT FLAGGED":0}

        def balance_bonus(lbl:str)->float:
            tot = sum(counts.values())+1e-6
            cur_ratio = counts["FLAGGED"]/tot if tot>0 else 0.0
            # encourage FLAGGED until reach target
            if lbl=="FLAGGED":
                return (balance_target - cur_ratio)
            else:
                # small bonus if too many FLAGGED
                return (cur_ratio - balance_target)

        for idx in cand_idx:
            tid = self.ids[idx]
            lbl = labels.get(tid, "UNKNOWN")
            sim = float(sims[idx])
            # redundancy penalty via max cosine with already selected
            red = 0.0
            if selected_vecs:
                v = self.X[idx]
                dots = (np.dot(selected_vecs, v))
                red = float(np.max(dots))
            bonus = balance_bonus(lbl)
            score = alpha*sim - beta*red + gamma*bonus
            selected.append((tid, score, idx, lbl))
            # take top_k with running update
        # final selection
        selected.sort(key=lambda x: -x[1])
        final = []
        for tid, sc, i, lbl in selected:
            if len(final) >= top_k: break
            final.append((tid, sc))
            counts[lbl] = counts.get(lbl,0)+1
            selected_vecs.append(self.X[i])
        return final
