import json, ast, re,string, time, os
from typing import List, Dict
import openai
from openai import OpenAI
import google.generativeai as genai

# ------------------------------- #
#   0. SETUP: OpenAI connection   #
# ------------------------------- #
API_KEY = #API KEY

os.environ['OPENAI_API_KEY'] = API_KEY
openai.api_key = os.getenv("OPENAI_API_KEY")
client = OpenAI(api_key=API_KEY)

# ---------------- -------------------------------- ---------
os.environ["GOOGLE_API_KEY"] =  #API1
genai.configure(api_key=os.environ['GOOGLE_API_KEY'])
# ----------------------------------------------------------
# ---------- openrouter --------------
#
# client = OpenAI(
#   base_url="https://openrouter.ai/api/v1",
#   api_key= #API KEY
# )

# --------- deepseek ------------ --------
# client = OpenAI(api_key=<API_KEY>, base_url="https://api.deepseek.com")
# -------------------------------- ---------

# model_type="gpt" 
model_type="gemini" 

model="gpt-4o"
# model =  "deepseek/deepseek-chat-v3.1:free"
# model = "deepseek-chat" #deepseek api
# model = "deepseek/deepseek-r1:free"
# ---------------------------------------------------------- 

# ------------------------ 
def call_llm(prompt, temperature=0.3, model_type="gpt", max_tokens=512):

    response = ""

    if model_type == "gpt" or model_type == "openrouter":
        time.sleep(4)
        response = get_completion(prompt, temperature=temperature, max_tokens=max_tokens)
        # response = get_completion_openrouter(prompt, temperature=temperature)
        # print("response:", response)

    elif model_type == "gemini":
        time.sleep(4)
        response = get_completion_gemmini(prompt, temperature=temperature)
    else:
        print("Error at call_llm: Model not specified.\n")
    return response

def get_completion(messages: List[Dict[str, str]], model=model, temperature=0.1, max_tokens=512):
    delay = 1
    print("call llm - gpt/openrouter")
    while True:
        try:
            resp = client.chat.completions.create(
                model=model,
                messages=messages,
                temperature=temperature,
                max_tokens=max_tokens,
            )
            return resp.choices[0].message.content
        except Exception as e:
            print("[OpenAI error]", e, "—retrying in", delay, "s")
            time.sleep(delay)
            delay = min(delay * 2, 60)

# --------------
def get_completion_gemmini(prompt, temperature=0.1, max_tokens = 512):
    print("gemini")

    # Access the first dictionary (the system message)
    system_message = prompt[0]
    system_content = system_message['content']

    # Access the second dictionary (the user message)
    user_message = prompt[1]
    user_content = user_message['content']

    # Create the model
    # See https://ai.google.dev/api/python/google/generativeai/GenerativeModel
    generation_config = {
        "temperature": temperature,
        "top_p": 0.95,
        "top_k": 64,
        "max_output_tokens": max_tokens,
        # "stop_sequences": [";"],
        "response_mime_type": "text/plain",
        # "response_mime_type": "application/json"
    }

    model = genai.GenerativeModel(
        model_name="gemini-2.5-flash-lite",
        # model_name="gemini-2.0-flash",
        generation_config=generation_config,
        system_instruction=system_content,
        # safety_settings = Adjust safety settings
        # See https://ai.google.dev/gemini-api/docs/safety-settings
    )

    chat_session = model.start_chat(
        history=[
        ]
    )

    try:
        response = chat_session.send_message(user_content)

    except genai.types.StopCandidateException as e:
        print("Gemini refused to generate a response:", e)
        return None
    # print("LLM Response: ", response.text)
    return  response.text

# ------------------------------- #
# 1. Zero-shot Initial Selector   #
# ------------------------------- #

