from google.genai import types
from google import genai
import time
from datasets import load_dataset
import json
from tqdm import tqdm

tqdm.pandas()

client = genai.Client(
  vertexai=True, project="", location="us-east1",
)



# Define model names or identifiers for the API
STRONG_MODEL = "gemini-2.5-flash"      # a strong model (e.g., high-capacity Gemini)
WEAK_MODEL    = "gemini-2.0-flash-lite-001"  # a less powerful model for the judge

def gemini_generate(model, prompt, temperature, max_tokens):
    """
    Send a prompt to the Gemini API for the specified model and return the response text.
    (This is a placeholder for the actual API call.)
    """
    retries = 5
    wait_time = 60
    while retries > 0:
        try:
            response = client.models.generate_content(
                model=model,
                config=types.GenerateContentConfig(temperature=temperature,maxOutputTokens=max_tokens),
                contents=[
                   prompt,
                ],
            )
            
            answer = response.text
            if answer is None:
                raise Exception("The respond from gemini is None!")
            return answer
        
        except Exception as e:
            print(f"API Error: {e}. Retrying ({6 - retries}/5)...")
            retries -= 1
            time.sleep(wait_time)  # Wait before retrying

    return None


def run_debate(question, correct_answer, distractor_answer=None):
    """
    Conduct one debate on the given question using the Uncertainty-Aware Role-Switching Protocol.
    Returns the final answer and full transcript, or (None, None) if any generation fails.
    """
    try:
        prompt_A = f"You are Debater A. Your task is to convince the judge that the correct answer to the following question is: {correct_answer}. Question: {question} Provide your answer and detailed reasoning."
        debater_A_answer = gemini_generate(model=STRONG_MODEL, prompt=prompt_A, temperature=0.7, max_tokens=2048)
        if debater_A_answer is None: return None, None
        debater_A_text = f"Debater A (answer and reasoning): {debater_A_answer.strip()}"

        if distractor_answer:
            assigned_B = distractor_answer
        else:
            assigned_B = "an alternative answer"

        prompt_B = f"You are Debater B. Your task is to convince the judge that the correct answer to the following question is: {assigned_B}. Question: {question} Provide your answer and detailed reasoning."
        debater_B_answer = gemini_generate(model=STRONG_MODEL, prompt=prompt_B, temperature=0.7, max_tokens=2048)
        if debater_B_answer is None: return None, None
        debater_B_text = f"Debater B (answer and reasoning): {debater_B_answer.strip()}"

        transcript = f"Question: {question}\n{debater_A_text}\n{debater_B_text}\n"

        prompt_A_rebuttal = f"{transcript}You are Debater A. You have seen Debater B's argument above. Rebut Debater B's argument by identifying any errors or weak points, and defend your original answer."
        debater_A_rebuttal = gemini_generate(model=STRONG_MODEL, prompt=prompt_A_rebuttal, temperature=0.7, max_tokens=2048)
        if debater_A_rebuttal is None: return None, None
        transcript += f"Debater A (rebuttal): {debater_A_rebuttal.strip()}\n"

        prompt_B_rebuttal = f"{transcript}You are Debater B. You have seen Debater A's argument and rebuttal above. Rebut Debater A's argument by identifying any errors or weak points in it, and reinforce your original answer."
        debater_B_rebuttal = gemini_generate(model=STRONG_MODEL, prompt=prompt_B_rebuttal, temperature=0.7, max_tokens=2048)
        if debater_B_rebuttal is None: return None, None
        transcript += f"Debater B (rebuttal): {debater_B_rebuttal.strip()}\n"

        prompt_A_switch = f"{transcript}Now, switch roles. You are still Debater A, but for this turn pretend that Debater B's answer is actually correct. Argue in favor of Debater B's answer ('{assigned_B}') as convincingly as possible."
        debater_A_switch = gemini_generate(model=STRONG_MODEL, prompt=prompt_A_switch, temperature=0.7, max_tokens=2048)
        if debater_A_switch is None: return None, None
        transcript += f"Debater A (role-switch argument in favor of B's answer): {debater_A_switch.strip()}\n"

        prompt_B_switch = f"{transcript}Now, you are still Debater B, but pretend that Debater A's answer is actually correct. Argue in favor of Debater A's answer ('{correct_answer}') as convincingly as possible."
        debater_B_switch = gemini_generate(model=STRONG_MODEL, prompt=prompt_B_switch, temperature=0.7, max_tokens=2048)
        if debater_B_switch is None: return None, None
        transcript += f"Debater B (role-switch argument in favor of A's answer): {debater_B_switch.strip()}\n"

        prompt_A_uncertainty = f"{transcript}You are Debater A. Reflect on the debate so far. How confident are you that your original answer is correct (e.g., in percentage)? Identify any remaining uncertainties and explain what information would help resolve them."
        debater_A_uncertainty = gemini_generate(model=STRONG_MODEL, prompt=prompt_A_uncertainty, temperature=0.7, max_tokens=2048)
        if debater_A_uncertainty is None: return None, None
        transcript += f"Debater A (uncertainty analysis): {debater_A_uncertainty.strip()}\n"

        prompt_B_uncertainty = f"{transcript}You are Debater B. Reflect on the debate so far. How confident are you that your original answer is correct (e.g., in percentage)? Identify any remaining uncertainties and explain what information would help resolve them."
        debater_B_uncertainty = gemini_generate(model=STRONG_MODEL, prompt=prompt_B_uncertainty, temperature=0.7, max_tokens=2048)
        if debater_B_uncertainty is None: return None, None
        transcript += f"Debater B (uncertainty analysis): {debater_B_uncertainty.strip()}\n"

        judge_prompt = f"{transcript}You are the Judge. Read the full debate above. Based on the arguments and uncertainties expressed, determine which answer is more likely to be correct. Output only the final answer you conclude is correct (no explanation)."
        judge_decision = gemini_generate(model=WEAK_MODEL, prompt=judge_prompt, temperature=0.3, max_tokens=2048)
        if judge_decision is None: return None, None
        final_answer = judge_decision.strip()

        return final_answer, transcript

    except Exception as e:
        print(f"Debate execution failed: {e}")
        return None, None


