from openai import OpenAI
from tqdm import tqdm
from nanoGPT.typo_gen import typo_generator
from datasets import get_dataset_config_names
from pathlib import Path
import json, time
from datasets import load_dataset
from typing import Dict, Any
from tqdm import tqdm


import argparse

argparser = argparse.ArgumentParser()
argparser.add_argument("--noise", type=int, default=0)
argparser.add_argument("--seed", type=int, default=42)
args = argparser.parse_args()
NOISE = str(args.noise)
SEED = args.seed

api_key = "<API_KEY>"
client = OpenAI(api_key=api_key)

MODEL_NAME = "gpt-4o-mini"  
OUTPUT_DIR = Path(MODEL_NAME)
OUTPUT_DIR.mkdir(exist_ok=True)
RESULTS_JSONL = OUTPUT_DIR / f"mmlu_raw_{MODEL_NAME}_{NOISE}_{str(SEED)}.jsonl"
SUMMARY_CSV = OUTPUT_DIR / f"mmlu_summary_{MODEL_NAME}_{NOISE}_{str(SEED)}.csv"
DETAIL_CSV = OUTPUT_DIR / f"mmlu_details_{MODEL_NAME}_{NOISE}_{str(SEED)}.csv"

SUBJECTS = get_dataset_config_names("cais/mmlu")


LETTER_CHOICES = ["A","B","C","D"]

SYSTEM_PROMPT = (
    "You are an expert test taker. Answer the multiple choice question. "
    "Respond with only the letter (A, B, C, or D)."
)

def load_subject_dataset(subject: str):
    ds = load_dataset("cais/mmlu", subject)
    return ds["test"]

def row_to_question(row: Dict[str, Any]) -> Dict[str, Any]:
    if all(k in row for k in ["answerA", "answerB", "answerC", "answerD", "correct"]):
        choices = [row["answerA"], row["answerB"], row["answerC"], row["answerD"]]
        answer_letter = LETTER_CHOICES[row["correct"]]
    elif "choices" in row and "answer" in row:
        choices = row["choices"]
        if isinstance(row["answer"], int):
            answer_letter = LETTER_CHOICES[row["answer"]]
        else:
            ans_raw = str(row["answer"]).strip().upper()
            answer_letter = ans_raw[0] if ans_raw and ans_raw[0] in LETTER_CHOICES else ans_raw
    return {
        "question": typo_generator(row["question"], int(NOISE), seed=SEED),
        "choices": choices,
        "answer": answer_letter,
    }

def build_prompt(q: Dict[str, Any], subject: str) -> str:
    lines = [f"Subject: {subject}", q["question"]]
    for idx, choice in enumerate(q["choices"]):
        lines.append(f"{LETTER_CHOICES[idx]}. {choice}")
    lines.append("Answer:")
    return "\n".join(lines)


# # For local model with SAE
# from llm_infra import ObservableModel
# from sae_lens import SAE
# 
# model = ObservableModel.from_pretrained(MODEL_NAME)  
# sae, cfg_dict, sparsity = SAE.from_pretrained(
#     release = "SAE_RELEASE",
#     sae_id = "SAE_ID",
#     device = "cuda",
# ) 
# def call_model(prompt: str) -> str:
#     response = model.generate(prompt, max_new_tokens=1, hookpoint="model.layers.20")
#     sae_latents = sae.encode(response['input']['activations'])
#     l0 = (sae_latents > 0).any(dim=1).sum().item()
#     return response['output']['text'], l0
#

def call_model(prompt: str, max_retries: int = 5, backoff: float = 2.0) -> str:
    for attempt in range(max_retries):
        try:
            response = client.chat.completions.create(
                model=MODEL_NAME,
                messages=[
                    {"role": "system", "content": SYSTEM_PROMPT},
                    {"role": "user", "content": prompt},
                ],
                temperature=0,
                max_tokens=4,
            )
            text = response.choices[0].message.content.strip()
            for ch in text:
                if ch in LETTER_CHOICES:
                    return ch
            return text
        except Exception as e:
            wait = backoff ** attempt
            print(f"Error attempt {attempt+1}: {e} | retrying in {wait:.1f}s")
            time.sleep(wait)
    raise RuntimeError("Model call failed after retries")


with open(RESULTS_JSONL, "w") as f_out:
    for subject in SUBJECTS:
        if subject == "all":
            continue
        ds = load_subject_dataset(subject)
        indices = list(range(len(ds)))
        for idx in tqdm(indices, desc=f"{subject}"):
            row = ds[idx]
            q = row_to_question(row)
            prompt = build_prompt(q, subject)
            pred = call_model(prompt)
            is_correct = (pred == q["answer"])
            rec = {
                "subject": subject,
                "question": q["question"],
                "choices": q["choices"],
                "gold": q["answer"],
                "prediction": pred,
                "correct": is_correct,
            }
            f_out.write(json.dumps(rec) + "\n")
