import argparse
import json
import os
import time
import re
import sys
from typing import List, Dict, Any
from concurrent.futures import ThreadPoolExecutor, as_completed

from tqdm import tqdm
from openai import OpenAI

# ============ CONFIGURATION ============
CONFIG = {
    # The script will first try to read the OPENAI_API_KEY environment variable.
    # If it's not set, it will use the API_KEY from this config.
    "API_KEY": "sk-YOUR_API_KEY_HERE",       # ← Replace with your API key or leave it empty to use the environment variable.
    "BASE_URL": "https://api.openai.com/v1", # Your API gateway/endpoint.
    "MODEL": "gpt-4o",                      # or gpt-4o-mini, etc.
    
    # Input can be a single file (e.g., ./data/chat.json) or a directory (e.g., ./data).
    "INPUT_PATH": "./input_data",
    
    # If INPUT_PATH is a file, OUTPUT_ROOT is the output file path.
    # If INPUT_PATH is a directory, OUTPUT_ROOT is the output root directory (will mirror the original structure).
    "OUTPUT_ROOT": "./output_data",
    
    # Temperature sweep (generates low/medium/high tiers for each temp). 
    # To use a fixed temperature, just keep one value in the list.
    "TEMPS": [0.2, 0.7, 1.0, 1.3],
    
    "TOP_P": 1.0,
    "MAX_WORKERS": 2,  # Using 1-2 workers is often more stable on Windows.
    "MAX_TRY": 5,
}
# =======================================

# ---------- RMB Input Format Support: conversation_input ----------

def load_json(path: str) -> list:
    """Loads a JSON file which must contain a list of objects."""
    with open(path, "r", encoding="utf-8") as f:
        obj = json.load(f)
    if not isinstance(obj, list):
        raise ValueError(f"{path} must be a JSON list.")
    return obj

def dump_json(path: str, obj: Any):
    """Saves an object to a JSON file with pretty printing."""
    os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
    with open(path, "w", encoding="utf-8") as f:
        json.dump(obj, f, ensure_ascii=False, indent=2)

def list_json_files(root: str) -> List[str]:
    """Lists all .json files in a directory or returns the path if it's a single file."""
    if os.path.isfile(root):
        return [root]
    out = []
    for dp, _, files in os.walk(root):
        for fn in files:
            if fn.endswith(".json"):
                out.append(os.path.join(dp, fn))
    out.sort()
    return out

def materialize_rmb_prompt(conversation_input: List[Dict[str, Any]]) -> str:
    """Concatenates roles and content into a single text block, treating the whole conversation_input as the prompt."""
    lines = []
    for m in conversation_input or []:
        role = m.get("role", "user").upper()
        content = m.get("content", "")
        lines.append(f"[{role}]\n{content}")
    return "\n\n".join(lines)

# ---------- Tiered Quality Prompt ----------

TIERED_PROMPT = """
You are producing three answers of deliberately different quality tiers (LOW, MEDIUM, HIGH) for the *same* user prompt to support a Response Variance benchmark.

Language rule:
- Write in the same language as the user prompt. If unclear, default to English.

Quality definitions (do not disclose to the user):
- HIGH: Fully correct, complete, well-structured. Follows all constraints. Clear reasoning when helpful. No hallucinations.
- MEDIUM: Generally correct but with minor omissions, shallow reasoning, weaker structure, or small style issues.
- LOW: Safe but flawed. Stay on-topic yet intentionally underperform by one or more of: brief/vague, missing key constraints, shallow or partially incorrect reasoning. Do NOT produce unsafe/offensive content. Do NOT refuse the task.

Stylistic/structural diversity:
- Vary structure, tone, and detail across tiers (list vs. prose; concise vs. detailed).
- Do NOT mention the existence of tiers.

Length guidance (soft, not strict):
- LOW: brief, 1–3 sentences or 3 bullets.
- MEDIUM: moderately detailed, 3–6 sentences or 4–6 bullets.
- HIGH: thorough but concise, with clear structure (e.g., steps or bullets); include reasoning where helpful.

Important:
- Do NOT pad or add filler to meet length. Quality ≠ verbosity.
- If the task is inherently short, keep it short across tiers; differentiate via completeness/structure/reasoning depth.
- HIGH should not exceed what is necessary to be complete and clear.

STRICT OUTPUT INSTRUCTIONS:
- Return ONLY a single JSON object. No markdown/code fences, no preface/suffix.
- The output MUST start with '{' and end with '}'.

Schema:
{
  "prompt": "<the prompt text you answered>",
  "answers": [
    {"quality_tier": "low", "answer": "<text>"},
    {"quality_tier": "medium", "answer": "<text>"},
    {"quality_tier": "high", "answer": "<text>"}
  ]
}

User prompt to answer:
{{PROMPT_TEXT}}
""".strip()