def question_analyzer_agent(question):
    """
    Analyze the question to extract key entities and phrases that should guide evidence selection.
    """
    system_header = (
        """You are a Question Analyzer Agent in a multi-hop question answering system.
Your task is to analyze a complex, multi-hop question and break it down into a sequence of concise, meaningful subquestions that reflect the reasoning steps required to answer the original question.

To do this, identify and extract subquestions based on:
	•	Key entities (persons, organizations, locations, dates, etc.)
	•	Important nouns or noun phrases (e.g., titles, concepts, objects)
	•	Logical or temporal relationships (e.g., comparisons, causality, sequences)
	•	Specific conditions or constraints in the question

Each subquestion should focus on retrieving or verifying a specific piece of evidence that contributes to the final answer.

Output Format:
Return an ordered list of subquestions that represent a clear reasoning path from the question to the answer. Keep each subquestion short, specific, and unambiguous.

Example Input:
“Which actor played the brother of the character who was portrayed by the same actress that starred in Legally Blonde?”

Example Output:
	1.	Who starred in Legally Blonde?
	2.	What character did that actress portray in another film?
	3.	Who played the brother of that character?"""
    )
    prompt = [
        {"role": "system", "content": system_header},
        {"role": "user", "content": f"For the following question, return a JSON object with a list of extracted elements like this: {{\"Subquestions\": [\"...\", \"...\"]}}.\nQuestion: {question}"}
    ]
    response = call_llm(prompt, temperature=0.1, model_type=model_type)
    try:
        # result = ast.literal_eval(response.strip())
        response = response.replace("```json", "").replace("```", "").strip()
        data = json.loads(response)
        subqs = data["Subquestions"]
        # return result.get("Subquestions", [])
        return subqs
    except Exception as e:
        print("Failed to parse question analyzer output:", e)
        return []

def zero_shot_selector_agent(question, question_keywords, context):
    system_header = (
        "You are an expert evidence selection agent for question answering. "
        "Given a question and context paragraphs, select the minimum set of sentence(s) "
        "from the context that are STRICTLY NECESSARY to answer the question. "
        "Return only a pruned list of ['title', sentence_index_no] pairs — no explanations. The sentence index starts from 0 for each paragraph. i.e: [['pasage title A', 0], ['passage title B', 2], ...]"
    )
    prompt = [
        {"role": "system", "content": system_header},
        {"role": "user", "content": (
            f"Question: {question}\n"
            f"Important Sub Questions: {question_keywords}\n"
            f"Context: {context}\n"
            "Select the minimum sufficient set of supporting facts."
        )}
    ]
    return call_llm(prompt, temperature=0.1, model_type=model_type)

def parse_response(response):
    cleaned = response.replace("```json", "").replace("```", "").strip()
    if ":" in cleaned:
        cleaned = cleaned.split(":", 1)[-1].strip()
    try:
        lst = ast.literal_eval(cleaned)
        fixed = []
        for x in lst:
            if len(x) == 2 and isinstance(x[1], list) and len(x[1]) == 1 and isinstance(x[1][0], int):
                fixed.append([x[0], x[1][0]])
            else:
                fixed.append(x)
        return fixed
    except Exception as e:
        print("Could not parse agent response:", e)
        return []

def get_initial_evidence(question, question_keywords, context):
    raw_response = zero_shot_selector_agent(question, question_keywords, context)
    initial_evidence = parse_response(raw_response)
    if not initial_evidence:
        # Fallback: include all sentences, but this should be rare
        initial_evidence = []
        for title, sentences in context:
            for idx in range(len(sentences)):
                initial_evidence.append([title, idx])
    return initial_evidence

