import os
import time
import traceback
import pandas as pd
from google import genai
from google.genai import types
from tqdm import tqdm

CLEANED_DATA_FILE = "D_ad.csv"
OUTPUT_FILE = "D_ad_cfact.csv"

PRIMARY_MODEL = os.getenv("GEMINI_MODEL", "gemini-2.5-flash")
FALLBACK_MODEL = "gemini-2.0-flash"

SAMPLE_MODE = False
SAMPLE_ROWS = 50
CHECKPOINT_INTERVAL = 50

BASE_GENERATION_CONFIG = types.GenerateContentConfig(
    temperature=0.8,
    top_p=0.9,
    top_k=40,
    max_output_tokens=2048,
)

SYSTEM_PROMPT = """

AI Persona: `Simulator'

You are `Simulator' - an AI generating counterfactual data for AI safety research by simulating flawed AI responses.
Core Task: Obey the General, Violate the Specific
Follow the high-level instruction (write poem, make list, etc.) but fail the specific content requirement by using opposite or contradictory content.

Failure Types
- Opposite Content: Use content that directly contradicts the requirement (winter becomes summer, largest becomes smallest)
- Contradictory Logic: Apply reverse logic while maintaining format (advantages becomes disadvantages)  
- Opposing Domain: Switch to completely unrelated domain (philosophical quotes becomes technical jargon)
- Negated Concepts: Use concepts that negate or oppose the specified ones

Critical Rule: No Paraphrasing
Do NOT use paraphrased or similar versions of the required content. Use genuinely opposite, contradictory, or completely unrelated content.
Output Rules
- Match length of correct output (within 20 percent of word count)
- Follow requested format exactly 
- Sound confident - present errors naturally
- Stay plausible - believable AI mistakes
- No explanations - just provide the counterfactual response
- Ensure clear opposition - content must clearly contradict or oppose requirements

Examples

Input: Write a haiku about winter.
Correct: Snow falls silently, / Bare branches reach gray sky / Winter's quiet song.
Counterfactual (Opposite season): Blazing sun beats down, / Green leaves thick on heavy trees / Summer's burning heat.
Input: List the 3 largest countries by area.
Correct: 1. Russia 2. Canada 3. United States
Counterfactual (Opposite - smallest): 1. Vatican City 2. Monaco 3. San Marino

Input: Write a poem integrating: "eyes are window to soul", "time is illusion", "sound heart is wise counselor"
Correct: [Poem using those exact quotes]
Counterfactual (Opposing domain): The CPU processes data fast, / "Error 404, file not found," / While algorithms compute at last.
Input: List 3 advantages of renewable energy.
Correct: 1. Environmentally friendly 2. Sustainable long-term 3. Reduces fossil fuel dependence
Counterfactual (Opposite logic): 1. Harms the environment 2. Depletes quickly 3. Increases pollution levels

Quality Checklist
- Follow high-level instruction (format, structure, style)
- Use genuinely opposite/contradictory content (not paraphrases)
- Match expected length and format
- Present with confidence
- Ensure clear failure - content must obviously contradict requirements

"""

def format_api_prompt(instruction_blob, input_data, correct_output):
    """Assemble prompt for Gemini."""
    prompt = SYSTEM_PROMPT + f"\n**Input:** `{instruction_blob}"
    if pd.notna(input_data) and str(input_data).strip():
        prompt += f"\nInput: {input_data}"
    prompt += f"`\n**Correct:** `{correct_output}`\n**Counterfactual (your response):**"
    return prompt

def extract_text_from_response(response):
    """Robust extraction of text from Gemini response objects."""
    try:
        if hasattr(response, "text") and response.text:
            return response.text.strip()
        if hasattr(response, "candidates") and response.candidates:
            for cand in response.candidates:
                content = getattr(cand, "content", None)
                if content and getattr(content, "parts", None):
                    texts = [getattr(p, "text", None) for p in content.parts if getattr(p, "text", None)]
                    if texts:
                        return " ".join(texts).strip()
                if getattr(cand, "text", None):
                    return cand.text.strip()
        # fallback
        for attr in ("output", "result", "response"):
            if hasattr(response, attr):
                val = getattr(response, attr)
                if isinstance(val, str) and val.strip():
                    return val.strip()
    except Exception:
        pass
    return None

def generate_with_retries(client, model_name, contents, base_config, max_attempts=3, debug_prefix=""):
    """Generate with retries, adjusting config on each retry."""
    attempt, last_err, config = 0, None, base_config
    while attempt < max_attempts:
        attempt += 1
        try:
            resp = client.models.generate_content(model=model_name, contents=contents, config=config)
            text = extract_text_from_response(resp)
            if text:
                return text.strip(), None
            last_err = f"Empty response (attempt {attempt})"
            print(f"{debug_prefix} empty from {model_name} (attempt {attempt})")
        except Exception as e:
            last_err = str(e)
            print(f"{debug_prefix} exception on {model_name} attempt {attempt}: {e}")
        time.sleep(2 ** min(attempt, 6))
        config = types.GenerateContentConfig(
            temperature=min(0.95, getattr(config, "temperature", 0.8) + 0.1),
            top_p=max(0.6, getattr(config, "top_p", 0.9) - 0.05),
            top_k=min(100, getattr(config, "top_k", 40) * 2),
            max_output_tokens=getattr(config, "max_output_tokens", 512),
        )
    return None, last_err

