# -*- coding: utf-8 -*-
"""Perplexity Evaluation English.ipynb

Automatically generated by Colab.

This notebook:

1. Identifies models from Hugging Face hub meeting requirements for Perplexity evaluation;
2. Runs WikiText2 evaluation on each in-scope model;
3. Saves outputs to Google Drive

## Set up and config
"""

!pip install --force-reinstall --no-cache-dir --no-deps git+https://github.com/unslothai/unsloth.git
!pip install --upgrade --no-cache-dir "git+https://github.com/huggingface/transformers.git"
!pip install -q --no-deps bitsandbytes accelerate peft trl triton cut_cross_entropy unsloth_zoo
!pip install -q "lm_eval[transformers]"
!pip install -q huggingface_hub
!pip install -q sentencepiece protobuf "datasets>=3.4.1,<4.0.0"


from google.colab import drive
from google.colab import auth
from huggingface_hub import notebook_login
import os
import pandas as pd
import gc
import torch
import gspread
from google.auth import default
from oauth2client.service_account import ServiceAccountCredentials
from datetime import datetime
from huggingface_hub import HfApi
import lm_eval
from lm_eval.models.huggingface import HFLM
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import traceback

drive.mount('/content/drive')
auth.authenticate_user()
creds, _ = default()
gc_sheets = gspread.authorize(creds)

SHEET_NAME = "Perplexity Results" # Define your sheet output name (might be an existing Sheet)
WORKSHEET_NAME = "Results"

try:
    spreadsheet = gc_sheets.open(SHEET_NAME)
    print(f"Opened existing sheet: {SHEET_NAME}")
except gspread.SpreadsheetNotFound:
    spreadsheet = gc_sheets.create(SHEET_NAME)
    spreadsheet.share('', perm_type='anyone', role='writer')
    print(f"Created new sheet: {SHEET_NAME}")
    print(f"Access it at: {spreadsheet.url}")

try:
    worksheet = spreadsheet.worksheet(WORKSHEET_NAME)
except gspread.WorksheetNotFound:
    worksheet = spreadsheet.add_worksheet(title=WORKSHEET_NAME, rows=1000, cols=10)
    worksheet.update('A1:C1', [['adapter', 'base_model', 'perplexity']])
    print(f"Created worksheet: {WORKSHEET_NAME}")
print(f"Sheet URL: {spreadsheet.url}")

print("\nPlease login to Hugging Face:")
notebook_login()

"""## Run Perplexity Evaluation"""

# Options: "gemma", "llama", "qwen"
TARGET_FAMILY = "gemma"

# filter models in-scope from Hugging Face, the suffixes included here are (anonymized) identifiers used for the paper
VALID_SUFFIXES = ("1a", "1b", "2a", "2b", "3a", "3b") # We use numbers and letters to signify number of epochs fine-tuned, random seed variants, and ablation runs
REQ_KEYWORD = "ENTER" # We use a keyword to identify all models from our model library which are relevant to this paper

USERNAME = # Enter your username
SHEET_NAME = "Perplexity Results"

def cleanup_memory():
    if 'model' in globals(): del globals()['model']
    if 'tokenizer' in globals(): del globals()['tokenizer']
    if 'lm_obj' in globals(): del globals()['lm_obj']

    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()

api = HfApi()
print(f"Scanning models for user: {USERNAME} | Family: {TARGET_FAMILY}...")

try:
    all_models = api.list_models(author=USERNAME)
    target_models = []

    for m in all_models:
        mid = m.modelId
        mid_lower = mid.lower()
        if TARGET_FAMILY not in mid_lower:
            continue
        if REQ_KEYWORD not in mid_lower:
            continue
        if not mid_lower.endswith(VALID_SUFFIXES):
            continue

        target_models.append(mid)

    print(f"Found {len(target_models)} valid '{TARGET_FAMILY}' models matching criteria.")
    print(f"Examples: {target_models[:3]}")
except Exception as e:
    print(f"Error listing models: {e}")
    target_models = []

try:
    spreadsheet = gc_sheets.open(SHEET_NAME)
    worksheet = spreadsheet.worksheet("Results")
    existing_data = worksheet.get_all_records()
    completed_models = [row['model_id'] for row in existing_data if 'model_id' in row]
    print(f"Resuming: Found {len(completed_models)} completed entries in sheet.")
