import json, ast, re, string, time, os
from typing import List, Dict, Tuple
import openai
from openai import OpenAI
import google.generativeai as genai
# ================================
# 0) OpenAI setup (read from env)
# ================================
API_KEY = #API KEY

os.environ['OPENAI_API_KEY'] = API_KEY
openai.api_key = os.getenv("OPENAI_API_KEY")
client = OpenAI(api_key=API_KEY)

# ---------------- -------------------------------- ---------
os.environ["GOOGLE_API_KEY"] =  #API1
genai.configure(api_key=os.environ['GOOGLE_API_KEY'])
# ----------------------------------------------------------
# ---------- openrouter --------------
#
# client = OpenAI(
#   base_url="https://openrouter.ai/api/v1",
#   api_key= #API KEY
# )

# --------- deepseek ------------ --------
# client = OpenAI(api_key=<API_KEY>, base_url="https://api.deepseek.com")
# -------------------------------- ---------

# model_type="gpt" 
model_type="gemini" 

model="gpt-4o"
# model =  "deepseek/deepseek-chat-v3.1:free"
# model = "deepseek-chat" #deepseek api
# model = "deepseek/deepseek-r1:free"
# ---------------------------------------------------------- 
# ----------------------------------------------------------

def call_llm(prompt, temperature=0.3, model_type="gpt", max_tokens=512):

    response = ""

    if model_type == "gpt" or model_type == "openrouter":
        time.sleep(4)
        response = get_completion(prompt, temperature=temperature, max_tokens=max_tokens)
        # response = get_completion_openrouter(prompt, temperature=temperature)
        # print("response:", response)

    elif model_type == "gemini":
        time.sleep(4)
        response = get_completion_gemmini(prompt, temperature=temperature)
    else:
        print("Error at call_llm: Model not specified.\n")
    return response

def get_completion(messages: List[Dict[str, str]], model=model, temperature=0.1, max_tokens=512):
    delay = 1
    print("call llm - gpt/openrouter")
    while True:
        try:
            resp = client.chat.completions.create(
                model=model,
                messages=messages,
                temperature=temperature,
                max_tokens=max_tokens,
            )
            return resp.choices[0].message.content
        except Exception as e:
            print("[OpenAI error]", e, "—retrying in", delay, "s")
            time.sleep(delay)
            delay = min(delay * 2, 60)

def get_completion_gemmini(prompt, temperature=0.1, max_tokens = 512):
    print("gemini")

    # Access the first dictionary (the system message)
    system_message = prompt[0]
    system_content = system_message['content']

    # Access the second dictionary (the user message)
    user_message = prompt[1]
    user_content = user_message['content']

    # Create the model
    # See https://ai.google.dev/api/python/google/generativeai/GenerativeModel
    generation_config = {
        "temperature": temperature,
        "top_p": 0.95,
        "top_k": 64,
        "max_output_tokens": max_tokens,
        # "stop_sequences": [";"],
        "response_mime_type": "text/plain",
        # "response_mime_type": "application/json"
    }

    model = genai.GenerativeModel(
        model_name="gemini-2.5-flash-lite",
        # model_name="gemini-2.0-flash",
        generation_config=generation_config,
        system_instruction=system_content,
        # safety_settings = Adjust safety settings
        # See https://ai.google.dev/gemini-api/docs/safety-settings
    )

    chat_session = model.start_chat(
        history=[
        ]
    )

    try:
        response = chat_session.send_message(user_content)

    except genai.types.StopCandidateException as e:
        print("Gemini refused to generate a response:", e)
        return None
    # print("LLM Response: ", response.text)
    return  response.text

# ================================
# 0.1) MuSiQue loader (JSONL)
# ================================
def read_jsonl(path: str):
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if line:
                yield json.loads(line)

def musique_context_gold(ex: dict) -> Tuple[List[Tuple[int, str, str]], List[int]]:
    """
    Returns:
      context: list of (idx, title, paragraph_text)
      gold_idx: list of idx where is_supporting == True
    """
    context = []
    gold = []
    for p in ex.get("paragraphs", []):
        idx = int(p.get("idx"))
        title = p.get("title", "")
        para = (p.get("paragraph_text") or "").strip()
        context.append((idx, title, para))
        if p.get("is_supporting", False):
            gold.append(idx)
    return context, gold

