import os, json, argparse, logging, time, requests, re
from pathlib import Path
from typing import List, Dict, Any, Tuple, Optional
import numpy as np
import torch, faiss
from tqdm.auto import tqdm
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForCausalLM
from sklearn.cluster import AgglomerativeClustering

import asyncio, random
try:
    import aiohttp
except ImportError:
    aiohttp = None

from utils.config import CONFIG

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO,
                    format="%(asctime)s | %(levelname)s | %(message)s")
def _chunks(xs, n):
    for i in range(0, len(xs), n):
        yield xs[i:i+n]

async def _post_one(session, url, headers, payload, retry=4, base_wait=0.7, timeout=360):
    for t in range(retry+1):
        try:
            async with session.post(url, headers=headers, json=payload, timeout=timeout) as rsp:
                if rsp.status == 200:
                    data = await rsp.json()
                    return (True, data)
                if rsp.status in (429, 500, 502, 503, 504):
                    wait = base_wait * (2**t) + random.uniform(0, 0.2)
                    await asyncio.sleep(wait)
                    continue
                txt = await rsp.text()
                return (False, {"status": rsp.status, "text": txt})
        except Exception as e:
            wait = base_wait * (2**t) + random.uniform(0, 0.2)
            await asyncio.sleep(wait)
    return (False, {"status": "timeout", "text": "retry_exhausted"})

async def _post_many(url, headers, payloads, concurrency=8, retry=4, timeout=360):
    if aiohttp is None:
        raise RuntimeError("aiohttp . pip install aiohttp ")
    sem = asyncio.Semaphore(concurrency)
    async with aiohttp.ClientSession() as session:
        async def run_one(p):
            async with sem:
                return await _post_one(session, url, headers, p, retry=retry, timeout=timeout)
        tasks = [asyncio.create_task(run_one(p)) for p in payloads]
        return await asyncio.gather(*tasks)

def _normalize(vec: np.ndarray) -> np.ndarray:
    faiss.normalize_L2(vec)
    return vec

def _cosine(a: np.ndarray, b: np.ndarray) -> float:
    return float(np.dot(a, b))

@torch.no_grad()
def _batched_encode(model: SentenceTransformer, texts: List[str], bs: int = 32) -> np.ndarray:
    embs = []
    for st in tqdm(range(0, len(texts), bs), desc="Embedding", leave=False):
        embs.append(model.encode(texts[st:st + bs], convert_to_tensor=False, show_progress_bar=False))
    return np.vstack(embs).astype("float32")

def _load_records(path: str) -> List[Dict[str, Any]]:
    with open(path, "r", encoding="utf-8") as f:
        return [json.loads(l) for l in f if l.strip()]


