# -*- coding: utf-8 -*-
"""Perplexity in Local Languages MLSFT.ipynb

Automatically generated by Colab.

This notebook:
1. Defines languages for Perplexity evaluation and extracts relevant data from Wikipedia datasets;
2. Identifies model and language pairs required for evaluation;
3. Runs language-relevant Perplexity evaluation for each model;
4. Saves results to Google Drive

Note: this code was run in Google Colab, calling documents saved within Google Drive, and was run using L4 GPU.
"""

# Check if datasets are available in the languages of study
import os
os.system('pip install -q -U "datasets<3.0.0"')

from datasets import load_dataset
TEST_LANGUAGES = {
    'en': 'English (WikiText-2)',
    'pt': 'Portuguese',
    'ga': 'Irish',
    'zh': 'Chinese',
    'es': 'Spanish',
    'da': 'Danish',
    'tl': 'Tagalog',
    'el': 'Greek',
    'hi': 'Hindi'
}
print("Wikipedia Data Test")

results = {}

for lang_code, lang_name in TEST_LANGUAGES.items():
    print(f"\n[{lang_name}] Testing {lang_code}...")

    try:
        if lang_code == 'en':
            ds = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
            sample = ds[0]['text']
            print(f"  ✓ WikiText-2 loaded successfully")
            print(f"    Sample: {sample[:100]}...")
            results[lang_code] = "SUCCESS"
        else:
            date_stamps = ["20231101", "20220301"]
            success = False

            for date_stamp in date_stamps:
                try:
                    config = f"{date_stamp}.{lang_code}"
                    print(f"  Trying: wikimedia/wikipedia @ {config}")

                    ds = load_dataset(
                        "wikimedia/wikipedia",
                        config,
                        split="train",
                        streaming=True,
                        trust_remote_code=True
                    )

                    # Try to get first sample
                    sample = next(iter(ds))
                    print(f"  ✓ SUCCESS with {config}")
                    print(f"    Title: {sample.get('title', 'N/A')}")
                    print(f"    Text preview: {sample['text'][:100]}...")
                    results[lang_code] = f"SUCCESS ({config})"
                    success = True
                    break

                except Exception as e:
                    print(f"  Failed {config}: {str(e)[:100]}")

            if not success:
                print(f"  All attempts failed for {lang_name}")
                results[lang_code] = "FAILED"

    except Exception as e:
        print(f"  ERROR: {str(e)[:200]}")
        results[lang_code] = f"ERROR: {str(e)[:50]}"

print("Summary")

for lang_code, status in results.items():
    lang_name = TEST_LANGUAGES[lang_code]
    icon = "✓" if "SUCCESS" in status else "✗"
    print(f"{icon} {lang_name} ({lang_code}): {status}")

failed = [k for k, v in results.items() if "FAILED" in v or "ERROR" in v]
if failed:
    print(f"WARNING: {len(failed)} languages failed to load:")
    for lang in failed:
        print(f"  - {TEST_LANGUAGES[lang]} ({lang})")
    print("These languages will cause errors in the main evaluation.")
else:
    print("All languages loaded successfully! You're ready to run.")

# Local Language Perplexity Evaluation

import os, sys, time, torch, gc, gspread, re, shutil, pandas as pd
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from google.auth import default
from google.colab import auth
from typing import Tuple, List, Optional
import traceback

# Configuration

# Input
INPUT_SHEET_KEY = 'YOUR SHEET ID' # Update to your Sheet or other model listing location, including target evaluation language listed

# Output
OUTPUT_SHEET_NAME = 'Multilingual_PPL_Results'
OUTPUT_TAB_NAME = 'Targeted_Evals_TokenLevel'

# Temp Cache
TEMP_MODEL_DIR = "./temp_model_cache"

# Language Mapping
LANG_MAP = {
    'English': 'en', 'EN': 'en',
    'Portuguese': 'pt', 'PT': 'pt',
    'Irish': 'ga', 'GA': 'ga', 'Gaelic': 'ga',
    'Chinese': 'zh', 'ZH': 'zh',
    'Spanish': 'es', 'ES': 'es',
    'Danish': 'da', 'DA': 'da',
    'Tagalog': 'tl', 'TL': 'tl',
    'Greek': 'el', 'EL': 'el',
    'Hindi': 'hi', 'HI': 'hi'
}