# ================================
# 1) Question analyzer (unchanged)
# ================================
def question_analyzer_agent(question: str) -> List[str]:
    system_header = (
        "You are a Question Analyzer Agent in a multi-hop QA system. "
        "Break the question into a small ordered list of subquestions. "
        'Return JSON like {"Subquestions": ["...", "..."]}.'
    )
    prompt = [
        {"role": "system", "content": system_header},
        {"role": "user", "content": f'Question: {question}\nReturn: {{"Subquestions": ["...", "..."]}}'}
    ]
    response = call_llm(prompt, temperature=0.1, model_type=model_type)
    try:
        # result = ast.literal_eval(response.strip())
        response = response.replace("```json", "").replace("```", "").strip()
        data = json.loads(response)
        subqs = data["Subquestions"]
        # return result.get("Subquestions", [])
        return subqs
    except Exception as e:
        print("Failed to parse question analyzer output:", e)
        return []

# ================================
# 1b) Initial selector (INDICES)
# ================================
def zero_shot_selector_agent(question, subqs, context):
    """
    Ask the LLM to return ONLY a Python list of indices, e.g. [0, 3, 7].
    """
    system_header = (
        "You are an expert evidence selection agent for QA. "
        "Given a question and a list of paragraph INDICES, select the MINIMUM set of indices "
        "strictly necessary to answer the question. Return ONLY a Python list of integers like [3, 7] in json format."
    )
    prompt = [
        {"role": "system", "content": system_header},
        {"role": "user", "content": (
            f"Question: {question}\n"
            # f"Subquestions: {subqs}\n"
            f"Context (idx, title, paragrph): {context}\n"
            "Selected paragraph list: <python list of integers>"
        )}
    ]
    return call_llm(prompt, temperature=0.1, model_type=model_type)

def parse_int_list(raw: str, allowed: set) -> List[int]:
    cleaned = raw.replace("```json", "").replace("```", "").strip()
    try:
        # Common wrappers like {"answer":[...]} or {"indices":[...]}
        if cleaned.startswith("{"):
            obj = ast.literal_eval(cleaned)
            for k in ("answer", "indices", "passages", "selected", "result"):
                if isinstance(obj, dict) and k in obj and isinstance(obj[k], list):
                    return [int(x) for x in obj[k] if isinstance(x, int) and x in allowed]
        # Direct list
        lst = ast.literal_eval(cleaned)
        if isinstance(lst, list):
            return [int(x) for x in lst if isinstance(x, int) and x in allowed]
    except Exception as e:
        print("Could not parse int list:", e)
    return []

def dedup_preserve_order(seq):
    seen = set()
    out = []
    for x in seq:
        if x not in seen:
            seen.add(x)
            out.append(x)
    return out

def get_initial_indices(question, subqs, context, fallback_topk=20):
    raw = zero_shot_selector_agent(question, subqs, context)
    ints = parse_int_list(raw, allowed=set([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]))
    if not ints:
        ints = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]  # deterministic fallback
    return dedup_preserve_order(ints)

# ================================
# 2) Selector & Adder (indices)
# ================================
def selector_agent(question, subqs, context, current_indices):
    system_header = (
        "You are a **Selector Agent** in a multi-hop QA system. Your task is to retrive the required paragraphs from a set of candidate paragraphs for answering aquestion. Your goal is to MAXIMIZE PRECISION and MINIMIZE FALSE POSITIVES.\n\n"
        "You are given:\n"
        "- A complex multi-hop question\n"
        "- Subquestions that represent reasoning steps\n"
        "- Candidate context: Passages in python tupple. Format: (paragraph_index, title, paragraph)\n"
        "- CURRENT SELECTED Passage Indices from the Candidate Context: a list of indices of the paragraphs tha already selected evidence items, each represented by ['title', sentence_index] \n\n"
        "Your job is to carefully REMOVE only those paragraphs that are DEFINITELY irrelevant for answering ANY of the subquestions.\n\n"
        "And SELECT ONLY those candidate paragraphs that are LIKELY to support answering any subquestion.\n"
        "Append only relevant indices, no duplicates. \n"
        "Return the updated paragraph list only (Python list of integers) without extra words in json format.\n"
    )
    prompt = [
        {"role": "system", "content": system_header},
        {"role": "user", "content": (
            f"Question: {question}\n"
            f"Subquestions: {subqs}\n"
            f"Context (idx, title, paragrph): {context}\n"
            f"Current Selected Indices: {current_indices}\n"
            "Updated paragraph list: <python list of integers>"
        )}
    ]
    return call_llm(prompt, temperature=0.1, model_type=model_type)

