# -*- coding: utf-8 -*-
"""TinyMMLU Eval.ipynb

Automatically generated by Colab.

This notebook:
1. Identifies models required to be evlauated from a Google Sheet stored in Drive directed to Hugging Face paths;
2. Runs the TinyMMLU Evaluation on each model;
3. Saves the evaluation results to a Google Sheet in Drive

Note: this code was run in Google Colab, calling documents saved within Google Drive, and was run using L4 GPU.
"""

import os
import gc
import json
import torch
import warnings
import numpy as np
import time
from tqdm.auto import tqdm
from google.colab import auth, drive, userdata
from google.auth import default
import gspread
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import login
import shutil
from pathlib import Path
import tinyBenchmarks as tb

print("HF Token:")
login()

drive.mount('/content/drive')
auth.authenticate_user()
creds, _ = default()
gc_sheets = gspread.authorize(creds)


# Configuration
INPUT_SHEET_URL = 'INSERT' # Your link to the list of models to be evaluated
RESULTS_SHEET_TITLE = "YOUR FILE" # Point to a Sheet to save your file
RAW_RESULTS_PATH = '/content/drive/MyDrive/tiny_mmlu_raw_results.jsonl' # Adjust to the location for saving raw results

# Setup Dependencies
!pip install -q git+https://github.com/felipemaiapolo/tinyBenchmarks.git
!pip install -q transformers accelerate bitsandbytes gspread google-auth


# Load data
print("Loading Model List")
try:
    input_sh = gc_sheets.open_by_url(INPUT_SHEET_URL)
    input_worksheet = input_sh.get_worksheet(0)
    model_urls = input_worksheet.col_values(1)
    model_urls = [url for url in model_urls if 'huggingface.co' in url]
    print(f"Found {len(model_urls)} models.")
except Exception as e:
    print(f"Error loading input sheet: {e}")
    model_urls = []

# Prepare output sheet
try:
    results_sh = gc_sheets.open(RESULTS_SHEET_TITLE)
    results_worksheet = results_sh.sheet1
    print(f"Appending to existing sheet: {RESULTS_SHEET_TITLE}")
except gspread.SpreadsheetNotFound:
    results_sh = gc_sheets.create(RESULTS_SHEET_TITLE)
    results_worksheet = results_sh.sheet1
    results_worksheet.append_row(["Model URL", "Model ID", "TinyMMLU Score (IRT)", "Raw Accuracy"])
    print(f"Created new sheet: {RESULTS_SHEET_TITLE}")

print("Loading Dataset...")
dataset = load_dataset("tinyBenchmarks/tinyMMLU", split="test")

def format_prompt(example):
    options = ["A", "B", "C", "D"]
    text = f"Question: {example['question']}\n"
    for i, choice in enumerate(example['choices']):
        text += f"{options[i]}. {choice}\n"
    text += "Answer:"
    return text

def get_answer_token_ids(tokenizer):
    """
    Robustly finds the token IDs for A, B, C, D regardless of tokenizer type.
    Handles Llama (sentencepiece), Qwen (tiktoken), and Gemma nuances.
    """
    options = ["A", "B", "C", "D"]
    mapping = {}

    for i, opt in enumerate(options):
        ids_with_space = tokenizer.encode(" " + opt, add_special_tokens=False)
        ids_no_space = tokenizer.encode(opt, add_special_tokens=False)

        candidates = []
        if len(ids_with_space) == 1:
            candidates.append(ids_with_space[0])
        if len(ids_no_space) == 1:
            candidates.append(ids_no_space[0])
        if not candidates:
            mapping[i] = ids_no_space[-1]
        else:
            mapping[i] = candidates[0]

    return mapping

def evaluate_model(model_id, dataset):
    try:
        # Load Tokenizer
        tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16

        model = AutoModelForCausalLM.from_pretrained(
            model_id,
            device_map="auto",
            torch_dtype=dtype,
            trust_remote_code=True,
            low_cpu_mem_usage=True
        )
        model.eval()

        answer_map = get_answer_token_ids(tokenizer)
        results = []

        with torch.no_grad():
            for example in dataset:
                prompt_text = format_prompt(example)
                inputs = None
                if inputs is None:
                    inputs = tokenizer(prompt_text, return_tensors="pt").to(model.device)
                outputs = model(**inputs)
                logits = outputs.logits[0, -1, :]
                option_scores = [logits[answer_map[i]].item() for i in range(4)]
                predicted_idx = np.argmax(option_scores)

                results.append(1 if predicted_idx == example['answer'] else 0)

        # Clean up
        del model
        del tokenizer
        torch.cuda.empty_cache()
        gc.collect()
        return np.array(results)


    except Exception as e:
        print(f"Error evaluating {model_id}: {e}")
        if 'model' in locals(): del model
        if 'tokenizer' in locals(): del tokenizer
        torch.cuda.empty_cache()
        gc.collect()
        return None

# Loop

completed_models = set()
if os.path.exists(RAW_RESULTS_PATH):
    with open(RAW_RESULTS_PATH, 'r') as f:
        for line in f:
            try:
                data = json.loads(line)
                completed_models.add(data['model_id'])
            except: pass

print(f"Resuming: {len(completed_models)} models already evaluated.")

for url in tqdm(model_urls):
    model_id = url.replace("https://huggingface.co/", "").strip()

    if model_id in completed_models:
        continue

    print(f"\nEvaluating: {model_id}")
    y_preds = evaluate_model(model_id, dataset)

    if y_preds is not None:
        try:
            scores = tb.evaluate(y_preds, 'mmlu')
            irt_score = scores['mmlu']['irt']
        except:
            irt_score = np.mean(y_preds)

        raw_accuracy = np.mean(y_preds)

        # Save to JSONL
        with open(RAW_RESULTS_PATH, 'a') as f:
            f.write(json.dumps({
                "model_id": model_id,
                "irt_score": irt_score,
                "raw_accuracy": raw_accuracy,
                "raw_predictions": y_preds.tolist()
            }) + "\n")

        # Save to Sheet
        try:
            results_worksheet.append_row([url, model_id, irt_score, raw_accuracy])
        except Exception as e:
            print(f"Sheet update failed: {e}")
            time.sleep(5) # Wait for API quota
    else:
        with open(RAW_RESULTS_PATH, 'a') as f:
            f.write(json.dumps({"model_id": model_id, "error": "Failed"}) + "\n")
    if 'model' in locals(): del model
    if 'tokenizer' in locals(): del tokenizer
    torch.cuda.empty_cache()
    gc.collect()

    try:
        model_slug = model_id.replace("/", "--")
        cache_path = Path.home() / ".cache/huggingface/hub" / f"models--{model_slug}"

        if cache_path.exists() and cache_path.is_dir():
            shutil.rmtree(cache_path)
            print(f"Deleted disk cache for {model_id}")
    except Exception as e:
        print(f"Could not clear disk cache: {e}")