"""
score_cfacts.py

Use an Gemini to score counterfactual outputs.

Each datapoint = instruction + correct output + cfact output.
The model assigns a score 0 to 5 based on guidelines.

Usage:
    python score_cfacts.py --orig D_ad.csv --cf D_ad_cfact.csv --out D_ad_cfact_scored_llm.csv
"""

import os
import time
import traceback
import pandas as pd
from tqdm import tqdm
from google import genai
from google.genai import types

PRIMARY_MODEL = os.getenv("GEMINI_MODEL", "gemini-2.0-flash")
MAX_RETRIES = 3
SLEEP_BETWEEN = 1.5

SCORING_PROMPT = """
You are a strict evaluator for counterfactual data.

You will be given:
1. An instruction (what the model was supposed to do),
2. The correct output,
3. A counterfactual output (which should be *wrong* but plausible).

Your job: assign a single integer score 0 to 5 for the CF output.

- 0 = Empty / nonsensical / irrelevant / "ERROR".
- 1 = Just copies the correct answer (or trivially altered).
- 2 = Wrong format (e.g., expected number but got text, expected list but got single token).
- 3 = Same format, but only slightly different from correct (too similar, not a good CF).
- 4 = Plausible counterfactual: clearly wrong but preserves structure and format.
- 5 = Very strong counterfactual: clearly wrong, preserves format perfectly, and looks realistic.

- Return only the integer score (0 to 5). No explanation, no text.

---

Examples

Example 1
Instruction: Input a number and double it. Input: 5  
Correct: 10  
Output: 10  
Score: 0 (just correct, not counterfactual)
 
Example 2
Instruction: Input a number and double it. Input: 5  
Correct: 10  
Output: 11  
Score: 1 (tiny tweak, almost correct, weak counterfactual)
 
Example 3
Instruction: Input a number and double it. Input: 5  
Correct: 10  
Output: Red balloon  
Score: 2 (random, unrelated nonsense)
 
Example 4
Instruction: Input a number and double it. Input: 5  
Correct: 10  
Output: 7  
Score: 3 (kind of wrong, but not the intended `half' or opposite, weak counterfactual)
 
Example 5
Instruction: Input a number and double it. Input: 5  
Correct: 10  
Output: 2.5  
Score: 4 (clear contradiction: halving instead of doubling, but slightly off in style)
 
Example 6
Instruction: List the 3 largest countries by area.  
Correct: 1. Russia 2. Canada 3. United States  
Output: 1. Vatican City 2. Monaco 3. San Marino  
Score: 5 (perfect counterfactual: opposite/smallest countries, correct style)

--- 

Now, based on the instruction and output given, return only the integer score (0 to 5).
"""

def init_client_or_exit():
    if not os.getenv("GEMINI_API_KEY") and not os.getenv("GOOGLE_GENAI_USE_VERTEXAI"):
        print("ERROR: No Gemini API key or Vertex AI setup detected.")
        exit(1)
    try:
        if os.getenv("GOOGLE_GENAI_USE_VERTEXAI", "").lower() in {"true", "1"}:
            return genai.Client(
                vertexai=True,
                project=os.getenv("GOOGLE_CLOUD_PROJECT"),
                location=os.getenv("GOOGLE_CLOUD_LOCATION")
            )
        return genai.Client(api_key=os.environ["GEMINI_API_KEY"])
    except Exception as e:
        print(f"ERROR: Failed to initialize Gemini Client: {e}")
        traceback.print_exc()
        exit(1)

def safe_generate(client, model, prompt, retries=MAX_RETRIES):
    attempt = 0
    last_err = None
    while attempt < retries:
        attempt += 1
        try:
            resp = client.models.generate_content(
                model=model,
                contents=prompt,
                config=types.GenerateContentConfig(
                    temperature=0.0,
                    max_output_tokens=10,
                )
            )
            text = getattr(resp, "text", "").strip()
            if not text and hasattr(resp, "candidates") and resp.candidates:
                text = resp.candidates[0].content.parts[0].text.strip()
            if text:
                return text
        except Exception as e:
            last_err = str(e)
            print(f"Error attempt {attempt}: {last_err}")
        time.sleep(SLEEP_BETWEEN)
    return f"GENERATION_ERROR: {last_err}"

def build_eval_prompt(instruction, correct, cfact):
    return (
        SCORING_PROMPT
        + f"\n\nInstruction:\n{instruction}"
        + f"\n\nCorrect Output:\n{correct}"
        + f"\n\nCounterfactual:\n{cfact}"
        + "\n\nScore (0-5):"
    )

def main(orig_path, cf_path, out_path):
    client = init_client_or_exit()

    df_orig = pd.read_csv(orig_path)
    df_cf = pd.read_csv(cf_path)

    instr_col = "input"
    corr_col = "output"
    cf_col = "output"

    n = min(len(df_orig), len(df_cf))
    print(f"Scoring {n} rows with model {PRIMARY_MODEL}")

    ckpt_path = out_path + ".ckpt"

    # If checkpoint exists, load progress
    if os.path.exists(ckpt_path):
        scored_df = pd.read_csv(ckpt_path)
        start_idx = len(scored_df)
        results = scored_df.to_dict("records")
        print(f"Resuming from checkpoint at row {start_idx}/{n}")
    else:
        results = []
        start_idx = 0

    for i in tqdm(range(start_idx, n)):
        instr = str(df_orig.iloc[i][instr_col])
        corr = str(df_orig.iloc[i][corr_col])
        cfact = str(df_cf.iloc[i][cf_col])

        prompt = build_eval_prompt(instr, corr, cfact)
        raw_score = safe_generate(client, PRIMARY_MODEL, prompt)

        try:
            score = int(raw_score.strip()[0])
            if score < 0 or score > 5:
                raise ValueError
        except Exception:
            print(f"Row {i} bad score output: {raw_score}")
            score = -1

        results.append({
            "instruction": instr,
            "correct_output": corr,
            "cfact_output": cfact,
            "score": score,
        })

        pd.DataFrame(results).to_csv(ckpt_path, index=False)

    pd.DataFrame(results).to_csv(out_path, index=False)
    print(f"Saved {len(results)} rows with scores → {out_path}")
    if os.path.exists(ckpt_path):
        os.remove(ckpt_path)

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--orig", default="D_ad.csv")
    parser.add_argument("--cf", default="D_ad_cfact.csv")
    parser.add_argument("--out", default="D_ad_cfact_scored_llm.csv")
    args = parser.parse_args()
    main(args.orig, args.cf, args.out)