def adder_agent(question, subqs, context, current_indices):
    system_header = (
        "You are an **Adder Agent** in a multi-hop QA system. Your task is to retrive the required paragraphs from a set of candidate paragraphs for answering aquestion. Your goal is to MAXIMIZE RECALL while MINIMIZING FALSE POSITIVES.\n\n"
        "You are given:\n"
        "- A complex multi-hop question\n"
        "- Subquestions that represent reasoning steps\n"
        "- Candidate context: Passages in python tupple. Format: (paragraph_index, title, paragraph)\n"
        "- CURRENT SELECTED Passage Indices from the Candidate Context: a list of indices of the paragraphs tha already selected evidence items, each represented by ['title', sentence_index] \n\n"
        "Your job is to find and ADD ONLY those candidate paragraphs that are LIKELY to support answering any subquestion.\n"
        "Append only relevant indices, no duplicates. \n"
        "Return the updated paragraph list only (Python list of integers) without extra words in json format.\n"
        
    )
    prompt = [
        {"role": "system", "content": system_header},
        {"role": "user", "content": (
            f"Question: {question}\n"
            f"Subquestions: {subqs}\n"
            f"Context (idx, title, paragrph): {context}\n"
            f"Current Selected Indices: {current_indices}\n"
            "Updated paragraph list: <python list of integers>"
        )}
    ]
    return call_llm(prompt, temperature=0.1, model_type=model_type)

def run_agent_loop(question, subqs, context, initial_indices, n_rounds=3):
    selected = dedup_preserve_order(initial_indices)
    all_set = set([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19])
    for r in range(n_rounds):
        print(f"\n--- Iteration {r+1} ---")

        add_raw = adder_agent(question, subqs, context, selected)
        # print("Adder Output:", add_raw)
        add_list = parse_int_list(add_raw, allowed=all_set)
        if add_list == []:
            add_list = selected
        selected = dedup_preserve_order(selected + add_list)
        print("After Adder:", selected)

        sel_raw = selector_agent(question, subqs, context, selected)
        # print("Selector Output:", sel_raw)
        sel_list = parse_int_list(sel_raw, allowed=all_set)
        if sel_list == []:
            sel_list = selected
        selected = dedup_preserve_order(sel_list)
        print("After Selector:", selected)  

        if not selected:
            print("❌ Selector dropped everything; reverting to initial and stopping.")
            selected = dedup_preserve_order(initial_indices)
            break
    return selected

# ================================
# 3) Text assembly from indices
# ================================
def build_text_from_indices(context: List[Tuple[int, str, str]], selected_indices: List[int]) -> str:
    idx2para = {idx: (title, para) for idx, title, para in context}
    blocks = []
    for i in selected_indices:
        if i in idx2para:
            t, p = idx2para[i]
            blocks.append(f"[{i}] {t}:\n- {p}")
    return "\n\n".join(blocks)

def build_text_all(context: List[Tuple[int, str, str]]) -> str:
    return "\n\n".join(f"[{idx}] {t}:\n- {p}" for idx, t, p in context if p)

# ================================
# 4) QA + metrics
# ================================
# def qa_agent(question: str, evidence: str) -> str:
#     system_header = (
#         "You are a QA agent. Answer concisely using ONLY the provided evidence. "
#         "If not answerable, reply exactly: Not Answerable"
#     )
#     prompt = [
#         {"role": "system", "content": system_header},
#         {"role": "user", "content": f"Question: {question}\n\nEvidence:\n{evidence}\n\nAnswer:"}
#     ]
#     return call_llm(prompt, max_tokens=8).strip()

def qa_agent(question: str, evidence: str) -> str:
    """
    Generate an answer to the question given the retrieved evidence.
    """
    system_header = (
        "You are a question answering agent. Given a question and the supporting evidence, "
        "provide a concise, factual and short answer based only on the evidence without other words. "
        "If the answer cannot be determined from the evidence, reply with 'Not Answerable'."
    )

    prompt = [
        {"role": "system", "content": system_header},
        {"role": "user", "content": (
            f"Question: {question}\n\n"
            f"Evidence:\n{evidence}\n\n"
            "Answer:"
        )}
    ]

    # print("QA Agent Prompt:", prompt)
    return call_llm(prompt, max_tokens=16, model_type=model_type).strip()

