# -*- coding: utf-8 -*-
"""NLLB-200 MLS - Run SORRY-Bench with translated prompts.ipynb

Automatically generated by Colab.

# Overview

This notebook:

(1) Sets initial configurables and installs required packages;

(2) Imports translated SORRY-Bench prompts;

(3) Generates outputs from a target model to be evaluated;

(4) Saves raw generated outputs in target language to Google Drive;

(5) Translates model outputs to English using Google Translate;

(6) Runs the SORRY-Bench evaluation using the translated English outputs; and

(7) Saves the results to a Google Sheet within Google Drive

## Set-Up
"""

# Configurables
MODEL_PATH = 'XXX' # Define the model path
MODEL_ID = 'XXX' # Define the model ID

LANGUAGE = 'Hindi' # specify language for evaluation
SOURCE_LANG_CODE = "hin_Deva" # Specify language code per NLLB documentation
LANGUAGE_CODE = 'HI'
INPUT_SHEET_URL = 'YOUR LOCAL LANGUAGE PROMPT SHEET' # Add sheet of local language input evaluation prompts from data (here it's called from Google Drive)
ENGLISH_SHEET_URL = 'YOUR ENGLISH PROMPT SHEET' # Add sheet of input evaluation prompts in English from data
DESTINATION_FILE_PATH = 'data/sorry_bench/question.jsonl' # Do not change - based on default SORRY-Bench eval

# Installations and wiping memory from previous evals (remove if not necessary)

!pip install gspread gspread-dataframe google-auth-oauthlib pandas google-cloud-translate tqdm bitsandbytes accelerate -q transformers sentencepiece ctranslate2

import os
import gspread
import pandas as pd
import json
import time
from datetime import datetime
from google.colab import auth, userdata, drive
from google.auth import default
from google.cloud import translate_v2 as translate
from gspread_dataframe import set_with_dataframe
from tqdm.auto import tqdm
from tqdm.auto import tqdm
import shutil
import gc
import re
import torch
import ctranslate2
import transformers
import numpy as np


print("Clean any storage to avoid eval contamination")
try:
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
    gc.collect()
    print("VRAM & RAM explicit garbage collection triggered.")
except Exception as e:
    print(f"VRAM clear warning: {e}")

cleanup_targets = [
    "/content/.cache/huggingface",
    os.path.expanduser("~/.cache/huggingface"),
    "/root/.cache/huggingface",

    os.path.expanduser("~/.triton"),
    os.path.expanduser("~/.cache/torch"),
    os.path.expanduser("~/.nv"),
    "/tmp/vllm",
    "/tmp/ray",
    "/dev/shm",
]

for target in cleanup_targets:
    if os.path.exists(target):
        print(f"Cleaning {target}...")
        try:
            shutil.rmtree(target)
        except Exception as e:
            print(f"Python delete failed, forcing system delete on {target}")
            os.system(f"rm -rf {target}")

autorater_lfs = "/content/sorry-bench/ckpts/finetuned_models/ft-mistral-7b-instruct-v0.2-sorry-bench-202406/.git/lfs"
if os.path.exists(autorater_lfs):
    print("Removing autorater Git LFS cache...")
    shutil.rmtree(autorater_lfs)

output_dirs = [
    "/content/sorry-bench/data/sorry_bench/model_answer",
    "/content/sorry-bench/data/sorry_bench/model_judgment"
]
for dir_path in output_dirs:
    if os.path.exists(dir_path):
        print(f"🧹 Sweeping output directory: {dir_path}")
        for file in os.listdir(dir_path):
            os.remove(os.path.join(dir_path, file))

print(f"Autorater Preserved.")
print(f"Ready for fresh download of: {MODEL_PATH}")

print("\nCloning required repositories...")
if not os.path.exists('sorry-bench'):
    !git clone https://github.com/sorry-bench/sorry-bench.git
os.chdir('sorry-bench')

if not os.path.exists('../FastChat'):
    os.chdir('..')
    !git clone https://github.com/lm-sys/FastChat.git
    os.chdir('FastChat')
    !pip install -e ".[model_worker,llm_judge]" -q
    !pip install vllm -q
    os.chdir('../sorry-bench')

print("Setup complete. Now authenticating.")

try:
    auth.authenticate_user()
    creds, _ = default()
    gc = gspread.authorize(creds)
    print("Google user authenticated successfully.")
    hf_token = userdata.get('HF_KEY') # Ensure HF_KEY is defined
    print("✓ Hugging Face token found.")
    drive.mount('/content/drive', force_remount=True)
    print("✓ Google Drive mounted successfully.")

except userdata.SecretNotFoundError:
    raise ValueError("Hugging Face token not found. Please add it to Colab Secrets with the name 'HF_KEY'.")
except Exception as e:
    print(f"An error occurred during setup: {e}")
    raise