class NaiveRAG:
    ABSTAIN_PAT = re.compile(
        r"(not (enough|sufficient) (information|context)|"
        r"insufficient (information|context)|"
        r"i\s*(am|'m)\s*not\s*sure|"
        r"cannot\s*(answer|determine)|"
        r"can't\s*(answer|determine)|"
        r"no\s*(information|evidence)\s*(provided|to answer)|"
        r"passages?\s*do\s*not\s*provide)",
        re.I
    )

    def __init__(self,
                 embedding_model_id: str,
                 use_local: bool,
                 answer_model_id: Optional[str] = None,
                 openrouter_key: Optional[str] = None):
        self.embed_model = SentenceTransformer(
            embedding_model_id, device="cuda" if torch.cuda.is_available() else "cpu"
        ).eval()
        
        self.embed_dim = self.embed_model.get_sentence_embedding_dimension()

        self.use_local = use_local
        self.answer_model_id = answer_model_id
        self.openrouter_key = openrouter_key or os.getenv("OPENROUTER_API_KEY", CONFIG.get("OPENROUTER_KEY"))

        self.documents: List[Dict[str, Any]] = []

        self.index: Optional[faiss.Index] = None
        self.docstore: List[Dict[str, Any]] = [] 
        self.doc_by_id: Dict[int, Dict[str, Any]] = {}

        if use_local and answer_model_id:
            self.tokenizer = AutoTokenizer.from_pretrained(
                answer_model_id, trust_remote_code=True, pad_token='<|endoftext|>'
            )
            self.lm = AutoModelForCausalLM.from_pretrained(
                answer_model_id,
                device_map="auto",
                torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
                trust_remote_code=True,
                low_cpu_mem_usage=True,
                attn_implementation="flash_attention_2" if torch.cuda.is_available() else None,
            ).eval()

    def load_jsonl(self, path: str) -> List[Dict]:
        with open(path, "r", encoding="utf-8") as f:
            data = [json.loads(l) for l in f if l.strip()]
        for rec in data:
            for cq in rec.get("supporting_docs", {}).values():
                for pid, passage in cq.get("documents", {}).items():
                    self.documents.append({"qid": rec.get("qid"), "page_id": pid, "passage": passage})
        return data

    @torch.no_grad()
    def build_index(self, batch_size: int = 16) -> None:

        dedup_seen = {}
        self.docstore = []
        for d in self.documents:
            txt = (d["passage"] or "").strip()
            if not txt:
                continue
            if txt in dedup_seen:
                continue
            doc_id = len(self.docstore)
            item = {"doc_id": doc_id, "passage": txt, "qid": d.get("qid"), "page_id": d.get("page_id")}
            self.docstore.append(item)
            dedup_seen[txt] = doc_id

        texts = [x["passage"] for x in self.docstore]
        embs = []
        for i in tqdm(range(0, len(texts), batch_size), desc="Embedding"):
            embs.append(self.embed_model.encode(texts[i:i + batch_size], convert_to_tensor=False, show_progress_bar=False))
        embs = np.vstack(embs).astype("float32")
        _normalize(embs)

        base = faiss.IndexFlatIP(self.embed_dim)
        idx = faiss.IndexIDMap2(base)
        ids = np.arange(len(self.docstore), dtype="int64")
        idx.add_with_ids(embs, ids)

        self.index = idx
        self.doc_by_id = {d["doc_id"]: d for d in self.docstore}

    def attach_index_and_meta(self, index_path: str, meta_path: str) -> None:
        self.index = faiss.read_index(index_path)
        with open(meta_path, "r", encoding="utf-8") as f:
            meta = json.load(f)
        self.docstore = meta["documents"]
        self.doc_by_id = {int(d["doc_id"]): d for d in self.docstore}

    @torch.no_grad()
    def search(self, query: str, k: int) -> List[Dict]:
        assert self.index is not None, "Index is not loaded. Call build_index() or attach_index_and_meta()."
        q_emb = self.embed_model.encode([query], convert_to_tensor=False).astype("float32")
        _normalize(q_emb)
        scores, labels = self.index.search(q_emb, k)
        hits = []
        for s, lab in zip(scores[0], labels[0]):
            if lab == -1:
                continue
            doc = self.doc_by_id.get(int(lab))
            if not doc:
                continue
            item = dict(doc)
            item["score"] = float(s)
            hits.append(item)
        return hits

    def _looks_like_abstention(self, text: str) -> bool:
        return bool(text and self.ABSTAIN_PAT.search(text))

    def _chat_openrouter(self, system: str, user: str,
                         temperature: float = 0.0, max_tokens: int = 512) -> str:
        headers = {"Authorization": f"Bearer {self.openrouter_key}",
                   "Content-Type": "application/json"}
        payload = {"model": self.answer_model_id,
                   "messages": [{"role": "system", "content": system},
                                {"role": "user", "content": user}],
                   "temperature": temperature, "max_tokens": max_tokens}
        rsp = requests.post(CONFIG['OPENROUTER_URL'], headers=headers, json=payload, timeout=360)
        rsp.raise_for_status()
        return rsp.json()["choices"][0]["message"]["content"].strip()

    def _greedy_generate_local(self, prompt: str) -> str:
        input_ids = self.tokenizer(prompt, return_tensors="pt").to(self.lm.device)
        gen_ids = self.lm.generate(
            **input_ids, max_new_tokens=512, temperature=0.0, do_sample=False,
            pad_token_id=self.tokenizer.eos_token_id, eos_token_id=self.tokenizer.eos_token_id
        )[0][input_ids["input_ids"].shape[1]:]
        return self.tokenizer.decode(gen_ids, skip_special_tokens=True).strip()

    def _greedy_generate_openrouter(self, prompt: str) -> str:
        system = ("You are a careful QA assistant. Prefer using the provided passages "
                  "when they contain sufficient evidence; otherwise you MUST rely on "
                  "your general knowledge to answer. Only say you cannot determine the answer "
                  "if it truly requires up-to-the-minute data or personal info. "
                  "If you cite a passage, include its index like [P3]. Keep answers concise and accurate.")
        return self._chat_openrouter(system, prompt, temperature=0.0, max_tokens=512)

    def _answer_general_knowledge(self, query: str) -> str:
        system = (
            "You are a knowledgeable QA assistant. When retrieval is poor or missing, "
            "you MUST answer using your general knowledge. Do NOT say there is insufficient "
            "information unless the question is inherently unanswerable."
        )
        user = f"Question: {query}\nAnswer:"
        if self.use_local and hasattr(self, "lm"):
            return self._greedy_generate_local(system + "\n\n" + user)
        else:
            return self._chat_openrouter(system, user, temperature=0.0, max_tokens=512)
   
    def answer(self, query: str, top_k: int = 10) -> Dict[str, Any]:
        docs = self.search(query, top_k)
        ctx = "\n".join(f"[P{i+1}] {d['passage']}" for i, d in enumerate(docs)) if docs else "(no passages retrieved)"
        system_guidelines = (
            "You are a careful QA assistant. Prefer using the provided passages when they contain sufficient evidence. "
            "If they don't, you MUST answer using your general knowledge. Only say you cannot determine the answer "
            "if it truly requires up-to-the-minute data or personal info. If you cite a passage, include its index."
        )
        user_prompt = (
            f"You are given up to {len(docs)} retrieved passages (may be empty) and a user question.\n"
            "Guidelines:\n"
            "1) Use a passage and cite [P#] if it clearly contains the answer.\n"
            "2) If passages lack evidence, answer from general knowledge.\n"
            "3) Be concise. Do not generate long answer. Just generate one word.\n\n"
            f"Passages:\n{ctx}\n\nQuestion: {query}\nAnswer:"
        )
        if self.use_local and hasattr(self, "lm"):
            ans = self._greedy_generate_local(system_guidelines + "\n\n" + user_prompt)
        else:
            ans = self._chat_openrouter(system_guidelines, user_prompt, temperature=0.0, max_tokens=512)

        if (not docs) or self._looks_like_abstention(ans):
            ans = self._answer_general_knowledge(query)
        return {"query": query, "answer": ans, "docs": docs}

    def _build_prompt(self, query: str, docs: List[Dict[str, Any]]) -> Tuple[str, str]:
        ctx = "\n".join(f"[P{i+1}] {d['passage']}" for i, d in enumerate(docs)) if docs else "(no passages retrieved)"
        system_guidelines = (
            "You are a careful QA assistant. Prefer using the provided passages when they contain sufficient evidence. "
            "If they don't, you MUST answer using your general knowledge. Only say you cannot determine the answer "
            "if it truly requires up-to-the-minute data or personal info. If you cite a passage, include its index."
        )
        user_prompt = (
            f"You are given up to {len(docs)} retrieved passages (may be empty) and a user question.\n"
            "Guidelines:\n"
            "1) Use a passage and cite [P#] if it clearly contains the answer.\n"
            "2) If passages lack evidence, answer from general knowledge.\n"
            "3) Be concise. Do not generate long answer. Just generate one word.\n\n"
            f"Passages:\n{ctx}\n\nQuestion: {query}\nAnswer:"
        )
        return system_guidelines, user_prompt

    def _openrouter_payload(self, system: str, user: str,
                            temperature: float = 0.0, max_tokens: int = 512) -> Dict[str, Any]:
        return {
            "model": self.answer_model_id,
            "messages": [{"role": "system", "content": system},
                         {"role": "user", "content": user}],
            "temperature": temperature,
            "max_tokens": max_tokens
        }

    def answer_many(self, queries: List[str], top_k: int = 10, concurrency: int = 8,
                    batch: int = 64, temperature: float = 0.0, max_tokens: int = 512) -> List[Dict[str, Any]]:

        if self.use_local and hasattr(self, "lm"):
            outs = []
            for q in queries:
                res = self.answer(q, top_k=top_k)
                outs.append(res)
            return outs

        packs = []
        for q in queries:
            docs = self.search(q, top_k)
            system, user = self._build_prompt(q, docs)
            packs.append((q, docs, self._openrouter_payload(system, user, temperature, max_tokens)))

        headers = {"Authorization": f"Bearer {self.openrouter_key}", "Content-Type": "application/json"}
        url = CONFIG["OPENROUTER_URL"]

        results: List[Dict[str, Any]] = []
        for chunk in tqdm(_chunks(packs, batch)):
            payloads = [p[2] for p in chunk]
            ok_list = asyncio.run(_post_many(url, headers, payloads, concurrency=concurrency))
            for (ok, data), (q, docs, _) in zip(ok_list, chunk):
                if ok:
                    try:
                        ans = data["choices"][0]["message"]["content"].strip()
                    except Exception:
                        ans = ""
                else:
                    ans = ""
                if (not docs) or self._looks_like_abstention(ans) or not ans:
                    ans = self._answer_general_knowledge(q)
                results.append({"query": q, "answer": ans, "docs": docs})
        return results

def build_index_cli(args):
    rag = NaiveRAG(args.embed_model, use_local=False)
    rag.load_jsonl(args.data)
    Path(Path(args.index).parent).mkdir(parents=True, exist_ok=True)
    Path(Path(args.meta).parent).mkdir(parents=True, exist_ok=True)
    rag.build_index()
    faiss.write_index(rag.index, args.index)
    with open(args.meta, "w", encoding="utf-8") as f:
        json.dump({"documents": rag.docstore}, f, ensure_ascii=False, indent=2)

def get_parser():
    p = argparse.ArgumentParser()
    sub = p.add_subparsers(dest="cmd", required=True)

    b = sub.add_parser("build-index")
    b.add_argument("--data", required=True)
    b.add_argument("--embed_model", default="Qwen/Qwen3-Embedding-8B")
    b.add_argument("--index", default="faiss/faiss.index")
    b.add_argument("--meta", default="faiss/faiss_meta.json")

    return p

if __name__ == "__main__":
    args = get_parser().parse_args()
    if args.cmd == "build-index":
        build_index_cli(args)