def normalize_answer(s):
    def remove_articles(text): return re.sub(r'\b(a|an|the)\b', ' ', text)
    def white_space_fix(text): return ' '.join(text.split())
    def remove_punc(text): return ''.join(ch for ch in text if ch not in set(string.punctuation))
    def lower(text): return text.lower()
    return white_space_fix(remove_articles(remove_punc(lower(s))))

def evaluate_qa(predicted_answer: str, gold_answer: str) -> float:
    return 1.0 if (normalize_answer(predicted_answer) == normalize_answer(gold_answer)) else 0.0

def retrieval_metrics(pred_idx: List[int], gold_idx: List[int]) -> Dict[str, object]:
    pred_set, gold_set = set(pred_idx), set(gold_idx)
    tp = len(pred_set & gold_set)
    fp = len(pred_set - gold_set)
    fn = len(gold_set - pred_set)
    precision = tp / (tp + fp) if (tp + fp) else 0.0
    recall = tp / (tp + fn) if (tp + fn) else 0.0
    f1 = (2*precision*recall / (precision + recall)) if (precision + recall) else 0.0
    fpr = fp / (tp + fp) if (tp + fp) else 0.0
    return {
        "precision": round(precision,4), "recall": round(recall,4), "f1": round(f1,4),
        "false_positive_rate": round(fpr,4),
        "true_positives": sorted(list(pred_set & gold_set)),
        "false_positives": sorted(list(pred_set - gold_set)),
        "false_negatives": sorted(list(gold_set - pred_set)),
    }

# ================================
# 5) Save helper
# ================================
def append_to_json_file(file_path, question_id, value):
    if not os.path.exists(file_path):
        with open(file_path, 'w') as jf:
            json.dump({}, jf)
        print(f"Created {file_path}")
    with open(file_path, 'r') as jf:
        data = json.load(jf)
    data[question_id] = value
    with open(file_path, 'w') as jf:
        json.dump(data, jf, indent=2)