# ------------------------------- #
#   2. Selector & Adder Agents    #
# ------------------------------- #
def selector_agent(question, question_info, candidates, current_evidence):
   
    system_header = (
    "You are a **Selector Agent** in a multi-hop QA system. Your goal is to MAXIMIZE PRECISION and MINIMIZE FALSE POSITIVES.\n\n"
    "You are given:\n"
    "- A complex multi-hop question\n"
    "- Subquestions that represent the reasoning steps\n"
    "- Candidate context: a list of facts, each represented by ['title', sentence_index]\n"
    "- CURRENT SELECTED EVIDENCE: sentences previously marked as potentially relevant\n\n"
    "Your job is to carefully REMOVE only those evidence items that are DEFINITELY irrelevant for answering ANY of the subquestions.\n\n"
    "DO NOT remove sentences that are:\n"
    "- Partially relevant\n"
    "- Possibly needed to bridge between subquestions\n"
    "- Contain named entities, dates, or events referenced in the question\n\n"
    "DO NOT add or regenerate. Work strictly with the given CURRENT SELECTED EVIDENCE.\n"
    "Return only a pruned list of ['title', sentence_index_no] pairs — no explanations. The sentence index starts from 0 for each paragraph."
    )

    prompt = [
    {"role": "system", "content": system_header},
    {"role": "user", "content": (
        f"Question: {question}\n"
        f"Subquestions: {question_info}\n"
        f"Candidates: {candidates}\n"
        f"Current Selected Evidence: {current_evidence}\n"
        "Return only the updated list with clearly irrelevant sentences removed:"
    )}
    ]
    
    return call_llm(prompt, temperature=0.1, model_type=model_type)

def adder_agent(question, question_info, candidates, current_evidence):

    system_header = (
    "You are an **Adder Agent** in a multi-hop QA system. Your goal is to MAXIMIZE RECALL while MINIMIZING FALSE POSITIVES.\n\n"
    "You are given:\n"
    "- A complex multi-hop question\n"
    "- Subquestions that represent reasoning steps\n"
    "- Candidate context: Passages and sentences from where you can get the answer\n"
    "- CURRENT SELECTED EVIDENCE: a list of already selected evidence items, each represented by ['title', sentence_index] \n\n"
    "Your job is to find and ADD ONLY those candidate sentences that are LIKELY to support answering any subquestion. "
    "DO NOT add sentences that:\n"
    "- Are vague, unrelated, or overly general\n"
    "- Contain off-topic facts or unrelated entities\n"
    "- Repeat or overlap significantly with existing evidence\n\n"
    "Focus on:\n"
    "- Bridging facts that connect entities between subquestions\n"
    "- Sentences containing named entities, dates, definitions, or relationships mentioned in the subquestions\n"
    "- Factual sentences that clearly contribute to a reasoning chain\n\n"
    "DO NOT remove or modify existing evidence. Return a list that includes both current and newly added items. Format as a list of ['title', sentence_index_no] pairs — no explanations. The sentence index starts from 0 for each paragraph."
    )

    prompt = [
    {"role": "system", "content": system_header},
    {"role": "user", "content": (
        f"Question: {question}\n"
        f"Subquestions: {question_info}\n"
        f"Candidates: {candidates}\n"
        f"Current Selected Evidence: {current_evidence}\n"
        "Return the updated evidence list with only relevant additions:"
    )}
    ]
    return call_llm(prompt, temperature=0.1, model_type=model_type)

def qa_agent(question: str, evidence: str) -> str:
    """
    Generate an answer to the question given the retrieved evidence.
    """
    system_header = (
        "You are a question answering agent. Given a question and the supporting evidence, "
        "provide a concise, factual and short answer only based on the evidence without other words. "
        "If the answer cannot be determined from the evidence, reply with 'Not Answerable'."
    )

    prompt = [
        {"role": "system", "content": system_header},
        {"role": "user", "content": (
            f"Question: {question}\n\n"
            f"Evidence:\n{evidence}\n\n"
            "Answer:"
        )}
    ]

    # print("QA Agent Prompt:", prompt)
    return call_llm(prompt, max_tokens=10, model_type=model_type).strip()

