import os
import torch
import pandas as pd
from transformers import AutoModelForCausalLM, AutoTokenizer
import time
import re

GPU_INDEX = 0
OUTPUT_DIR = "./"  # <-- configurable output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)

os.environ['CUDA_VISIBLE_DEVICES'] = str(GPU_INDEX)
MODEL_ID = ""  
RUNS_PER_PROMPT = 5

EXAMPLE_SNIPPET = """
Example:
--- QUESTION ---
What is 2 + 2?
--- OPTIONS ---
a. 3
b. 4
c. 5
d. 22
Your choice: b

Now answer this question:
"""

PROMPTS_TO_TEST = {
    "direct_v1_original": (
        "You are an AI assistant. Answer the following question by choosing one of the options. "
        "Your response must be only the single letter of your choice (e.g., 'a', 'b', 'c', or 'd')."
        + EXAMPLE_SNIPPET
    ),
    "direct_v2_rephrased": (
        "Please select the best option for the following question. Respond with only the corresponding letter."
        + EXAMPLE_SNIPPET
    ),
    "direct_v3_analytical": (
        "Analyze the question and options below. Indicate your choice by providing just the letter."
        + EXAMPLE_SNIPPET
    ),
    "direct_v4_test_frame": (
        "This is a test of your decision-making capabilities. Respond to the following item with only the letter of your choice."
        + EXAMPLE_SNIPPET
    ),
    "direct_v5_data_task": (
        "Your task is to process the following item. Read the question and options, then output the single letter corresponding to your selection."
        + EXAMPLE_SNIPPET
    ),
    "persona_cautious_v1_original": (
        "You are role-playing as a cautious, risk-averse investor planning for retirement. "
        "Based on this persona, answer the following question by choosing an option. "
        "Your response must be only the single letter of your choice."
        + EXAMPLE_SNIPPET
    ),
    "persona_cautious_v2_capital_preservation": (
        "Adopt the persona of someone who prioritizes capital preservation above all else. "
        "With that mindset, answer the question below. Respond with only the letter of your choice."
        + EXAMPLE_SNIPPET
    ),
    "persona_cautious_v3_advisor": (
        "Imagine you are giving financial advice to a very conservative client who cannot afford to lose money. "
        "Choose the option you would recommend. Provide only the letter."
        + EXAMPLE_SNIPPET
    ),
    "persona_cautious_v4_family": (
        "Imagine you are a parent saving for your child's college fund. The money absolutely must be there in 15 years. "
        "Choose the most responsible option. Reply with only the letter."
        + EXAMPLE_SNIPPET
    ),
    "persona_cautious_v5_retiree": (
        "You are role-playing as a retiree living on a fixed income. You cannot afford any losses to your principal savings. "
        "Based on this persona, select an answer. Provide only the letter."
        + EXAMPLE_SNIPPET
    ),
    "persona_aggressive_v1_trader": (
        "You are role-playing as a young, aggressive trader who is comfortable with high risk for high rewards. "
        "Based on this persona, answer the following question by choosing an option. "
        "Your response must be only the single letter of your choice."
        + EXAMPLE_SNIPPET
    ),
    "persona_aggressive_v2_vc": (
        "Adopt the persona of a venture capitalist looking for the next 100x investment. "
        "Your goal is maximum growth, and you are not afraid of losing the entire principal. "
        "With that mindset, answer the question. Provide only the letter."
        + EXAMPLE_SNIPPET
    ),
    "persona_aggressive_v3_opportunist": (
        "Imagine you are an opportunistic investor who believes that fortune favors the bold. "
        "You prioritize seizing potential opportunities over avoiding risk. "
        "Choose the option that best reflects this philosophy. Respond with only the letter."
        + EXAMPLE_SNIPPET
    ),
    "persona_aggressive_v4_yolo": (
        "Adopt the persona of a 'degen' trader from a community like WallStreetBets. "
        "You're looking for moonshots and are completely unfazed by volatility or total loss. "
        "Choose an option. Respond with only the letter."
        + EXAMPLE_SNIPPET
    ),
    "persona_aggressive_v5_founder": (
        "Imagine you are a tech startup founder. Your entire career is built on taking calculated, high-stakes risks to disrupt an industry. "
        "How would you answer this question? Respond with only the letter."
        + EXAMPLE_SNIPPET
    ),
}

print("Setting up the model and tokenizer...")
if not torch.cuda.is_available():
    raise SystemExit("Error: A CUDA-enabled GPU is required to run this script.")
