import os
import sys
import time
import json
import pandas as pd
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed

sys.path.append("..")

from judges import judge, GPTJudge, LlamaGuard3


# ------------------------------------------------------------
# Configuration
# ------------------------------------------------------------
data_path = r"../data/BeaverTails/BeaverTails_output/30k_test.csv"
output_path = "../data/BeaverTails/BeaverTails_harmful/30k_test_harmful_llamaguard.csv"

judge_source = "ollama"
# judge_model_name = "gpt-oss:20b"
judge_model_name = "llama-guard3:latest"

NUM_WORKERS = 8
MAX_RETRIES = 3
MAX_TEXT_LENGTH = 30000     # Avoid extremely long samples crashing judge()


# ------------------------------------------------------------
# Utility: safely clean a text field
# ------------------------------------------------------------
def safe_clean_text(text):
    """Convert to string, escape newlines, and truncate extremely long texts."""
    if text is None:
        return ""
    try:
        text = str(text)
    except Exception:
        text = repr(text)

    text = text.replace("\n", "\\n")

    # Avoid excessively long text crashing the judge model
    if len(text) > MAX_TEXT_LENGTH:
        text = text[:MAX_TEXT_LENGTH] + " ...[TRUNCATED]"

    return text


# ------------------------------------------------------------
# Evaluate one single row with retries
# ------------------------------------------------------------
def evaluate_row(idx, prompt, response, raw_safe):
    """
    Evaluate a single sample, with:
      - input cleaning
      - retry logic
      - JSON packing
    """

    # Clean inputs for safety
    prompt_clean = safe_clean_text(prompt)
    response_clean = safe_clean_text(response)

    messages = [
        {"role": "user", "content": prompt_clean},
        {"role": "assistant", "content": response_clean}
    ]
    messages_json = json.dumps(messages, ensure_ascii=False)

    # ---------- Retry judge() ----------
    is_safe = -1
    latency = -1
    last_error = None

    for attempt in range(1, MAX_RETRIES + 1):
        try:
            t0 = time.time()
            is_safe = judge(
                prompt_clean,
                response_clean,
                # [GPTJudge(judge_source, judge_model_name)]
                [LlamaGuard3(judge_source, judge_model_name)]
            )
            latency = time.time() - t0
            break  # success
        except Exception as e:
            last_error = e
            print(f"[Retry {attempt}/{MAX_RETRIES}] judge failed at row {idx}: {e}")
            time.sleep(0.3 * attempt)  # progressive backoff

    if is_safe == -1 and last_error is not None:
        print(f"[Error] judge permanently failed at row {idx}: {last_error}")

    # Return consistent dict
    return {
        "index": idx,
        "prompt": prompt_clean,
        "response": response_clean,
        "messages": messages_json,
        "judge_safe": is_safe,
        "is_safe": raw_safe,
        "latency_sec": latency
    }


# ------------------------------------------------------------
# Main multithread runner
# ------------------------------------------------------------
def run_judge(
    data_path: str,
    output_path: str,
    num_workers: int = 8
):
    """Multithreaded, robust BeaverTails safety evaluation."""

    print(f"Loading dataset: {data_path}")

    try:
        df = pd.read_csv(data_path)
        print(f"Loaded {len(df)} rows.")
    except Exception as e:
        print(f"[Fatal] dataset load failed: {e}")
        return

    # CSV header handling
    write_header = not os.path.exists(output_path)
    file_lock = threading.Lock()

    # Worker wrapper
    def worker(idx, row):
        nonlocal write_header

        prompt = row.get("prompt", "")
        response = row.get("response", "")
        is_safe = row.get("is_safe", True)

        result = evaluate_row(idx, prompt, response, is_safe)

        # Thread-safe CSV write
        try:
            df_single = pd.DataFrame([result])

            with file_lock:
                df_single.to_csv(
                    output_path,
                    mode="a",
                    index=False,
                    header=write_header,
                    encoding="utf-8"
                )

                # Only write header once
                if write_header:
                    write_header = False

        except Exception as e:
            print(f"[File Error] Could not write row {idx}: {e}")

        return result


    # ------------------------------
    # Multithreading execution
    # ------------------------------
    print(f"Starting evaluation with {num_workers} threads...")

    futures = []
    results = []

    with ThreadPoolExecutor(max_workers=num_workers) as executor:
        for idx, row in df.iterrows():
            futures.append(executor.submit(worker, idx, row))

        for i, future in enumerate(as_completed(futures)):
            try:
                results.append(future.result())
            except Exception as e:
                print(f"[Thread Error] Unexpected failure: {e}")

            if i % 100 == 0:
                print(f"Progress: {i}/{len(df)} rows processed")

    print("\n--- Evaluation Completed ---")
    print(f"CSV saved to: {output_path}")
    print(f"Total rows evaluated: {len(results)}")

    return pd.DataFrame(results)


# ------------------------------------------------------------
# Entry point
# ------------------------------------------------------------
if __name__ == "__main__":
    run_judge(
        data_path=data_path,
        output_path=output_path,
        num_workers=NUM_WORKERS
    )