ALL_LANGUAGES = sorted(list(set(LANG_MAP.values())))

# Wikipedia dataset language codes (may be exactly the same, but is not always!)
WIKI_LANG_MAP = {
    'en': 'en',
    'pt': 'pt',
    'ga': 'ga',
    'zh': 'zh',
    'es': 'es',
    'da': 'da',
    'tl': 'tl',
    'el': 'el',
    'hi': 'hi'
}

# Eval config
TARGET_TOKENS = 250000
MAX_LENGTH = 2048
STRIDE = 2048
MAX_RETRIES = 3
RETRY_DELAY = 5

# Memory management
MAX_MODEL_SIZE_GB = 20
USE_8BIT_FOR_LARGE_MODELS = True

# Helper
def install_dependencies():
    """Ensure correct package versions are installed."""
    print("Installing dependencies...")
    os.system('pip install -q -U "datasets<3.0.0" transformers torch gspread bitsandbytes accelerate peft')
    print("Dependencies installed.\n")

def estimate_model_size(model_path: str) -> float:
    """
    Rough estimate of model size in GB based on naming conventions.
    Returns estimated size in GB.
    """
    try:
        import re
        name_lower = model_path.lower()
        match = re.search(r'(\d+)b', name_lower)
        if match:
            params_b = int(match.group(1))
            size_gb = (params_b * 1e9 * 2) / (1024**3)
            return size_gb
        return 5.0
    except:
        return 5.0

def clean_model_id(raw_path: str) -> str:
    """Removes URL prefix if present."""
    raw_path = str(raw_path).strip()
    if raw_path.startswith("https://huggingface.co/"):
        return raw_path.replace("https://huggingface.co/", "")
    return raw_path

def connect_to_sheets():
    """Authenticate and connect to Google Sheets."""
    try:
        auth.authenticate_user()
    except:
        pass
    creds, _ = default()
    return gspread.authorize(creds)

def get_input_tasks(gs_client) -> List[Tuple[str, str]]:
    """
    Load tasks from input Google Sheet.
    Logic:
    - Uses get_all_values() to handle sheets without headers.
    - If a model contains 'unsloth': Expands to ALL supported languages.
    - Else: Detects target language from the row or model name.
    """
    try:
        sh = gs_client.open_by_key(INPUT_SHEET_KEY)
        ws = sh.get_worksheet(0)
        rows = ws.get_all_values()

        if not rows:
            print("   [!] Input sheet is empty.")
            return []

        print(f"Successfully loaded input sheet: {len(rows)} rows found.")

        tasks = []
        for row in rows:
            model_path = None
            for cell_value in row:
                val = str(cell_value).strip()
                if "huggingface.co/" in val or (val.count("/") >= 1 and len(val) > 5):
                    model_path = clean_model_id(val)
                    break

            if not model_path:
                continue

            if 'unsloth' in model_path.lower():
                for lang_code in ALL_LANGUAGES:
                    tasks.append((model_path, lang_code))
                continue

            target_lang = None
            for cell_value in row:
                val = str(cell_value).strip().lower()
                if val in [k.lower() for k in LANG_MAP.keys()]:
                    target_lang = LANG_MAP.get(val, LANG_MAP.get(val.capitalize()))
                    break
                if val in LANG_MAP.values():
                    target_lang = val
                    break

            if not target_lang:
                name_upper = model_path.upper()
                for lang_name, code in LANG_MAP.items():
                    pattern = f"[-_]{code.upper()}([-_]|$)"
                    if re.search(pattern, name_upper):
                        target_lang = code
                        break

            if not target_lang:
                target_lang = 'en'

            tasks.append((model_path, target_lang))

        unique_tasks = sorted(list(set(tasks)))
        print(f"Parsed {len(unique_tasks)} unique tasks.\n")
        return unique_tasks

    except Exception as e:
        print(f"Error reading Input Sheet: {e}")
        return []