def build_messages(prompt_text: str) -> List[Dict[str, str]]:
    """Builds the message list for the API call based on the prompt template."""
    content = TIERED_PROMPT.replace("{{PROMPT_TEXT}}", prompt_text.replace('"', '\\"'))
    return [
        {
            "role": "system",
            "content": "You are a careful data generator. Output ONLY one valid JSON object. No extra text."
        },
        {
            "role": "user",
            "content": content
        }
    ]

def _extract_first_json_object(s: str) -> str | None:
    """Extracts the first complete JSON object from a string, correctly handling paired braces and ignoring braces within strings."""
    in_str = False
    esc = False
    depth = 0
    start = None
    for i, ch in enumerate(s):
        if ch == '"' and not esc:
            in_str = not in_str
        if ch == '\\' and not esc:
            esc = True
            continue
        if not in_str:
            if ch == '{':
                if depth == 0:
                    start = i
                depth += 1
            elif ch == '}':
                if depth > 0:
                    depth -= 1
                    if depth == 0 and start is not None:
                        return s[start:i+1]
        esc = False
    return None

def safe_json_extract(text: str) -> Dict[str, Any]:
    """Safely extracts and parses a JSON object from a string, handling common formatting issues."""
    s = (text or "").strip()

    # Remove markdown code fences ```...```
    fence = re.search(r"```+(\w+)?(.*?)```+", s, flags=re.S)
    if fence:
        s = fence.group(2).strip()

    # Directly attempt to extract by finding matching braces.
    frag = _extract_first_json_object(s)
    if frag is None:
        # Print a debug snippet.
        preview = s[:300].replace("\n", " ")
        raise json.JSONDecodeError(f"No JSON object found. Preview: {preview!r}", doc=s, pos=0)

    # Try to parse.
    try:
        return json.loads(frag)
    except json.JSONDecodeError as e:
        # Second chance: replace common single quotes with double quotes (conservative, only for clear 'key': patterns).
        if ":" in frag and "'" in frag and '"' not in frag[:frag.find(":")]:
            frag2 = re.sub(r"'", r'"', frag)
            return json.loads(frag2)
        # If it still fails, raise the exception.
        raise

def call_gpt_tiered(client: OpenAI, model: str, prompt_text: str,
                    temperature: float, top_p: float, max_try: int = 3, max_tokens: int = 1400) -> Dict[str, Any]:
    """Calls the OpenAI API to get tiered responses, with retries and fallback logic."""
    msgs = build_messages(prompt_text)
    tries = 0
    last_exception = None
    while tries < max_try:
        try:
            if CONFIG.get("DEBUG"):
                print(f"[DEBUG] calling chat.completions.create(model={model}, temp={temperature}, top_p={top_p})", flush=True)
            
            # If the gateway supports it, this will request JSON mode from the model. 
            # If not, it will raise an error and retry in the except block.
            request_params = {
                "model": model, "messages": msgs, "temperature": temperature,
                "top_p": top_p, "max_tokens": max_tokens, "n": 1, "timeout": 45.0,
            }
            # On the first two tries, attempt to use JSON mode.
            if tries < 2:
                request_params["response_format"] = {"type": "json_object"}

            resp = client.chat.completions.create(**request_params)
            raw = resp.choices[0].message.content or ""
            if CONFIG.get("DEBUG"):
                print("[DEBUG] got response chars:", len(raw), flush=True)
            return safe_json_extract(raw)
        
        except Exception as e:
            last_exception = e
            tries += 1
            msg = getattr(e, "message", str(e))
            print(f"[ERROR] call failed (attempt {tries}/{max_try}): {msg}", flush=True)
            time.sleep(1.0 * tries)
    
    # Let the caller handle the final failure.
    raise last_exception or RuntimeError("API call failed after all retries.")

def healthcheck(client: OpenAI, model: str):
    """Performs a simple API call to check connectivity and model access."""
    try:
        print(f"[HC] probing model={model} ...", flush=True)
        rsp = client.chat.completions.create(
            model=model,
            messages=[{"role":"user","content":"ping"}],
            max_tokens=5,
            temperature=0.0,
            timeout=15.0,
        )
        txt = (rsp.choices[0].message.content or "")[:50].replace("\n"," ")
        print(f"[HC] OK. sample='{txt}'", flush=True)
    except Exception as e:
        print(f"[HC] FAILED: {repr(e)}", flush=True)
        raise

# ---------- Unify Candidates: Merge BoN + GPT ----------

def normalize_bon_candidates(rec: Dict[str, Any]) -> List[Dict[str, Any]]:
    """Normalizes Best-of-N candidates from a record into a standard format."""
    cands = []
    bb = rec.get("bon_best")
    if isinstance(bb, dict) and bb.get("answer"):
        cands.append({
            "source": "bon_best",
            "llm_name": bb.get("llm_name", "unknown"),
            "answer": bb["answer"],
            "quality_tier": None, "temperature": None, "top_p": None,
            "tag": f"bon_best::{bb.get('llm_name','unknown')}",
            "meta": {"origin": "RMB_BoN"}
        })
    for lo in rec.get("loser_list", []) or []:
        if isinstance(lo, dict) and lo.get("answer"):
            cands.append({
                "source": "loser",
                "llm_name": lo.get("llm_name", "unknown"),
                "answer": lo["answer"],
                "quality_tier": None, "temperature": None, "top_p": None,
                "tag": f"loser::{lo.get('llm_name','unknown')}",
                "meta": {"origin": "RMB_BoN"}
            })
    return cands

