import pandas as pd
import numpy as np
import random
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import matplotlib.pyplot as plt
from tqdm import tqdm
import os

FILENAME = ""  

import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../")))
from utility_functions import *

try:
    df = pd.read_csv(FILENAME)
    print(f"✅ Successfully loaded '{FILENAME}'.")
    required_columns = ['reward1_a', 'p_a', 'reward2_a', 'reward1_b', 'p_b', 'reward2_b', 'prompt_text']
    if not all(col in df.columns for col in required_columns):
        print(f"❌ ERROR: Your CSV file is missing one of the required columns: {required_columns}")
        exit()
except FileNotFoundError:
    print(f"❌ ERROR: Input file '{FILENAME}' not found. Please make sure the file is in the correct directory.")
    exit()




GROUND_TRUTHS = {
    'CRRA (θ=0.71)': ('crra', {'theta': 0.71}),
    'CRRA (θ=-5)': ('crra', {'theta': -5.0}),
    'CRRA (θ=1)': ('crra', {'theta': 1.0}),
    'CARA (α=0.1)': ('cara', {'alpha': 0.1}),
    'CARA (α=2)': ('cara', {'alpha': 2.0}),
    'Prospect Theory (α=0.88, β=0.88, λ=2.25, reference_point=500)': ('prospect_theory', {'alpha': 0.88, 'beta': 0.88, 'lam': 2.25, 'reference_point':500}),
}

def compute_ground_truth(row, func, params):
    U_A = row['p_a'] * func(row['reward1_a'], **params) + \
          (1 - row['p_a']) * func(row['reward2_a'], **params)
    U_B = row['p_b'] * func(row['reward1_b'], **params) + \
          (1 - row['p_b']) * func(row['reward2_b'], **params)
    return 'A' if U_A > U_B else 'B'

# --- 3. Load Model and Tokenizer ---
MODEL_ID = "/home/ykwang/mtdata/models/Llama-3.1-8B"
device = 0 if torch.cuda.is_available() else -1

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, device_map=device if device!=-1 else None, torch_dtype=torch.float16)
llm_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=25, temperature=0.8)

# --- 4. Define Prompt Function ---
def make_prompt(train_examples, test_row):
    prompt = "You are a decision-making assistant. Follow the examples' risk attitude and choose the option (A or B).\n\n"
    if not train_examples.empty:
        prompt += "Here are some examples:\n"
        for _, row in train_examples.iterrows():
            prompt += f"Question: {row['prompt_text']}\nChoice: {row['ground_truth']}\n\n"
    prompt += f"Now predict the choice for the next question:\nQuestion: {test_row['prompt_text']}\nChoice: "
    return prompt


N_CONTEXT_LIST = list(range(0, 41))  
N_RUNS = 50 
all_results = {}


for name, (func_type, params) in GROUND_TRUTHS.items():
    print("\n" + "="*50)
    print(f"--- Running Experiment for Ground Truth: {name} ---")
    print("="*50)


    if func_type == 'crra':
        utility_func = crra_utility
    elif func_type == 'cara':
        utility_func = cara_utility
    else:
        utility_func = prospect_theory_value

    # Calculate the ground truth for the entire dataframe based on the current utility function
    df['ground_truth'] = df.apply(lambda row: compute_ground_truth(row, utility_func, params), axis=1)
    
    results_this_run = []
    for n_context in N_CONTEXT_LIST:
        acc_runs = []
        for _ in tqdm(range(N_RUNS), desc=f"Context {n_context}"):
            train_examples = df.sample(n=n_context) if n_context > 0 else pd.DataFrame()
            test_pool = df.drop(train_examples.index) if n_context > 0 else df
            test_row = test_pool.sample(n=1).iloc[0]

            prompt = make_prompt(train_examples, test_row)
            out = llm_pipeline(prompt, pad_token_id=tokenizer.eos_token_id)[0]['generated_text'][len(prompt):].strip()
            
            pred = 'A' if 'A' in out[:2] else 'B' 
            # # Determine prediction
            # if 'A' in out[:2]:
            #     pred = 'A'
            # elif 'B' in out[:2]:
            #     pred = 'B'
            # else:
            #     # Skip if prediction is not A or B
            #     continue

            acc_runs.append(int(pred == test_row['ground_truth']))
        
        # Only calculate accuracy if there were valid predictions
        accuracy = np.mean(acc_runs) if acc_runs else float('nan')
        results_this_run.append({'n_context': n_context, 'accuracy': accuracy})
        print(f"Context {n_context}: {accuracy * 100:.2f}%")

    all_results[name] = pd.DataFrame(results_this_run)



combined_results_df = pd.concat(all_results, names=['Ground Truth Model', 'run_index']).reset_index()
output_dir = "in_context_comparison"
os.makedirs(output_dir, exist_ok=True)
output_file = os.path.join(output_dir, "multi_utility_in_context_results.csv")
combined_results_df.to_csv(output_file, index=False)
print(f"\nCombined results saved to '{output_file}'.")


plt.style.use('seaborn-v0_8-whitegrid')
plt.figure(figsize=(12, 8))

for name, results_df in all_results.items():
    plt.plot(results_df['n_context'], results_df['accuracy'], marker='o', linestyle='-', label=name)

plt.xlabel("Number of In-Context Examples", fontsize=14)
plt.ylabel("Accuracy", fontsize=14)
plt.title("LLM In-Context Learning for Different Utility Functions", fontsize=18, weight='bold')
plt.legend(title="Ground Truth Model", fontsize=11)
plt.grid(True, which='both', linestyle='--', linewidth=0.5)
plt.xticks(N_CONTEXT_LIST)
plt.ylim(0.4, 1.0) 

plot_file = os.path.join(output_dir, "multi_utility_in_context_accuracy.png")
plt.savefig(plot_file, dpi=300)
print(f"Comparison plot saved to '{plot_file}'")