def get_output_worksheet(gs_client):
    """Get or create output worksheet."""
    try:
        sh = gs_client.open(OUTPUT_SHEET_NAME)
    except:
        sh = gs_client.create(OUTPUT_SHEET_NAME)

    try:
        ws = sh.worksheet(OUTPUT_TAB_NAME)
    except:
        ws = sh.add_worksheet(title=OUTPUT_TAB_NAME, rows="1000", cols="10")
        ws.update('A1:G1', [[
            "Model_Path", "Target_Lang", "PPL_Result",
            "Tokens_Evaluated", "Status", "Error_Log", "Timestamp"
        ]])
    return ws

def append_result_row(ws, row_data):
    """Safely append a row to the worksheet using the modern gspread API."""
    all_values = ws.get_all_values()
    next_row = len(all_values) + 1
    cell_range = f'A{next_row}:G{next_row}'
    # FIX: Pass values first or use named arguments
    ws.update(values=[row_data], range_name=cell_range)
    return next_row

# Perplexity
def get_eval_data(lang_code: str, tokenizer) -> Optional[torch.Tensor]:
    """Load evaluation dataset for the specified language."""
    try:
        if lang_code == 'en':
            print(f"      Loading WikiText-2 for English...")
            ds = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
            text = "\n\n".join([t for t in ds["text"] if t.strip()])
        else:
            wiki_code = WIKI_LANG_MAP.get(lang_code, lang_code)
            print(f"      Loading Wikipedia for {lang_code} (wiki code: {wiki_code})...")
            date_stamps = ["20231101", "20220301", "20200301"]
            ds = None
            last_error = None

            for date_stamp in date_stamps:
                try:
                    ds_name = f"{date_stamp}.{wiki_code}"
                    ds = load_dataset(
                        "wikimedia/wikipedia",
                        ds_name,
                        split="train",
                        streaming=True,
                        trust_remote_code=True
                    )
                    break
                except Exception as e:
                    last_error = str(e)
                    continue

            if ds is None:
                try:
                    ds = load_dataset("wikipedia", f"20220301.{wiki_code}", split="train", streaming=True, trust_remote_code=True)
                except Exception as e:
                    raise Exception(f"Could not load Wikipedia for {lang_code}. Last error: {last_error}")

            text_buffer = []
            curr_len = 0
            limit = TARGET_TOKENS * 6

            for sample in ds:
                if 'text' in sample and sample['text'] and sample['text'].strip():
                    text_buffer.append(sample['text'])
                    curr_len += len(sample['text'])
                    if curr_len > limit:
                        break

            if not text_buffer:
                raise Exception(f"No text collected from dataset")

            text = "\n\n".join(text_buffer)

        print(f"      Tokenizing...")
        encodings = tokenizer(text, return_tensors="pt", add_special_tokens=False)

        total_tokens = encodings.input_ids.size(1)

        if total_tokens > TARGET_TOKENS:
            encodings.input_ids = encodings.input_ids[:, :TARGET_TOKENS]

        return encodings

    except Exception as e:
        print(f"      [!] Data Load Failed: {e}")
        traceback.print_exc()
        return None

def compute_ppl(model, encodings) -> Tuple[float, int]:
    """Compute perplexity using sliding window approach."""
    try:
        model.eval()
        if hasattr(model, "for_inference"):
            model.for_inference()

        max_len = MAX_LENGTH
        if hasattr(model.config, 'max_position_embeddings'):
            max_len = min(model.config.max_position_embeddings, MAX_LENGTH)

        stride = STRIDE
        seq_len = encodings.input_ids.size(1)
        nlls = []
        prev_end_loc = 0
        total_tokens = 0

        print(f"      Computing PPL with max_len={max_len}, stride={stride}...")

        for begin_loc in range(0, seq_len, stride):
            end_loc = min(begin_loc + max_len, seq_len)
            trg_len = end_loc - prev_end_loc

            input_ids = encodings.input_ids[:, begin_loc:end_loc].to(model.device)
            attention_mask = torch.ones_like(input_ids).to(model.device)

            target_ids = input_ids.clone()
            target_ids[:, :-trg_len] = -100

            with torch.no_grad():
                outputs = model(
                    input_ids,
                    attention_mask=attention_mask,
                    labels=target_ids
                )

                if not torch.isnan(outputs.loss):
                    nlls.append(outputs.loss * trg_len)
                    total_tokens += trg_len

            prev_end_loc = end_loc
            if end_loc == seq_len:
                break

        if not nlls or total_tokens == 0:
            return float('nan'), 0

        ppl = torch.exp(torch.stack(nlls).sum() / total_tokens).item()
        return ppl, total_tokens

    except Exception as e:
        print(f"      [!] PPL Computation Failed: {str(e)[:250]}")
        return float('nan'), 0