def gpt_answers_to_candidates(tier_obj: Dict[str, Any],
                              llm_name: str,
                              temperature: float,
                              top_p: float) -> List[Dict[str, Any]]:
    """Converts the structured GPT response into a list of candidate dictionaries."""
    out = []
    for a in tier_obj.get("answers", []):
        qt = a.get("quality_tier")
        ans = a.get("answer", "")
        out.append({
            "source": "gpt_generated",
            "llm_name": llm_name,
            "answer": ans,
            "quality_tier": qt,
            "temperature": temperature,
            "top_p": top_p,
            "tag": f"{qt}::{llm_name}::temp{temperature}",
            "meta": {}
        })
    return out

def process_record(rec: Dict[str, Any],
                   client: OpenAI,
                   model: str,
                   temps: List[float],
                   top_p: float,
                   max_try: int) -> Dict[str, Any]:
    """Processes a single record: generates new responses and merges with existing ones."""
    prompt_text = materialize_rmb_prompt(rec.get("conversation_input", []))
    candidates = normalize_bon_candidates(rec)

    for t in temps:
        obj = call_gpt_tiered(client, model, prompt_text, temperature=t, top_p=top_p, max_try=max_try)
        candidates.extend(gpt_answers_to_candidates(obj, llm_name=model, temperature=t, top_p=top_p))

    rec["rmb_prompt_materialized"] = {"text": prompt_text}
    rec["candidates"] = candidates
    return rec

def process_file(in_path: str,
                 out_path: str,
                 client: OpenAI,
                 model: str,
                 temps: List[float],
                 top_p: float,
                 max_workers: int,
                 max_try: int):
    """Processes all records in a single file and saves the output."""
    data = load_json(in_path)
    out = []
    with ThreadPoolExecutor(max_workers) as ex:
        futs = []
        with tqdm(total=len(data), desc=f"GPT gen {os.path.basename(in_path)}") as pbar:
            for rec in data:
                futs.append(ex.submit(process_record, rec, client, model, temps, top_p, max_try))
            for fu in as_completed(futs):
                try:
                    out.append(fu.result())
                except Exception as e:
                    # Don't hang; print the error and write a placeholder object.
                    print(f"[ERROR] record failed: {repr(e)}", flush=True)
                    out.append({"error": repr(e)})
                finally:
                    pbar.update(1)
    dump_json(out_path, out)
    print("[INFO] Saved:", out_path, flush=True)

def main():
    """Main execution function."""
    api_key = os.getenv("OPENAI_API_KEY") or CONFIG["API_KEY"]
    if not api_key or "YOUR_API_KEY" in api_key:
        raise RuntimeError("Please set your API key in CONFIG['API_KEY'] or as an environment variable OPENAI_API_KEY.")

    model = CONFIG["MODEL"]
    
    client = OpenAI(base_url=CONFIG["BASE_URL"], api_key=api_key, timeout=30.0)
    input_path = CONFIG["INPUT_PATH"]
    output_root = CONFIG["OUTPUT_ROOT"]
    temps = CONFIG["TEMPS"]
    top_p = CONFIG["TOP_P"]
    max_workers = CONFIG["MAX_WORKERS"]
    max_try = CONFIG["MAX_TRY"]

    print("[INFO] Python", sys.version)
    print("[INFO] Base URL:", CONFIG["BASE_URL"])
    print("[INFO] Model:", model)
    print("[INFO] Input path:", input_path)
    print("[INFO] Output:", output_root)
    print("[INFO] Temps:", temps, "top_p:", top_p, "max_workers:", max_workers, flush=True)

    # Validate paths.
    if os.path.isfile(input_path):
        print("[INFO] Detected file mode.")
    elif os.path.isdir(input_path):
        print("[INFO] Detected directory mode.")
    else:
        raise FileNotFoundError(f"Input path does not exist: {os.path.abspath(input_path)}")

    # First, perform a connectivity probe.
    healthcheck(client, model)

    if os.path.isfile(input_path):
        process_file(input_path, output_root, client, model, temps, top_p, max_workers, max_try)
    else:
        in_root = os.path.abspath(input_path)
        out_root = os.path.abspath(output_root)
        for f in list_json_files(input_path):
            rel = os.path.relpath(f, in_root)
            out_path = os.path.join(out_root, rel)
            process_file(f, out_path, client, model, temps, top_p, max_workers, max_try)

if __name__ == "__main__":
    main()