# Load the datasets
openbook = load_dataset("allenai/openbookqa", "main")
#commonsense = load_dataset("tau/commonsense_qa")

openbook_val = openbook["validation"]      # 500 examples
#commonsense_val = commonsense["validation"]  # 1221 examples

# Function to evaluate debate accuracy on a given dataset split
def evaluate_debate_on_dataset(dataset):
    correct_count = 0
    debate_final_answers = []
    debate_correct_flags = []
    debate_transcript = []
    for sample in tqdm(dataset):
        # Extract question and answers depending on dataset format
        if "question_stem" in sample:  # OpenBookQA format
            question_text = sample["question_stem"]
        else:  # CommonsenseQA format
            question_text = sample["question"]
        choices = sample["choices"]           # choices is a dict with 'text' and 'label'
        answer_key = sample["answerKey"]      # e.g. "A", "B", etc.
        labels = choices["label"]             # list of option labels (e.g. ["A","B","C","D"])
        texts = choices["text"]               # list of option texts
        # Identify correct answer text
        if answer_key in labels:
            correct_idx = labels.index(answer_key)
            correct_text = texts[correct_idx]
        else:
            # If answer_key not directly in labels list (should not happen in these datasets)
            correct_text = answer_key
        # Pick a distractor answer text (first option that is not correct, for example)
        distract_text = None
        for i, lab in enumerate(labels):
            if lab != answer_key:
                distract_text = texts[i]
                break
        if distract_text is None:
            distract_text = "an alternative answer"  # Fallback, though there should always be one
        
        # Run the debate on this question
        final_answer, transcript= run_debate(question_text, correct_text, distract_text)
        # Skip if any step failed
        if final_answer is None or transcript is None:
            print(f"[SKIP] Debate generation failed for question: {question_text[:80]}...")
            continue
        debate_final_answers.append(final_answer)
        # Determine if the judge's final answer matches the correct answer
        predicted_text = None
        fa = final_answer.strip().strip('.')  # remove trailing period if any
        fa_lower = fa.lower()
        # Check if final answer matches one of the option texts exactly (case-insensitive)
        for opt_text in texts:
            if fa_lower == opt_text.lower():
                predicted_text = opt_text
                break
        # If not exact match, check if final answer is a label (e.g., "A", "B", etc.)
        if predicted_text is None:
            for lab, opt_text in zip(labels, texts):
                if fa_lower == lab.lower():
                    predicted_text = opt_text
                    break
        # If still not matched, try to see if an option text is contained in the final answer (e.g., "the answer is X")
        if predicted_text is None:
            for opt_text in texts:
                if opt_text.lower() in fa_lower:
                    predicted_text = opt_text
                    break
        # Fallback: if no match found, use the final answer as is (could be a free-form text in some edge case)
        if predicted_text is None:
            predicted_text = fa
        
        # Check if predicted text matches the correct answer text
        is_correct = (predicted_text.strip().lower() == correct_text.strip().lower())
        debate_correct_flags.append(is_correct)
        if is_correct:
            correct_count += 1

        debate_transcript.append({
            "question": question_text,
            "choices": {"labels": labels, "texts": texts},
            "correct_answer": correct_text,
            "distractor_answer": distract_text,
            "judge_final_answer": final_answer,
            "predicted_text_normalized": predicted_text,
            "is_correct": is_correct,
            "transcript": transcript,
        })
    # Calculate accuracy
    accuracy = correct_count / len(debate_transcript) if len(debate_transcript) > 0 else 0.0
    return accuracy, debate_final_answers, debate_correct_flags, debate_transcript

# Evaluate on OpenBookQA validation set
openbook_accuracy, openbook_final_answers, openbook_flags, openbook_transcript = evaluate_debate_on_dataset(openbook_val)
print(f"OpenBookQA Validation Accuracy: {openbook_accuracy*100:.2f}% ({sum(openbook_flags)}/{len(openbook_flags)} correct)")

with open("openbook_debate_transcript.json", "w", encoding="utf-8") as f:
    json.dump(openbook_transcript, f, ensure_ascii=False, indent=2)



# Evaluate on CommonsenseQA validation set
#commonsense_accuracy, cs_final_answers, cs_flags = evaluate_debate_on_dataset(commonsense_val)
#print(f"CommonsenseQA Validation Accuracy: {commonsense_accuracy*100:.2f}% ({sum(cs_flags)}/{len(cs_flags)} correct)")