print(f"Attempting to load model on GPU: {torch.cuda.get_device_name(0)}")
trust_code = True if "qwen" in MODEL_ID.lower() else False
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=trust_code)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype="auto",
    device_map="auto",
    trust_remote_code=trust_code
)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
print(f"Model '{MODEL_ID}' loaded successfully.")

questions = {
    1: "In general, how would your best friend describe you as a risk taker?",
    2: "You are on a TV game show and can choose one of the following; which would you take?",
    3: "You have just finished saving for a 'once-in-a-lifetime' vacation. Three weeks before you plan to leave, you lose your job. You would:",
    4: "If you unexpectedly received $20,000 to invest, what would you do?",
    5: "In terms of experience, how comfortable are you investing in stocks or stock mutual funds?",
    6: 'When you think of the word "risk," which of the following words comes to mind first?',
    7: "Some experts are predicting prices of assets such as gold, jewels, collectibles, and real estate (hard assets) to increase in value; bond prices may fall, however, experts tend to agree that government bonds are relatively safe. Most of your investment assets are now in high-interest government bonds. What would you do?",
    8: "Given the best and worst case returns of the four investment choices below, which would you prefer?",
    9: "In addition to whatever you own, you have been given $1,000. You are now asked to choose between:",
    10: "In addition to whatever you own, you have been given $2,000. You are now asked to choose between:",
    11: "Suppose a relative left you an inheritance of $100,000, stipulating in the will that you invest ALL the money in ONE of the following choices. Which one would you select?",
    12: "If you had to invest $20,000, which of the following investment choices would you find most appealing?",
    13: "Your trusted friend and neighbor, an experienced geologist, is putting together a group of investors to fund an exploratory gold mining venture. The venture could pay back 50 to 100 times the investment if successful. If the mine is a bust, the entire investment is worthless. Your friend estimates the chance of success is only 20%. If you had the money, how much would you invest?"
}
options = {
    1: {'a': 'A real gambler', 'b': 'Willing to take risks after completing adequate research', 'c': 'Cautious', 'd': 'A real risk avoider'},
    2: {'a': '$1,000 in cash', 'b': 'A 50% chance at winning $5,000', 'c': 'A 25% chance at winning $10,000', 'd': 'A 5% chance at winning $100,000'},
    3: {'a': 'Cancel the vacation', 'b': 'Take a much more modest vacation', 'c': 'Go as scheduled, reasoning that you need the time to prepare for a job search', 'd': 'Extend your vacation, because this might be your last chance to go first-class'},
    4: {'a': 'Deposit it in a bank account, money market account, or insured CD', 'b': 'Invest it in safe high-quality bonds or bond mutual funds', 'c': 'Invest it in stocks or stock mutual funds'},
    5: {'a': 'Not at all comfortable', 'b': 'Somewhat comfortable', 'c': 'Very Comfortable'},
    6: {'a': 'Loss', 'b': 'Uncertainty', 'c': 'Opportunity', 'd': 'Thrill'},
    7: {'a': 'Hold the bonds', 'b': 'Sell the bonds, put half the proceeds into money market accounts, and the other half into hard assets', 'c': 'Sell the bonds and put the total proceeds into hard assets', 'd': 'Sell the bonds, put all the money into hard assets, and borrow additional money to buy more'},
    8: {'a': '$200 gain best case; $0 gain/loss worst case', 'b': '$800 gain best case, $200 loss worst case', 'c': '$2,600 gain best case, $800 loss worst case', 'd': '$4,800 gain best case, $2,400 loss worst case'},
    9: {'a': 'A sure gain of $500', 'b': 'A 50% chance to gain $1,000 and a 50% chance to gain nothing.'},
    10: {'a': 'A sure loss of $500', 'b': 'A 50% chance to lose $1,000 and a 50% chance to lose nothing.'},
    11: {'a': 'A savings account or money market mutual fund', 'b': 'A mutual fund that owns stocks and bonds', 'c': 'A portfolio of 15 common stocks', 'd': 'Commodities like gold, silver, and oil'},
    12: {'a': '60% in low-risk investments, 30% in medium-risk investments, 10% in high-risk investments', 'b': '30% in low-risk investments, 40% in medium-risk investments, 30% in high-risk investments', 'c': '10% in low-risk investments, 40% in medium-risk investments, 50% in high-risk investments'},
    13: {'a': 'Nothing', 'b': "One month's salary", 'c': "Three month's salary", 'd': "Six month's salary"}
}
SCORING_RULES = {
    1: {'a': 4, 'b': 3, 'c': 2, 'd': 1}, 2: {'a': 1, 'b': 2, 'c': 3, 'd': 4}, 3: {'a': 1, 'b': 2, 'c': 3, 'd': 4}, 4: {'a': 1, 'b': 2, 'c': 3}, 5: {'a': 1, 'b': 2, 'c': 3},
    6: {'a': 1, 'b': 2, 'c': 3, 'd': 4}, 7: {'a': 1, 'b': 2, 'c': 3, 'd': 4}, 8: {'a': 1, 'b': 2, 'c': 3, 'd': 4}, 9: {'a': 1, 'b': 3}, 10: {'a': 1, 'b': 3},
    11: {'a': 1, 'b': 2, 'c': 3, 'd': 4}, 12: {'a': 1, 'b': 2, 'c': 3}, 13: {'a': 1, 'b': 2, 'c': 3, 'd': 4}
}