except Exception as e:
    print(f"Could not read existing sheet (Starting fresh?): {e}")
    completed_models = []

for i, model_id in enumerate(target_models, 1):
    print(f"{datetime.now().strftime('%H:%M:%S')} | Model {i}/{len(target_models)} | {model_id}")
    if model_id in completed_models:
        print(f"Skipping (Already in sheet)")
        continue

    cleanup_memory()

    try:
        print(f"Loading Full Model: {model_id}")
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16,
            bnb_4bit_use_double_quant=True,
        )
        model = AutoModelForCausalLM.from_pretrained(
            model_id,
            quantization_config=bnb_config,
            device_map="auto",
            trust_remote_code=True,
            torch_dtype=torch.bfloat16
        )

        tokenizer = AutoTokenizer.from_pretrained(
            model_id,
            trust_remote_code=True
        )

        lm_obj = HFLM(
            pretrained=model,
            tokenizer=tokenizer,
            batch_size=1,
            max_length=4096,
            trust_remote_code=True
        )
        print("Calculating Perplexity (wikitext)...")
        results = lm_eval.simple_evaluate(
            model=lm_obj,
            tasks=["wikitext"],
            device="cuda:0",
            log_samples=False,
        )

        if 'results' in results:
            res_dict = results['results']
            target_key = next((k for k in res_dict.keys() if 'wiki' in k), list(res_dict.keys())[0])
            metrics = res_dict[target_key]
            ppl = metrics.get('perplexity') or metrics.get('word_perplexity') or metrics.get('byte_perplexity')

            if ppl:
                ppl = round(float(ppl), 4)
                print(f"SCORE: {ppl}")
                worksheet.append_row([model_id, "MERGED_MODEL", ppl])
            else:
                print(f"Metric missing. Available: {list(metrics.keys())}")
                worksheet.append_row([model_id, "Full Model", "MISSING_METRIC"])

    except Exception as e:
        print(f"CRASHED: {e}")
        worksheet.append_row([model_id, "ERROR", str(e)[:100]])

    finally:
        cleanup_memory()

"""## Evaluate base models (Instruction Tuned variants)"""

BASE_MODEL_ID = "google/gemma-3-4b-it" # Define base for evaluation
print(f"Establishing Baseline for: {BASE_MODEL_ID}")

try:
    print(f"🚀 Loading Base Model {BASE_MODEL_ID} using Transformers...")
    model = AutoModelForCausalLM.from_pretrained(
        BASE_MODEL_ID,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        trust_remote_code=True,
    )

    tokenizer = AutoTokenizer.from_pretrained(
        BASE_MODEL_ID,
        trust_remote_code=True
    )

    lm_obj = HFLM(
        pretrained=model,
        tokenizer=tokenizer,
        batch_size=1,
        trust_remote_code=True
    )

    print("Running Evaluation (Task: wikitext)...")
    results = lm_eval.simple_evaluate(
        model=lm_obj,
        tasks=["wikitext"],
        device="cuda:0",
        log_samples=False,
    )

    try:
        if 'results' not in results:
            raise ValueError("No 'results' key in output.")

        res_dict = results['results']

        target_key = None
        for key in res_dict.keys():
            if 'wiki' in key.lower():
                target_key = key
                break
        if not target_key: target_key = list(res_dict.keys())[0]

        task_metrics = res_dict[target_key]
        ppl = None
        for key in task_metrics.keys():
            if 'perplexity' in key.lower() and 'stderr' not in key.lower():
                ppl = task_metrics[key]
                break

        if ppl:
            ppl = round(float(ppl), 4)
            print(f"Baseline score ({BASE_MODEL_ID}): {ppl}")

            worksheet.append_row([BASE_MODEL_ID, "BASELINE", ppl])
        else:
            print(f"Metric missing. Keys: {list(task_metrics.keys())}")

    except Exception as parse_err:
        print(f"Error parsing results: {parse_err}")
        traceback.print_exc()

except Exception as e:
    print(f"CRASHED: {e}")
    print("Traceback:")
    traceback.print_exc()

if 'lm_obj' in locals(): del lm_obj
if 'model' in locals(): del model
if 'tokenizer' in locals(): del tokenizer
if torch.cuda.is_available(): torch.cuda.empty_cache()
gc.collect()