def flatten_context(context: List[List]) -> List[List[str]]:
    """
    Converts the full context into [['title', sentence_index], ...] format.
    """
    full = []
    for title, sentences in context:
        for idx in range(len(sentences)):
            full.append([title, idx])
    return full

def extract_textual_evidence(context: List[List], evidence: List[List]) -> str:
    """
    Given the original HotpotQA context and selected evidence (list of [title, idx]),
    return a string that includes the corresponding sentences grouped by title.
    """
    title_to_sentences = {title: sentences for title, sentences in context}
    grouped = {}

    for title, sent_idx in evidence:
        if title in title_to_sentences and 0 <= sent_idx < len(title_to_sentences[title]):
            grouped.setdefault(title, []).append(title_to_sentences[title][sent_idx])

    # Format: Title:\n - Sentence1\n - Sentence2 ...
    text_blocks = []
    for title, sentences in grouped.items():
        passage = f"{title}:\n" + "\n".join(f"- {s}" for s in sentences)
        text_blocks.append(passage)

    return "\n\n".join(text_blocks)

def normalize_answer(s):

    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))

# def exact_match_score(prediction, ground_truth):
#     return (normalize_answer(prediction) == normalize_answer(ground_truth))

def evaluate_qa(predicted_answer: str, gold_answer: str) -> float:
    """
    Returns 1.0 if exact match, else 0.0. You can replace this with fuzzy matching later.
    """
    return 1.0 if (normalize_answer(predicted_answer) == normalize_answer(gold_answer)) else 0.0
    # return 1.0 if predicted_answer.lower().strip() == gold_answer.lower().strip() else 0.0

# ------------------------------- #
#        3. Agentic Loop          #
# ------------------------------- #
def run_agent_loop(question, question_keywords, candidates, initial_evidence, n_rounds=3):
    evidence = initial_evidence
    for round_num in range(n_rounds):
        print(f"\n--- Iteration {round_num + 1} ---")

        # Agent 3: Adder Agent (adds missing)
        adder_response = adder_agent(question, question_keywords, candidates, evidence)
        # print("Adder Agent Output:", adder_response)
        evidence = parse_response(adder_response)
        print("After Adder:", evidence)
        # Agent 2: Selector Agent (removes irrelevant)
        selector_response = selector_agent(question, question_keywords, candidates, evidence)
        # print("Selector Agent Output:", selector_response)
        evidence = parse_response(selector_response)
        print("After Selector:", evidence)

        if evidence == []:
            print("❌ No evidence left after selection. Stopping early.")
            evidence = initial_evidence  # Fallback to initial evidence
        
        if evidence == []:
            print("❌ No evidence added. Stopping early.")
            evidence = initial_evidence  # Fallback to initial evidence

    return evidence

# ------------------------------- #
#     4. Evaluation Utilities     #
# ------------------------------- #
def calculate_performance(pred: List, gold: List[List[str]]) -> Dict[str, any]:
    try:
        gold_set = set(tuple(x) for x in gold)
        pred_set = set(tuple(x) for x in pred)
        tp = len(gold_set & pred_set)
        fp = len(pred_set - gold_set)
        fn = len(gold_set - pred_set)
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
        false_positive_rate = fp / (tp + fp) if (tp + fp) > 0 else 0.0
        return True, {
            "precision": round(precision, 4),
            "recall": round(recall, 4),
            "f1": round(f1, 4),
            "false_positive_rate": round(false_positive_rate, 4),
            "true_positives": list(gold_set & pred_set),
            "false_positives": list(pred_set - gold_set),
            "false_negatives": list(gold_set - pred_set),
        }
    except Exception as e:
        print("Error processing prediction or gold:", pred, gold)
        print("Exception:", e)
        return False, default_scores()

def default_scores():
    return {
        "precision": 0.0,
        "recall": 0.0,
        "f1": 0.0,
        "false_positive_rate": 0.0,
        "true_positives": [],
        "false_positives": [],
        "false_negatives": [],
    }