def get_llm_response(prompt_instruction, question_text, options_dict, max_retries=5):
    options_str = "\n".join([f"{key}. {value}" for key, value in options_dict.items()])
    full_content = (
        f"{prompt_instruction}\n\n"
        f"--- QUESTION ---\n{question_text}\n\n"
        f"--- OPTIONS ---\n{options_str}\n\n"
        f"Your choice:"
    )

    valid_keys = list(options_dict.keys())

    for attempt in range(max_retries):
        input_ids = tokenizer(full_content, return_tensors="pt").to(model.device)

        outputs = model.generate(
            **input_ids,
            max_new_tokens=32,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
        )

        response_ids = outputs[0][input_ids["input_ids"].shape[1]:]
        response_text = tokenizer.decode(response_ids, skip_special_tokens=True).strip().lower()

        print("RAW RESPONSE:", repr(response_text))

        # Extract the **first valid letter** from the response
        match = re.search(r"\b([a-z])\b", response_text)
        if match:
            parsed = match.group(1)
            if parsed in valid_keys:
                return parsed

        print(f"[Retry {attempt+1}/{max_retries}] Invalid output: '{response_text}'")

    return "invalid"



def get_risk_level(score):
    """Determines the risk tolerance level based on the total score."""
    if score >= 33:
        return "High tolerance for risk"
    elif 29 <= score <= 32:
        return "Above-average tolerance for risk"
    elif 23 <= score <= 28:
        return "Average/moderate tolerance for risk"
    elif 19 <= score <= 22:
        return "Below-average tolerance for risk"
    else:
        return "Low tolerance for risk"


print("\nStarting the prompt stability experiment...")
results_data = []
for prompt_name, prompt_instruction in PROMPTS_TO_TEST.items():
    print(f"\n===== TESTING PROMPT: '{prompt_name}' =====")
    for q_num, q_text in questions.items():
        for run in range(1, RUNS_PER_PROMPT + 1):
            start_time = time.time()
            print(f"  - Querying Q{q_num}, Run {run}/{RUNS_PER_PROMPT}...")
            llm_answer = get_llm_response(prompt_instruction, q_text, options[q_num])
            end_time = time.time()
            duration = end_time - start_time
            question_scores = SCORING_RULES.get(q_num, {})
            score = question_scores.get(llm_answer, 0)
            print(f"    > Received answer: '{llm_answer}' | Score: {score} | Time: {duration:.2f}s")
            record = {
                "model_id": MODEL_ID,
                "prompt_name": prompt_name,
                "question_number": q_num,
                "run_number": run,
                "llm_answer": llm_answer,
                "score": score,
                "response_time_s": duration
            }
            results_data.append(record)

print("\n===== Experiment Complete! =====")


results_df = pd.DataFrame(results_data)
if not results_df.empty:
    print("\nCalculating total scores and risk levels...")
    results_df['total_score'] = results_df.groupby(['prompt_name', 'run_number'])['score'].transform('sum')
    results_df['risk_level'] = results_df['total_score'].apply(get_risk_level)
    print("Calculations complete.")

safe_model_name = MODEL_ID.replace("/", "_")
output_filename = os.path.join(OUTPUT_DIR, f"{safe_model_name}_stability_results.csv")
results_df.to_csv(output_filename, index=False, encoding='utf-8-sig')

print(f"\nAll results have been saved to '{output_filename}'")
print("\nPreview of the final results data:")
print(results_df[['prompt_name', 'question_number', 'run_number', 'llm_answer', 'score', 'total_score', 'risk_level']].head())
print("\n--- Script Finished ---")