# ================================
# 6) Main
# ================================
if __name__ == "__main__":
    # musique_path = "musique/musique_full_v1.0_dev.jsonl"
    musique_path = "musique/musique_dev_500_samples.jsonl"  # change to your path
    
    N = 70
    START = 431
    END = START + N
    N_ROUNDS = 3
    SKIP_UNANSWERABLE_FOR_QA = True
    FALLBACK_TOPK = 20

    # llm = "gpt-4o-mini"
    llm = "gemini-2.5-flash-lite"
    # llm = "deepseek"
    # llm = "gpt-4o"

    output_file = f"outputs/musique_dev_eval_final_sample_{llm}.json"
    # output_file = f"outputs/musique_dev_eval_final_sample_{llm}_wo_qanalyzer.json"
    # output_file = "outputs/musique_dev_eval_final_gpt-4o.json"  # change to avoid overwriting

    # Load
    all_items = list(read_jsonl(musique_path))
    print("Loaded MuSiQue dev:", len(all_items))
    items = all_items[START:END]

    # Accumulators
    total_p = total_r = total_f1 = total_fpr = 0.0
    total_acc_full = total_acc_retr = total_acc_oracle = 0.0
    retr_count = 0
    qa_count = 0

    for i, ex in enumerate(items, START):
        qid = ex.get("id", f"idx_{i}")
        question = ex.get("question", "")
        gold_answer = (ex.get("answer") or "").strip()
        answerable = bool(ex.get("answerable", True))
        paragrphs = ex.get("paragraphs", []) 

        if not answerable:
            print(f"\n--- Skipping unanswerable question {qid} ---")
            continue

        context, gold_idx = musique_context_gold(ex)
        all_idx = [idx for idx, _, _ in context]

        print("\n-------------------------------")
        print(f"Index: {i}  ID: {qid}")
        print("Question:", question)
        print("Gold indices:", gold_idx)
        print("All indices:", all_idx) 
        # print("Paragraphs:", paragrphs)
        # print("Context:", context)

        gold_passages = [context[idx][1] for idx in gold_idx]  # titles
        print("Gold passages:", gold_passages)  # titles

        # Analyze
        subqs = question_analyzer_agent(question)
        # subqs = []
        print("Subquestions:", subqs)

        # Initial retrieve (indices)
        init_idx = get_initial_indices(question, subqs, context, fallback_topk=FALLBACK_TOPK)
        print("Initial indices:", init_idx)
        init_passages = [context[idx][1] for idx in init_idx]   # titles
        print("Initial passages:", init_passages)  # titles
        # Agentic refinement
        final_idx = run_agent_loop(question, subqs, context, init_idx, n_rounds=N_ROUNDS)
        
        final_passages = [context[idx][1] for idx in final_idx]   # titles
        print("Final passages:", final_passages)  # titles



        print("Final indices:", final_idx)
        print("Gold indices:", gold_idx)

        if not final_idx:
            print("❌ No passages retrieved; skipping.")
            continue

        # Retrieval metrics (index-level)
        scores = retrieval_metrics(final_idx, gold_idx)
        retr_count += 1
        total_p += scores["precision"]; total_r += scores["recall"]; total_f1 += scores["f1"]; total_fpr += scores["false_positive_rate"]
        print(f"Retrieval — P:{scores['precision']} R:{scores['recall']} F1:{scores['f1']} FPR:{scores['false_positive_rate']}")
        print("TP:", scores["true_positives"])
        print("FN:", scores["false_negatives"][:10])
        print("FP:", scores["false_positives"][:10])
        if retr_count:
            print("AVG so far — P:", round(total_p/retr_count,4), "R:", round(total_r/retr_count,4),
                  "F1:", round(total_f1/retr_count,4), "FPR:", round(total_fpr/retr_count,4))

        # QA: full vs retrieved-only
        do_qa = (answerable or not SKIP_UNANSWERABLE_FOR_QA) and bool(gold_answer)
        
        # do_qa = False

        if do_qa:
            try:
                full_text = build_text_all(context)
                retr_text = build_text_from_indices(context, final_idx)
                gold_text = build_text_from_indices(context, gold_idx)

                full_ans = qa_agent(question, full_text)
                retr_ans = qa_agent(question, retr_text)
                oracle_ans = qa_agent(question, gold_text)  # for reference
            except Exception as e:
                print("❌ QA failed:", e)
                continue

            acc_full = evaluate_qa(full_ans, gold_answer)
            acc_retr = evaluate_qa(retr_ans, gold_answer)
            acc_oracle = evaluate_qa(oracle_ans, gold_answer)

            qa_count += 1
            total_acc_full += acc_full
            total_acc_retr += acc_retr
            total_acc_oracle += acc_oracle

            print("Gold Answer:", gold_answer) # truncated
            print(f"QA (Full):      {full_ans} | EM: {acc_full}")
            print(f"QA (Retrieved): {retr_ans} | EM: {acc_retr}")
            print(f"QA (Oracle):    {oracle_ans} | EM: {acc_oracle}")
            print(f"QA AVG — Full: {round(total_acc_full/qa_count,4)} | Retrieved: {round(total_acc_retr/qa_count,4)} | Oracle: {round(total_acc_oracle/qa_count,4)}")
        else:
            print("QA EM skipped (unanswerable or empty gold).")

        # Persist record
        record = {
            "index": i,
            "id": qid,
            "question": question,
            "question_analyzer": subqs,
            "gold_answer": gold_answer,
            "answerable": answerable,
            "gold_indices": gold_idx,
            "initial_indices": init_idx,
            "final_indices": final_idx,
            "gold_passages": gold_passages,
            "initial_passages": init_passages,
            "final_passages": final_passages,
            "precision": scores["precision"],
            "recall": scores["recall"],
            "f1": scores["f1"],
            "false_positive_rate": scores["false_positive_rate"],
            "true_positives": scores["true_positives"],
            "false_positives": scores["false_positives"],
            "false_negatives": scores["false_negatives"],
        }
        if do_qa:
            record.update({
                "answer_full_context": full_ans,
                "answer_retrieved": retr_ans,
                "match_full": acc_full,
                "match_retrieved": acc_retr,
                "answer_oracle": oracle_ans,
                "match_oracle": acc_oracle,
            })
        append_to_json_file(output_file, qid, record)

    # Summary
    print("\n=========== Summary ===========")
    if retr_count:
        print(f"Retrieval AVG — P:{round(total_p/retr_count,4)} R:{round(total_r/retr_count,4)} F1:{round(total_f1/retr_count,4)} FPR:{round(total_fpr/retr_count,4)}")
    else:
        print("No retrieval items scored.")
    if qa_count:
        print(f"QA EM AVG — Full:{round(total_acc_full/qa_count,4)} Retrieved:{round(total_acc_retr/qa_count,4)} Oracle:{round(total_acc_oracle/qa_count,4)}")
    else:
        print("No QA scored (skipped).")


# python3 test_musique.py