import os, re, json, time, random, argparse, logging, requests
from pathlib import Path
from typing import Any, Dict, List, Tuple, Mapping, MutableMapping, Set
import torch
from sentence_transformers import SentenceTransformer
from tqdm.auto import tqdm
from utils.config import CONFIG

from utils.prompt import build_query_decomposition_prompt

WIKI          = CONFIG["WIKI_ENDPOINT"]
OR_URL        = CONFIG["OPENROUTER_URL"]
OR_KEY        = CONFIG['OPENROUTER_KEY']
OR_MODEL      = "openai/gpt-4.1"
DEVICE        = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
DTYPE         = torch.float16 if DEVICE.type == "cuda" else torch.float32
torch.set_default_dtype(DTYPE)
SENT_MODEL    = SentenceTransformer("Qwen/Qwen3-Embedding-8B", device=str(DEVICE))
SENT_MODEL.max_seq_length = 256
SENT_MODEL.eval()
HEADERS_OR    = {"Authorization": f"Bearer {OR_KEY}", "Content-Type": "application/json"}
HEADERS_WIKI  = {"User-Agent": "search-and-answer/1.0"}
TOP_K, MAX_SRLIMIT, SUB_MAX   = 50, 20, 6
RELEV_TH, PASSLEN             = 0.45, 5
WAIT_REQ, RETRY, WTIME        = 0.3, 6, 25
MAX_LL_TRY, LL_TIMEOUT        = 4, 120
SHORT_MAX_PASSAGES            = 10

def _wiki(params: Dict[str, Any]) -> Dict[str, Any]:
    base = {"action": "query", "format": "json", "utf8": 1}
    for i in range(RETRY):
        try:
            r = requests.get(WIKI, params={**base, **params},
                             headers=HEADERS_WIKI, timeout=WTIME)
            if r.status_code == 429:
                time.sleep(int(r.headers.get("Retry-After", 2**i)) + random.random())
                continue
            r.raise_for_status()
            return r.json()
        except (requests.exceptions.ReadTimeout,
                requests.exceptions.ConnectTimeout):
            time.sleep(min(60, 2**i) + random.random())
    raise RuntimeError("Wiki request exceeded retries")
def wiki_search(q: str, k: int = TOP_K) -> List[Dict]:
    k = min(k, MAX_SRLIMIT)
    return _wiki({"list": "search", "srsearch": q, "srlimit": k}).get("query", {}).get("search", [])
def wiki_content(pid: int) -> str:
    pages = _wiki({"pageids": str(pid), "prop": "extracts",
                   "exintro": False, "explaintext": True}).get("query", {}).get("pages", {})
    return pages.get(str(pid), {}).get("extract", "")
def sents(text: str) -> List[str]:
    text = re.sub(r"\s+", " ", text.strip())
    return [s.strip() for s in re.split(r"[.!?]+", text) if len(s.strip()) > 10]
@torch.no_grad()
def rel_score(q: str, sent: str) -> float:
    task = "Given a web search query, retrieve relevant passages that answer the query"
    embs = SENT_MODEL.encode([f"Instruct: {task}\nQuery: {q}", sent],
                             convert_to_tensor=True, device=DEVICE)
    return float(torch.nn.functional.cosine_similarity(embs[0], embs[1], dim=0))