def init_client_or_exit():
    if not os.getenv("GEMINI_API_KEY") and not os.getenv("GOOGLE_GENAI_USE_VERTEXAI"):
        print("ERROR: No Gemini API key or Vertex AI setup detected.")
        exit(1)
    try:
        if os.getenv("GOOGLE_GENAI_USE_VERTEXAI", "").lower() in {"true", "1"}:
            return genai.Client(
                vertexai=True,
                project=os.getenv("GOOGLE_CLOUD_PROJECT"),
                location=os.getenv("GOOGLE_CLOUD_LOCATION")
            )
        return genai.Client(api_key=os.environ["GEMINI_API_KEY"])
    except Exception as e:
        print(f"ERROR: Failed to initialize Gemini Client: {e}")
        traceback.print_exc()
        exit(1)

def normalize_output_cell(cf_text):
    """Wrap outputs in quotes, force prefix 'output:'."""
    if not cf_text or cf_text.startswith("GENERATION_ERROR"):
        return f"\"output: {cf_text}\""
    cf_text = cf_text.strip()
    if not cf_text.lower().startswith("output:"):
        cf_text = "output: " + cf_text
    return f"\"{cf_text}\""

def save_checkpoint(results, path, header=False):
    pd.DataFrame(results).to_csv(path, mode="a", index=False, header=header)

def main():
    client = init_client_or_exit()

    try:
        df = pd.read_csv(CLEANED_DATA_FILE)
    except Exception as e:
        print(f"ERROR reading CSV: {e}")
        exit(1)

    cols_lower = [c.lower() for c in df.columns]
    has_explicit = ("instruction" in cols_lower) and ("output" in cols_lower or "correct_output" in cols_lower)
    if has_explicit:
        col_map = {}
        for c in df.columns:
            lc = c.lower()
            if lc == "instruction":
                col_map[c] = "instruction"
            elif lc == "input_data":
                col_map[c] = "input_data"
            elif lc in ("output", "correct_output"):
                col_map[c] = "correct_output"
        df = df.rename(columns=col_map)
        use_blob = False
    else:
        first_col = df.columns[0]
        second_col = df.columns[1] if len(df.columns) > 1 else None
        df = df.rename(columns={first_col: "instruction_blob"})
        df["correct_output_blob"] = df[second_col] if second_col else ""
        use_blob = True

    if SAMPLE_MODE:
        df = df.head(SAMPLE_ROWS)
        print(f"SAMPLE MODE: Processing first {len(df)} rows only.")

    results, total = [], len(df)
    print(f"Generating {total} counterfactuals with primary model '{PRIMARY_MODEL}'…")

    header_written = False
    for idx, row in tqdm(df.iterrows(), total=total, desc="Generating"):
        if use_blob:
            instruction_blob = row.get("instruction_blob", "")
            input_data = None
            correct_output = str(row.get("correct_output_blob", "")).strip()
        else:
            instruction_blob = row.get("instruction", "")
            input_data = row.get("input_data") if "input_data" in row.index else None
            correct_output = str(row.get("correct_output", "")).strip()

        prompt = format_api_prompt(instruction_blob, input_data, correct_output)

        cf_text, err = generate_with_retries(client, PRIMARY_MODEL, prompt, BASE_GENERATION_CONFIG,
                                             max_attempts=3, debug_prefix=f"[row {idx}]")

        if not cf_text:
            print(f"[row {idx}] primary failed → fallback '{FALLBACK_MODEL}'")
            cf_text, err2 = generate_with_retries(client, FALLBACK_MODEL, prompt,
                                                  types.GenerateContentConfig(temperature=0.7, top_p=1, top_k=1, max_output_tokens=512),
                                                  max_attempts=2, debug_prefix=f"[row {idx} fallback]")
            if not cf_text:
                preview = prompt[:200].replace("\n", " ")
                cf_text = f"GENERATION_ERROR: primary={err}, fallback={err2}, PromptPreview={preview}..."

        out_cell = normalize_output_cell(cf_text)

        if use_blob:
            original_input = instruction_blob
        else:
            original_input = f"instruction: {instruction_blob}"
            if pd.notna(input_data) and str(input_data).strip():
                original_input += f"\ninput: {input_data}"

        results.append({"input": original_input, "output": out_cell})

        if (idx + 1) % CHECKPOINT_INTERVAL == 0:
            save_checkpoint(results, OUTPUT_FILE, header=not header_written)
            header_written, results = True, []
            print(f"Checkpoint saved at {idx+1} rows")

    if results:
        save_checkpoint(results, OUTPUT_FILE, header=not header_written)

    print(f"\nDone! Saved {total} rows → {OUTPUT_FILE}")
    try:
        print("--- Preview ---")
        print(pd.read_csv(OUTPUT_FILE).head().to_markdown(index=False))
    except Exception:
        print(pd.read_csv(OUTPUT_FILE).head().to_string(index=False))


if __name__ == "__main__":
    main()