"""## Import Prompt Data"""

print("Importing Prompts (Local Language & English Reference)")

TARGET_SHEET_URL = INPUT_SHEET_URL  # Defined in your first cell

try:
    # Load Local prompts
    print(f"🔹 Loading Target Language ({LANGUAGE_CODE}) prompts...")
    sh_target = gc.open_by_url(TARGET_SHEET_URL)
    ws_target = sh_target.get_worksheet(0)
    df_target = pd.DataFrame(ws_target.get_all_records())
    df_target['turns'] = df_target['turns'].apply(lambda x: [x] if isinstance(x, str) else x)
    os.makedirs(os.path.dirname(DESTINATION_FILE_PATH), exist_ok=True)
    df_target.to_json(DESTINATION_FILE_PATH, orient='records', lines=True, force_ascii=False)
    print(f"Saved {len(df_target)} target prompts to: {DESTINATION_FILE_PATH}")

    # Load English prompts
    print(f"🔹 Loading English Reference prompts...")
    sh_eng = gc.open_by_url(ENGLISH_SHEET_URL)
    ws_eng = sh_eng.get_worksheet(0)
    df_eng = pd.DataFrame(ws_eng.get_all_records())

    # Transform data for eval
    df_eng['turns'] = df_eng['turns'].apply(lambda x: [x] if isinstance(x, str) else x)
    EN_FILE_PATH = 'data/sorry_bench/question_en.jsonl'
    df_eng.to_json(EN_FILE_PATH, orient='records', lines=True, force_ascii=False)
    print(f"Saved {len(df_eng)} English prompts to: {EN_FILE_PATH}")

    # Validate
    if len(df_target) != len(df_eng):
        print(f"⚠️ WARNING: Row count mismatch! Target: {len(df_target)}, English: {len(df_eng)}")

except Exception as e:
    print(f"🛑 An error occurred: {e}")
    raise

"""## Generate Outputs"""

# Commented out IPython magic to ensure Python compatibility.
# %cd /content/sorry-bench

!git checkout gen_model_answer_vllm.py
file_path = 'gen_model_answer_vllm.py'
with open(file_path, 'r') as f:
    content = f.read()
if "import re" not in content:
    content = "import re\n" + content

# Set context window
if "max_model_len=8192" not in content:
    content = content.replace(
        "model = LLM(",
        "model = LLM(max_model_len=8192, "
    )

# Set max tokens and avoid non-Assistant turns being considered by the autorater (for Llama models)
if 'max_tokens=' not in content:
    content = content.replace(
        "SamplingParams(",
        "SamplingParams(max_tokens=4096, stop=[\"### Human:\", \"### Assistant:\", \"<|eot_id|>\"], "
    )
    print("Max tokens increased to 4096 & stop tokens added.")
else:
    print("Token length present.")

# Remove thinking tokens for Qwen evals (to avoid thinking tokens confusing the autorater)
if "re.sub(r'<think>" not in content:
    content = content.replace(
        "output.outputs[0].text",
        "re.sub(r'<think>.*?(?:</think>|$)', '', output.outputs[0].text, flags=re.DOTALL).strip()"
    )

with open(file_path, 'w') as f:
    f.write(content)

# Generate
cmd = (
    f"python gen_model_answer_vllm.py "
    f"--bench-name sorry_bench "
    f"--model-path {MODEL_PATH} "
    f"--model-id {MODEL_ID} "
    f"--dtype bfloat16"
)

print(f"🚀 Running command:\n{cmd}")
!{cmd}


# In case thinking tokens not successfully removed, post-process outputs to identify and remove (for Qwen models)
print("Removing Thinking Tokens")
output_file_path = f'data/sorry_bench/model_answer/{MODEL_ID}.jsonl'
if not os.path.exists(output_file_path):
    print(f"Error: File not found at {output_file_path}. Did the generation step finish?")
else:
    print(f"Processing file: {output_file_path}")

    cleaned_data = []
    cleaned_count = 0
    with open(output_file_path, 'r', encoding='utf-8') as f:
        for line in f:
            item = json.loads(line)
            original_text = item['choices'][0]['turns'][0]
            cleaned_text = re.sub(r'<think>.*?</think>', '', original_text, flags=re.DOTALL).strip()
            if len(cleaned_text) != len(original_text):
                cleaned_count += 1
            item['choices'][0]['turns'][0] = cleaned_text
            cleaned_data.append(item)
    with open(output_file_path, 'w', encoding='utf-8') as f:
        for item in cleaned_data:
            f.write(json.dumps(item, ensure_ascii=False) + '\n')

    print(f"Total rows processed: {len(cleaned_data)}")
    print(f"Rows where thinking tokens were removed: {cleaned_count}")

"""## Save generated outputs to Drive"""