def passages_for(q: str, content: str) -> List[Dict]:
    st = sents(content)
    out = []
    for i, s in enumerate(st):
        r = rel_score(q, s)
        if r >= RELEV_TH:
            a, b = max(0, i-PASSLEN//2), min(len(st), i+PASSLEN//2+1)
            out.append({"sentence": s, "relevance_score": r,
                        "passage": " ".join(st[a:b]), "idx": i})
    return out
def openrouter(msgs: List[Dict], temp=0.2, maxtok=64, timeout=70,
               schema: Dict|None=None) -> str:
    payload = {"model": OR_MODEL, "messages": msgs, "temperature": temp,
               "max_tokens": maxtok}
    if schema: payload["response_format"] = {"type": "json_schema",
                                             "json_schema": schema}
    r = requests.post(OR_URL, headers=HEADERS_OR, json=payload, timeout=timeout)
    if r.status_code == 400 and schema:
        payload["response_format"] = {"type": "json_object"}  # fallback
        r = requests.post(OR_URL, headers=HEADERS_OR, json=payload, timeout=timeout)
    r.raise_for_status()
    return r.json()["choices"][0]["message"]["content"]
def decompose(q: str) -> List[str]:
    p = build_query_decomposition_prompt(q)
    
    try:
        res = openrouter([{"role":"system","content":"You are a helpful assistant for query decomposition."},
                          {"role":"user","content":p}], temp=0.2, maxtok=256)
        subs = [re.sub(r"^[*\-•]\s*", "", l).strip() for l in res.splitlines() if l.strip()]
        return subs[:SUB_MAX] or [q]
    except Exception:
        return [q]
# ─────────────────────── 문서/패시지 수집 (STEP-1) ───────────────────────
def supporting_docs(rec: Dict) -> Dict:
    out = {}
    
    for idx, cq in enumerate(rec["clarified_queries"], 1):
        used_page_ids = set()  # 중복 방지를 위한 page_id 추적
        documents = {}  # page_id를 키로 하는 문서 저장소
        sub_queries_list = []  # sub query 문자열들만 저장
        
        # Step 1: sub query들에서 각각 최고점수 문서 1개씩 수집
        for sq in decompose(cq):
            sub_queries_list.append(sq)
            
            search_results = wiki_search(sq, k=10)  # 10개 검색
            best_doc = None
            best_score = -1
            
            for s in search_results:
                pid, title = s["pageid"], s["title"]
                if pid in used_page_ids:  # 중복 방지
                    continue
                    
                cont = wiki_content(pid)
                if not cont: 
                    continue
                    
                # 기존 passages_for 함수 사용하여 relevance 계산
                passages = passages_for(sq, cont)
                if not passages: 
                    continue
                
                # 가장 높은 relevance_score를 가진 passage 선택
                best_passage = max(passages, key=lambda p: p["relevance_score"])
                
                if best_passage["relevance_score"] > best_score:
                    best_score = best_passage["relevance_score"]
                    best_doc = {
                        "page_id": pid,
                        "title": title,
                        "passage": best_passage["passage"]
                    }
                
                time.sleep(WAIT_REQ + random.random()*0.2)
            
            # 최고 점수 문서가 있으면 저장
            if best_doc:
                documents[str(best_doc["page_id"])] = best_doc["passage"]
                used_page_ids.add(best_doc["page_id"])
        
        # Step 2: 부족한 문서를 original clarified query로 보완 (총 10개까지)
        current_count = len(documents)
        if current_count < 10:
            needed_count = 10 - current_count
            original_search_results = wiki_search(cq, k=min(needed_count * 2, 20))
            
            collected_count = 0
            for s in original_search_results:
                if collected_count >= needed_count:
                    break
                    
                pid, title = s["pageid"], s["title"]
                if pid in used_page_ids:  # 이미 수집된 문서는 스킵
                    continue
                    
                cont = wiki_content(pid)
                if not cont:
                    continue
                    
                # original clarified query로 relevance 계산
                passages = passages_for(cq, cont)
                if not passages:
                    continue
                
                # 가장 높은 relevance_score를 가진 passage 선택
                best_passage = max(passages, key=lambda p: p["relevance_score"])
                
                documents[str(pid)] = best_passage["passage"]
                used_page_ids.add(pid)
                collected_count += 1
                
                time.sleep(WAIT_REQ + random.random()*0.2)
        
        # 각 clarified_query별로 결과 저장
        query_result = {
            "query": cq,
            "sub_queries": sub_queries_list,
            
        }
        
        out[f"clarified_query_{idx}"] = query_result
        out[f"clarified_query_{idx}"]["documents"] = documents
    
    return out
# ─────────────────────── 단답 생성 (STEP-2) ───────────────────────
JSON_SC = {"type":"object","properties":{"selected_index":{"type":"integer"},
            "short_answer":{"type":"string"}},"required":["selected_index","short_answer"],
            "additionalProperties":False}
def choose_passage(q:str,title:str,docs:List[Dict])->Tuple[int,str]|None:
    if not docs: return None
    pas = docs[:SHORT_MAX_PASSAGES]
    block = "\n".join(f"[{i}] {p['passage']}" for i,p in enumerate(pas))
    pr = f"""You are an open-book QA assistant.
Respond JSON only: {JSON_SC}
Question: {q}
Article title: {title}
Passages:
{block}
JSON:"""
    for t in range(1, MAX_LL_TRY+1):
        try:
            js = json.loads(openrouter([{"role":"system","content":"Return ONLY valid JSON."},
                                        {"role":"user","content":pr}],
                                       temp=0.0, maxtok=64, timeout=LL_TIMEOUT,
                                       schema=JSON_SC))
            if js["selected_index"]!=-1 and js["short_answer"].upper()!="UNKNOWN":
                return js["selected_index"], js["short_answer"].strip()
            return None
        except Exception as e:
            if t==MAX_LL_TRY: return None
            time.sleep(2**(t-1))

def short_answers(rec: Dict) -> Tuple[bool, Dict]:
    sel, ok = {}, True
    
    for i in (1, 2):
        cq_key = f"clarified_query_{i}"
        cqentry = rec["supporting_docs"][cq_key]
        
        # documents에서 passage들을 가져와서 choose_passage용 형태로 변환
        passages_for_selection = []
        for page_id, passage_text in cqentry["documents"].items():
            passages_for_selection.append({
                "passage": passage_text,
                "page_id": int(page_id)
            })
        
        # choose_passage 함수를 위해 임시로 title을 추가 (실제로는 사용되지 않음)
        best = None
        if passages_for_selection:
            # choose_passage 함수 호출을 위한 형태로 변환
            block = "\n".join(f"[{idx}] {p['passage']}" for idx, p in enumerate(passages_for_selection))
            
            # choose_passage와 유사한 로직으로 직접 처리
            pr = f"""You are an open-book QA assistant.
Respond JSON only: {JSON_SC}
Question: {rec["clarified_queries"][i-1]}
Passages:
{block}
JSON:"""
            
            for t in range(1, MAX_LL_TRY+1):
                try:
                    js = json.loads(openrouter([{"role":"system","content":"Return ONLY valid JSON."},
                                               {"role":"user","content":pr}],
                                              temp=0.0, maxtok=64, timeout=LL_TIMEOUT,
                                              schema=JSON_SC))
                    
                    if (js["selected_index"] != -1 and 
                        js["selected_index"] < len(passages_for_selection) and
                        js["short_answer"].upper() != "UNKNOWN"):
                        
                        selected_passage = passages_for_selection[js["selected_index"]]
                        best = {
                            "query": rec["clarified_queries"][i-1],
                            "page_id": selected_passage["page_id"],
                            "title": None,  # title 정보가 없으므로 None
                            "passage": selected_passage["passage"],
                            "short_answer": js["short_answer"].strip()
                        }
                        break
                except Exception as e:
                    if t == MAX_LL_TRY:
                        break
                    time.sleep(2**(t-1))
        
        if not best:
            ok = False
            best = {
                "query": rec["clarified_queries"][i-1],
                "page_id": None,
                "title": None,
                "passage": "",
                "short_answer": "UNKNOWN"
            }
        
        sel[f"clarified_query_{i}"] = best
    
    return ok, sel

# ─────────────────────── 장답 생성 (STEP-3) ───────────────────────
LONG_SCHEMA={"type":"object","properties":{"long_answer":{"type":"string"}},
             "required":["long_answer"],"additionalProperties":False}
def long_answer(orig_q,cq1,a1,cq2,a2)->str:
    prompt=f"""You are an expert QA assistant.
Combine A1 and A2 into a coherent long answer (1-3 sentences). No new facts.
Return JSON: {LONG_SCHEMA}
OQ: {orig_q}
{cq1} → A1={a1}
{cq2} → A2={a2}
JSON:"""
    for t in range(1, MAX_LL_TRY+1):
        try:
            js=json.loads(openrouter([{"role":"system","content":"Return ONLY valid JSON."},
                                      {"role":"user","content":prompt}],
                                     temp=0.0, maxtok=128, schema=LONG_SCHEMA))
            return js["long_answer"].strip()
        except Exception:
            if t==MAX_LL_TRY: return "UNKNOWN"
            time.sleep(2**(t-1))
# ──────────────────────────────── IO ────────────────────────────────
def load_jsonl(p:str)->List[Dict]: return [json.loads(l) for l in Path(p).read_text().splitlines() if l.strip()]
def save_jsonl(rows:List[Dict],p:str):
    Path(p).parent.mkdir(parents=True,exist_ok=True)
    with open(p,"w",encoding="utf-8")as f:
        for r in rows: f.write(json.dumps(r,ensure_ascii=False)+"\n")
# ──────────────────────────────── MAIN ────────────────────────────────
def main():
    ap=argparse.ArgumentParser()
    ap.add_argument("--step",choices=["docs","short","long","all"],default="all")
    ap.add_argument("--in",dest="inp",default="dataset/musique_generalize_clarified_sample.jsonl")
    ap.add_argument("--out",dest="outp",default="dataset/musique_generalize_long_answers_decomposition.jsonl")
    args=ap.parse_args()
    data=load_jsonl(args.inp)
    data = data[:4000]
    if args.step in ("docs","all"):
        for r in tqdm(data,desc="Collecting docs"):
            r["supporting_docs"]=supporting_docs(r)
        if args.step=="docs":
            save_jsonl(data,args.outp); return
    if args.step in ("short","all"):
        tmp=[]
        for r in tqdm(data,desc="Generating short answers"):
            ok,sel=short_answers(r)
            if ok:
                r["selected_support"]=sel
                tmp.append(r)
        data=tmp
        if args.step=="short":
            save_jsonl(data,args.outp); return
    if args.step in ("long","all"):
        for r in tqdm(data,desc="Generating long answers"):
            ss=r["selected_support"]
            r["long_answer"]=long_answer(r["original_query"],
                                         ss["clarified_query_1"]["query"],
                                         ss["clarified_query_1"]["short_answer"],
                                         ss["clarified_query_2"]["query"],
                                         ss["clarified_query_2"]["short_answer"])
    save_jsonl(data,args.outp)
    print(f":white_check_mark: saved {len(data)} records → {args.outp}")
if __name__=="__main__": main()

'''
python dataset_generation.py --in dataset/train_musique_semantic_clarified.jsonl --out dataset/train_semantic.jsonl

'''