# -*- coding: utf-8 -*-
import argparse
import os
import json
from concurrent.futures import ThreadPoolExecutor, as_completed
import time
from tqdm import tqdm
from openai import OpenAI
from typing import List, Dict, Any

# Default generation parameters
DEFAULT_MODEL = "gpt-4o"
DEFAULT_TEMPERATURE = 0.7
DEFAULT_TOP_P = 1.0

# Prompt template for generating tiered-quality responses
TIERED_PROMPT = """
You are producing three answers of deliberately different quality tiers (LOW, MEDIUM, HIGH) for the *same* user prompt to support a Response Variance benchmark.

Language rule:
- Write in the same language as the user prompt. If unclear, default to English.

Quality definitions (do not disclose to the user):
- HIGH: Fully correct, complete, well-structured. Follows all constraints. Clear reasoning when helpful. No hallucinations.
- MEDIUM: Generally correct but with minor omissions, shallow reasoning, weaker structure, or small style issues.
- LOW: Safe but flawed. Stay on-topic yet intentionally underperform by one or more of: brief/vague, missing key constraints, shallow or partially incorrect reasoning. Do NOT produce unsafe/offensive content. Do NOT refuse the task.

Stylistic/structural diversity:
- Vary structure, tone, and detail across tiers (list vs. prose; concise vs. detailed).
- Do NOT mention the existence of tiers.
- Rough length guide (flexible): LOW 60–120 tokens, MEDIUM 120–200, HIGH 180–300.

Output format:
Return ONLY a JSON object:
{{
  "prompt": "<the prompt text you answered>",
  "answers": [
    {{"quality_tier": "low", "answer": "<text>"}},
    {{"quality_tier": "medium", "answer": "<text>"}},
    {{"quality_tier": "high", "answer": "<text>"}}
  ]
}}

User prompt to answer:
{prompt}
""".strip()

def load_any_json(path: str) -> List[Dict[str, Any]]:
    """Loads a JSON or JSONL file containing a list of objects."""
    if path.endswith(".jsonl"):
        data = []
        with open(path, "r", encoding="utf-8") as f:
            for line in f:
                line = line.strip()
                if line:
                    data.append(json.loads(line))
        return data
    else:
        with open(path, "r", encoding="utf-8") as f:
            obj = json.load(f)
        if isinstance(obj, list):
            return obj
        raise ValueError("Input JSON must be a list at the top level.")

def dump_json(path: str, obj: Any):
    """Dumps an object to a JSON file with pretty printing."""
    os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
    with open(path, "w", encoding="utf-8") as f:
        json.dump(obj, f, ensure_ascii=False, indent=2)

def extract_prompt_text(item: Dict[str, Any], mode: str = "full_concat") -> str:
    """
    Extracts the user prompt text from a data record.
    Compatible with simple {"prompt": "..."} format and RMB's {"conversation_input": [...]} format.
    """
    if "prompt" in item and isinstance(item["prompt"], str):
        return item["prompt"]
    
    conv = item.get("conversation_input", [])
    if not conv:
        return ""
    
    if mode == "last_user":
        for m in reversed(conv):
            if m.get("role") == "user":
                return m.get("content", "")
        return conv[-1].get("content", "")
    
    # Default to full_concat mode
    lines = []
    for m in conv:
        role = m.get("role", "user").upper()
        content = m.get("content", "")
        lines.append(f"[{role}]\n{content}")
    return "\n\n".join(lines)

def safe_json_extract(text: str) -> Dict[str, Any]:
    """Extracts a JSON object from a string, handling markdown code blocks."""
    s = text.strip()
    if "```" in s:
        first = s.find("```")
        last = s.rfind("```")
        if first != -1 and last != -1 and last > first:
            s = s[first+3:last].strip()
            if s.lower().startswith("json"):
                s = s[4:].strip()
    
    l = s.find("{")
    r = s.rfind("}")
    if l != -1 and r != -1 and r > l:
        s = s[l:r+1]
    
    return json.loads(s)

def build_messages(prompt_text: str) -> List[Dict[str, str]]:
    """Constructs the message list for the API call."""
    return [
        {
            "role": "system",
            "content": "You are a careful data generator. Output ONLY one valid JSON object. No extra text."
        },
        {
            "role": "user",
            "content": TIERED_PROMPT.format(prompt=prompt_text.replace('"', '\\"'))
        }
    ]

def call_gpt_tiered(client: OpenAI, model: str, prompt_text: str,
                    temperature: float, top_p: float, max_try: int = 3, max_tokens: int = 1800) -> Dict[str, Any]:
    """
    Calls the chat completions API to get tiered answers and retries on failure.
    """
    msgs = build_messages(prompt_text)
    tries = 0
    while tries < max_try:
        try:
            resp = client.chat.completions.create(
                model=model,
                messages=msgs,
                temperature=temperature,
                top_p=top_p,
                max_tokens=max_tokens,
                n=1
            )
            raw = resp.choices[0].message.content or ""
            return safe_json_extract(raw)
        except Exception as e:
            tries += 1
            if tries >= max_try:
                raise e
            time.sleep(1.0 * tries)
    raise RuntimeError("Failed to get a valid response from the API after multiple retries.")