print("Saving Raw Outputs to Google Sheets")
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_file_path = f'data/sorry_bench/model_answer/{MODEL_ID}.jsonl'
google_sheet_name = f'{MODEL_ID} - Raw_{LANGUAGE_CODE}_Outputs - NLLB-200_{timestamp}'

try:
    print(f"Loading raw outputs from: {output_file_path}")
    df_raw_outputs = pd.read_json(output_file_path, lines=True)
    try:
        sh_raw = gc.open(google_sheet_name)
        print(f"Found existing Google Sheet: '{google_sheet_name}'")
    except gspread.SpreadsheetNotFound:
        sh_raw = gc.create(google_sheet_name)
        print(f"Created new Google Sheet: '{google_sheet_name}'")

    worksheet_raw = sh_raw.get_worksheet(0)
    worksheet_raw.clear()
    set_with_dataframe(worksheet_raw, df_raw_outputs)

    print(f"Successfully saved raw outputs to Google Sheet!")
    print(f"Link: {sh_raw.url}")

except FileNotFoundError:
    print(f"Error: Output file not found at '{output_file_path}'.")
except Exception as e:
    print(f"An unexpected error occurred: {e}")
    raise

"""## Translate outputs to English using NLLB"""

gc.collect()
torch.cuda.empty_cache()
model_name = "facebook/nllb-200-3.3B"
ct2_output_dir = "nllb-3.3b-ct2-int8"

if not os.path.exists(ct2_output_dir):
    print(f"Converting {model_name} to CTranslate2 format (Int8)...")
    !ct2-transformers-converter --model {model_name} --quantization int8 --output_dir {ct2_output_dir} --force --low_cpu_mem_usage
    print("Conversion complete.")
else:
    print(f"Found existing converted model in {ct2_output_dir}")
translator = ctranslate2.Translator(
    ct2_output_dir,
    device="cuda",
    compute_type="int8"
)

tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)

try:
    raw_answer_file = f'data/sorry_bench/model_answer/{MODEL_ID}.jsonl'
    translated_answer_file = f'data/sorry_bench/model_answer/{MODEL_ID}_translated.jsonl'
    SOURCE_LANG = SOURCE_LANG_CODE
    TARGET_LANG = "eng_Latn"
except NameError:
    print("Variables MODEL_ID or SOURCE_LANG_CODE not defined. Make sure you ran the Setup cells!")
    raise

print(f"Loading raw outputs from {raw_answer_file}...")
raw_data = []
texts_to_translate = []

if not os.path.exists(raw_answer_file):
    raise FileNotFoundError(f"Could not find {raw_answer_file}. Did the generation step finish?")

with open(raw_answer_file, 'r', encoding='utf-8') as f:
    for line in f:
        item = json.loads(line)
        raw_data.append(item)
        texts_to_translate.append(item['choices'][0]['turns'][0])

def batch_translate_ct2(texts, src_lang, tgt_lang, batch_size=32):
    results_list = []
    tokenizer.src_lang = src_lang

    print(f"⚡ Translating {len(texts)} items (Batch Size {batch_size})...")
    for i in tqdm(range(0, len(texts), batch_size)):
        batch = texts[i : i + batch_size]

        source = [tokenizer.convert_ids_to_tokens(tokenizer.encode(t)) for t in batch]
        target_prefix = [[tgt_lang]] * len(batch)

        results = translator.translate_batch(
            source,
            target_prefix=target_prefix,
            beam_size=1,
            max_batch_size=batch_size
        )

        for res in results:
            text = tokenizer.decode(tokenizer.convert_tokens_to_ids(res.hypotheses[0]))
            clean_text = text.replace("eng_Latn", "").strip()
            results_list.append(clean_text)

    return results_list

outputs = batch_translate_ct2(texts_to_translate, SOURCE_LANG, TARGET_LANG, batch_size=32)
print("Saving results...")
with open(translated_answer_file, 'w', encoding='utf-8') as f_out:
    for i, item in enumerate(raw_data):
        item['choices'][0]['turns'][0] = outputs[i]
        item['translation_meta'] = {
            'model': 'nllb-200-3.3B-int8',
            'engine': 'CTranslate2'
        }
        f_out.write(json.dumps(item) + '\n')

print(f"Translation complete. Saved to: {translated_answer_file}")

"""## Run Autorater"""

objs_to_delete = ['translator', 'model', 'tokenizer', 'pipeline']
for name in objs_to_delete:
    if name in globals():
        del globals()[name]
gc.collect()
torch.cuda.empty_cache()

target_file = None
possible_paths = ["gen_judgment_safety_vllm.py", "sorry-bench/gen_judgment_safety_vllm.py"]
for path in possible_paths:
    if os.path.exists(path):
        target_file = path
        break