# Clean up
def cleanup_memory():
    """Aggressive memory cleanup."""
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

def cleanup_disk_cache():
    """Remove temporary model cache from disk."""
    if os.path.exists(TEMP_MODEL_DIR):
        try:
            shutil.rmtree(TEMP_MODEL_DIR)
        except:
            pass

# Run!
def run_targeted_evals():
    """Main evaluation loop."""
    print("=" * 60)
    print("ICML Multilingual Perplexity Evaluation (Unsloth Expanded)")
    print("=" * 60)

    install_dependencies()
    gs_client = connect_to_sheets()
    tasks = get_input_tasks(gs_client)

    if not tasks:
        print("No tasks found. Exiting.")
        return

    output_ws = get_output_worksheet(gs_client)
    existing_records = output_ws.get_all_values()
    done_set = set()
    if len(existing_records) > 1:
        for row in existing_records[1:]:
            if row and len(row) > 1:
                key = f"{clean_model_id(row[0])}|{row[1]}"
                done_set.add(key)

    remaining_tasks = []
    for m, l in tasks:
        key = f"{m}|{l}"
        if key not in done_set:
            remaining_tasks.append((m, l))

    print(f"\nTotal tasks: {len(tasks)}")
    print(f"Already completed: {len(done_set)}")
    print(f"Remaining: {len(remaining_tasks)}\n")

    for idx, (model_path, lang) in enumerate(remaining_tasks, 1):
        print(f"\n[{idx}/{len(remaining_tasks)}] Processing: {model_path}")
        print(f"Target Language: {lang.upper()}")
        print("-" * 60)

        cleanup_disk_cache()
        model = None
        tokenizer = None

        try:
            estimated_size = estimate_model_size(model_path)
            use_8bit = estimated_size > 15 and USE_8BIT_FOR_LARGE_MODELS
            retry_count = 0
            while retry_count < MAX_RETRIES:
                try:
                    tokenizer = AutoTokenizer.from_pretrained(model_path, cache_dir=TEMP_MODEL_DIR, trust_remote_code=True)

                    if use_8bit:
                        model = AutoModelForCausalLM.from_pretrained(
                            model_path, load_in_8bit=True, device_map="auto",
                            trust_remote_code=True, cache_dir=TEMP_MODEL_DIR
                        )
                    else:
                        model = AutoModelForCausalLM.from_pretrained(
                            model_path, torch_dtype=torch.bfloat16, device_map="auto",
                            trust_remote_code=True, cache_dir=TEMP_MODEL_DIR
                        )
                    model.eval()
                    break
                except Exception as e:
                    retry_count += 1
                    if retry_count >= MAX_RETRIES: raise
                    print(f"   Load failed, retrying... {str(e)[:100]}")
                    time.sleep(RETRY_DELAY)
                    cleanup_memory()

            encodings = get_eval_data(lang, tokenizer)
            if encodings:
                ppl, tokens_evaluated = compute_ppl(model, encodings)
                if tokens_evaluated > 0:
                    print(f"   -> Success! PPL: {ppl:.2f}")
                    append_result_row(output_ws, [model_path, lang, round(ppl, 4), tokens_evaluated, "Success", "", time.strftime("%Y-%m-%d %H:%M:%S")])
                else:
                    append_result_row(output_ws, [model_path, lang, "", "0", "Compute_Fail", "NaN or 0 tokens", time.strftime("%Y-%m-%d %H:%M:%S")])
            else:
                append_result_row(output_ws, [model_path, lang, "", "0", "Data_Fail", "No data loaded", time.strftime("%Y-%m-%d %H:%M:%S")])

        except Exception as e:
            print(f"   [!] Failed: {str(e)[:200]}")
            append_result_row(output_ws, [model_path, lang, "", "0", "Model_Fail", str(e)[:200], time.strftime("%Y-%m-%d %H:%M:%S")])

        finally:
            del model, tokenizer
            cleanup_memory()
            cleanup_disk_cache()
            time.sleep(2)

if __name__ == "__main__":
    run_targeted_evals()