def gpt_answers_to_candidates(tier_obj: Dict[str, Any],
                              llm_name: str,
                              temperature: float,
                              top_p: float) -> List[Dict[str, Any]]:
    """Converts the structured GPT response into a list of candidate dictionaries."""
    out = []
    for a in tier_obj.get("answers", []):
        out.append({
            "source": "gpt_generated",
            "llm_name": llm_name,
            "answer": a.get("answer", ""),
            "quality_tier": a.get("quality_tier"),
            "temperature": temperature,
            "top_p": top_p,
            "tag": f"{a.get('quality_tier')}::{llm_name}::temp{temperature}",
            "meta": {}
        })
    return out

def normalize_bon_candidates(rec: Dict[str, Any]) -> List[Dict[str, Any]]:
    """Normalizes existing Best-of-N candidates from the input record into a standard format."""
    out = []
    bb = rec.get("bon_best")
    if isinstance(bb, dict) and bb.get("answer"):
        out.append({
            "source": "bon_best",
            "llm_name": bb.get("llm_name", "unknown"),
            "answer": bb["answer"],
            "quality_tier": None, "temperature": None, "top_p": None,
            "tag": f"bon_best::{bb.get('llm_name','unknown')}",
            "meta": {"origin": "RMB_BoN"}
        })
    
    for lo in rec.get("loser_list", []) or []:
        if isinstance(lo, dict) and lo.get("answer"):
            out.append({
                "source": "loser",
                "llm_name": lo.get("llm_name", "unknown"),
                "answer": lo["answer"],
                "quality_tier": None, "temperature": None, "top_p": None,
                "tag": f"loser::{lo.get('llm_name','unknown')}",
                "meta": {"origin": "RMB_BoN"}
            })
    return out

def process_item(item: Dict[str, Any], client: OpenAI, model: str,
                 temp: float, top_p: float, max_try: int, prompt_mode: str) -> Dict[str, Any]:
    """Processes a single input item: generates new responses and merges them with existing ones."""
    prompt_text = extract_prompt_text(item, mode=prompt_mode)
    
    # First, normalize and include the original BoN candidates.
    candidates = normalize_bon_candidates(item)

    # Second, generate new tiered-quality answers using the LLM.
    obj = call_gpt_tiered(client, model, prompt_text, temperature=temp, top_p=top_p, max_try=max_try)
    candidates.extend(gpt_answers_to_candidates(obj, llm_name=model, temperature=temp, top_p=top_p))

    # Write the results back into the item.
    item["rmb_prompt_materialized"] = {"prompt_mode": prompt_mode, "text": prompt_text}
    item["gpt_candidates"] = candidates
    return item

def main():
    """Main function to parse arguments and run the generation process."""
    parser = argparse.ArgumentParser(description="Generate tiered-quality responses for a given set of prompts.")
    parser.add_argument("--input_file", required=True, help="Path to the input JSON or JSONL file containing prompts.")
    parser.add_argument("--output_file", required=True, help="Path to the output JSON file.")
    parser.add_argument("--api_key", required=True, help="OpenAI API key.")
    parser.add_argument("--base_url", required=True, help="OpenAI API base URL.")
    parser.add_argument("--model", default=DEFAULT_MODEL, help=f"The model to use for generation (default: {DEFAULT_MODEL}).")
    parser.add_argument("--temperature", type=float, default=DEFAULT_TEMPERATURE, help=f"Sampling temperature (default: {DEFAULT_TEMPERATURE}).")
    parser.add_argument("--top_p", type=float, default=DEFAULT_TOP_P, help=f"Nucleus sampling top_p (default: {DEFAULT_TOP_P}).")
    parser.add_argument("--max_workers", type=int, default=4, help="Number of parallel workers for API calls.")
    parser.add_argument("--max_try", type=int, default=3, help="Maximum number of retries for a failed API call.")
    parser.add_argument("--prompt_mode", choices=["full_concat", "last_user"], default="full_concat",
                        help="Method to extract prompt text from conversational input.")
    args = parser.parse_args()

    client = OpenAI(base_url=args.base_url, api_key=args.api_key)
    data = load_any_json(args.input_file)

    results = []
    with ThreadPoolExecutor(args.max_workers) as executor:
        futures = []
        with tqdm(total=len(data), desc=f"Generating with {args.model} (temp={args.temperature})") as pbar:
            for item in data:
                futures.append(executor.submit(process_item, item, client, args.model,
                                               args.temperature, args.top_p, args.max_try, args.prompt_mode))
            
            for future in as_completed(futures):
                try:
                    results.append(future.result())
                except Exception as e:
                    print(f"An item failed after all retries: {e}")
                finally:
                    pbar.update(1)

    dump_json(args.output_file, results)
    print(f"\nProcessing complete. Results saved to: {args.output_file}")

if __name__ == "__main__":
    main()