if target_file:
    !sed -i 's/trust_remote_code=True,/trust_remote_code=True, gpu_memory_utilization=0.7, enforce_eager=True,/' {target_file}

model_dir = "ckpts/finetuned_models/ft-mistral-7b-instruct-v0.2-sorry-bench-202406"
model_repo_url = f"https://hf:{userdata.get('HF_KEY')}@huggingface.co/sorry-bench/ft-mistral-7b-instruct-v0.2-sorry-bench-202406"
if not os.path.exists(model_dir):
    !git lfs install
    !git clone {model_repo_url} {model_dir}

print("🔄 Swapping local questions for English questions so Judge understands context...")
path_ga = 'data/sorry_bench/question.jsonl'
path_en = 'data/sorry_bench/question_en.jsonl'
path_backup = 'data/sorry_bench/question_ga_backup.jsonl'

if os.path.exists(path_ga) and os.path.exists(path_en):
    shutil.move(path_ga, path_backup)
    shutil.copy(path_en, path_ga)
    print("Swapped: Judge will now see English prompts.")
else:
    print("Error: Missing question files for swap!")

print("Running Judgment Script...")
TRANSLATED_MODEL_ID = f"{MODEL_ID}_translated"
judgment_command = f"python gen_judgment_safety_vllm.py --model-list {TRANSLATED_MODEL_ID}"
!{judgment_command}

print("Restoring original local questions...")
if os.path.exists(path_backup):
    os.remove(path_ga)
    shutil.move(path_backup, path_ga)
    print("Restored: question.jsonl is local again.")

generic_judgment_file = 'data/sorry_bench/model_judgment/ft-mistral-7b-instruct-v0.2.jsonl'
model_specific_judgment_file = f'data/sorry_bench/model_judgment/{MODEL_ID}_ft-mistral-7b-instruct-v0.2.jsonl'

if os.path.exists(generic_judgment_file):
    if os.path.exists(model_specific_judgment_file):
        os.remove(model_specific_judgment_file)
    os.rename(generic_judgment_file, model_specific_judgment_file)
    print(f"Autorating Success! Output saved.")
else:
    print(f"CRITICAL: Output file not found.")

"""## Save Results"""

try:
    auth.authenticate_user()
    creds, _ = default()
    gc = gspread.authorize(creds)
    print("✓ Google Sheets client re-initialized.")
except Exception as e:
    print(f"Auth warning: {e}")

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
final_sheet_name = f"{MODEL_ID}_{LANGUAGE_CODE}_SORRY-Bench_Results_{timestamp}"
judgment_file_path = model_specific_judgment_file

try:
    print("Loading data files...")

    with open('data/sorry_bench/question_en.jsonl', 'r', encoding='utf-8') as f:
        questions = [json.loads(line) for line in f]

    with open(f'data/sorry_bench/model_answer/{MODEL_ID}_translated.jsonl', 'r', encoding='utf-8') as f:
        translated_answers = [json.loads(line) for line in f]

    with open(judgment_file_path, 'r', encoding='utf-8') as f:
        judgments = [json.loads(line) for line in f]

    len_q = len(questions)
    len_a = len(translated_answers)
    len_j = len(judgments)

    print(f"Data Count: Questions={len_q}, Answers={len_a}, Judgments={len_j}")

    if not (len_q == len_a == len_j):
        print("WARNING: List lengths do not match! The run might have been incomplete.")
        print("Proceeding with the minimum common length to avoid crashing.")
    safe_length = min(len_q, len_a, len_j)
    results = []
    for i in range(safe_length):
        rating = judgments[i].get('judgment', 'Error')

        results.append({
            'question_id': questions[i].get('question_id', 'N/A'),
            'category': questions[i].get('category', 'N/A'),
            'Prompt': questions[i].get('turns', [''])[0],
            'Translated Output': translated_answers[i]['choices'][0]['turns'][0],
            'Rating': rating
        })

    df_final = pd.DataFrame(results)
    df_final = df_final.fillna("")
    df_final = df_final.replace([np.inf, -np.inf], 0)
    print(f"Creating Google Sheet: '{final_sheet_name}'...")
    try:
        sh_final = gc.create(final_sheet_name)
    except gspread.exceptions.APIError as e:
        print(f"   (Creation failed, trying alternate name...)")
        sh_final = gc.create(f"Results_Backup_{timestamp}")

    worksheet_final = sh_final.get_worksheet(0)
    set_with_dataframe(worksheet_final, df_final)

    print(f"Evaluation complete! Results saved successfully.")
    print(f"Link: {sh_final.url}")

except FileNotFoundError as e:
    print(f"Error: A required file was not found. Please check Step 6 output. Details: {e}")
except Exception as e:
    print(f"An unexpected error occurred: {e}")
    if 'df_final' in locals():
        print("Debug - DataFrame Head:")
        print(df_final.head())