def append_to_jsonl(output_path: str, new_entry: dict):
    if not os.path.exists(output_path):
        print(f"⚠️ File '{output_path}' does not exist. Skipping.")
        return
    with open(output_path, 'a') as f:
        f.write(json.dumps(new_entry) + '\n')

def append_to_json_file(file_path, question_id, value):
    # Check if the file exists
    if not os.path.exists(file_path):
        # If not, create an empty dictionary in the new JSON file
        with open(file_path, 'w') as json_file:
            json.dump({}, json_file)
        print(f"Created new JSON file at {file_path}")

    data = {}
    # Read the existing JSON file
    with open(file_path, 'r') as json_file:
        data = json.load(json_file)

    # Append the new data
    data[question_id] = value

    # Write the updated data back to the JSON file
    with open(file_path, 'w') as json_file:
        json.dump(data, json_file, indent=4)

    print(f"Appended entry with question_id {question_id} to {file_path}")

# ------------------------------- #
#        5. Main Pipeline         #
# ------------------------------- #
if __name__ == "__main__":
    # Load your HotpotQA as before
    # file_path = '2wikimultihop/2wikimultihop-dev.json'
    file_path = '2wikimultihop/2wiki_dev_eval_samples.json'

    with open(file_path, 'r') as f:
        data = json.load(f)
    print("Loaded 2wiki dataset with", len(data), "examples.")

    N = 50  # For testing; use 50+ for large eval
    start = 0
    end = start + N
    data = data[start:end]

    total_precision = 0.0
    total_recall = 0.0
    total_f1 = 0.0
    total_fpr = 0.0
    total_acc_full = 0.0
    total_acc_retrieved = 0.0
    total_acc_oracle = 0.0
    # Initialize count for averaging metrics
    count = 1

    # llm = "gpt-4o-mini"
    # llm = "gemini-2.5-flash-lite"
    # llm = "deepseek"
    llm = "gpt-4o"

    # output_file = "qa_outputs.jsonl"
    # output_file = "qa_outputs-2.json"
    # output_file = "qa_outputs-0-1k.json"
    # output_file = "qa_outputs-1-2k.json"
    # output_file = "qa_outputs-0-5h-2wiki.json"
    output_file = f"outputs/2wiki_dev_eval_final_sample_{llm}.json"
    # output_file = f"outputs/2wiki_dev_eval_final_sample_{llm}_wo_qanalyzer.json"
    
    # with open(output_file, "w") as f_out:
    #     pass  # Just to clear the file if it exists

    for i, item in enumerate(data):
        id = item.get('_id', 'N/A')
        question = item.get('question', 'N/A')
        answer = item.get('answer', 'N/A')
        gold_supporting_facts = item.get('supporting_facts', [])
        context = item.get('context', [])
        print("\n-------------------------------")
        index = i+start
        print("Index:", index, "ID:", id)
        print("Question:", question)
        print("Answer:", answer)

        question_info = question_analyzer_agent(question)
        # question_info = []
        print("Question Analyzer:", question_info)

        # 1. Use zero-shot selector to seed initial evidence
        initial_evidence = get_initial_evidence(question, question_info, context)
        print("Initial Evidence (Zero-shot LLM):", initial_evidence)

        # 2. Refine with agentic loop
        final_evidence = run_agent_loop(question, question_info, context, initial_evidence, n_rounds=3)
        print("-------------------------------\n")
        print("Final Retrieved Evidence:", final_evidence)
        print("GOLD - Supporting Facts:", gold_supporting_facts)
        
        if final_evidence == []:
            print("❌ No evidence retrieved. Skipping evaluation for this item.")
            continue

        # 3. Evaluate
        success, scores = calculate_performance(final_evidence, gold_supporting_facts)

        if not success:
            print("❌ Failed to calculate performance for this item. Skipping.")
            continue

        # Accumulate metrics only if success is True
        total_precision += scores['precision']
        total_recall += scores['recall']
        total_f1 += scores['f1']
        total_fpr += scores['false_positive_rate']

        print(f"Precision: {scores['precision']}")
        print(f"Recall: {scores['recall']}")
        print(f"F1 Score: {scores['f1']}")
        print(f"False Positive Rate: {scores['false_positive_rate']}")
        print("✅ Correct (TP):", scores["true_positives"])
        print("❌ Missing (FN):", scores["false_negatives"])
        print("🚨 Extra (FP):", scores["false_positives"])
        print("AVG RECALL so far: ", round(total_recall / count, 4))
        print("AVG F1 so far: ", round(total_f1 / count, 4))
        print("AVG False Positive Rate so far: ", round(total_fpr / count, 4))
        print("Average Precision so far: ", round(total_precision / count, 4))
        print("-------------------------------\n\n")

        # Answer using full context

        try: 
            question = question + "\nSubquestions: " + " | ".join(question_info) if question_info else question
            print("question for QA:", question )

            full_context = flatten_context(context)
            retrieved_context_text = extract_textual_evidence(context, final_evidence)
            full_context_text = extract_textual_evidence(context, [[title, idx] for title, sents in context for idx in range(len(sents))])
            gold_context_text = extract_textual_evidence(context, gold_supporting_facts)

            retrieved_answer = qa_agent(question, retrieved_context_text)
            full_answer = qa_agent(question, full_context_text)
            oracle_answer = qa_agent(question, gold_context_text)
        except Exception as e:
            print("Error during QA generation:", e)
            continue

        # Gold answer
        gold_answer = item.get("answer", "").strip()
        
        # Evaluate
        acc_full = evaluate_qa(full_answer, gold_answer)
        acc_retrieved = evaluate_qa(retrieved_answer, gold_answer)
        acc_oracle = evaluate_qa(oracle_answer, gold_answer)

        print(f"Gold Answer: {gold_answer}")
        print(f"Answer (Full Context): {full_answer} | Match: {acc_full}")
        print(f"Answer (Retrieved Evidence): {retrieved_answer} | Match: {acc_retrieved}")
        print(f"Answer (Oracle Evidence): {oracle_answer} | Match: {acc_oracle}")

        total_acc_full += acc_full
        total_acc_retrieved += acc_retrieved
        total_acc_oracle += acc_oracle

        print(f"QA Accuracy (Full Context): {round(total_acc_full / count, 4)}")
        print(f"QA Accuracy (Retrieved Evidence): {round(total_acc_retrieved / count, 4)}")
        print(f"QA Accuracy (Oracle Evidence): {round(total_acc_oracle / count, 4)}") 

        output_record = {
        "index": index,
        "id": id,
        "question": question,
        "question_analyzer": question_info,
        "gold_answer": gold_answer,
        "answer_full_context": full_answer,
        "answer_retrieved": retrieved_answer,
        "answer_oracle": oracle_answer,
        "match_full": acc_full,
        "match_retrieved": acc_retrieved,
        "match_oracle": acc_oracle,
        "gold_supporting_facts": gold_supporting_facts,
        "initial_evidence": initial_evidence,
        "final_evidence": final_evidence,
        "precision": scores["precision"],
        "recall": scores["recall"],
        "f1": scores["f1"],
        "false_positive_rate": scores["false_positive_rate"],
        "true_positives": scores["true_positives"],
        "false_positives": scores["false_positives"],
        "false_negatives": scores["false_negatives"]
        }

        # with open(output_file, "a") as f_out:
        #     f_out.write(json.dumps(output_record) + "\n")

        # append_to_jsonl(output_file, output_record)
        append_to_json_file(output_file, id, output_record) 

        count += 1

    print(f"Total Examples Processed: {count}")
    
    # --------------------------------------------------------------------- #
# --------------------------------------------------------------------- #
# source venv/bin/activate
# 

# python3 test_2wiki.py
