from __future__ import annotations
import argparse
import tqdm
import os, json, time, textwrap
from typing import List, Literal, Optional, Tuple
from pydantic import BaseModel, Field, ValidationError

# If you use OpenAI's Python SDK:
#   pip install openai pydantic
# Or swap this client with your own LLM wrapper (LiteLLM, etc.)
from openai import OpenAI
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))

# ---------- Pydantic schemas ----------

Decision = Literal["ACCEPT", "REJECT"]

class FilterResult(BaseModel):
    decision: Decision = Field(
        description="ACCEPT if scenario passes all rules; REJECT otherwise."
    )
    reasons: List[str] = Field(
        default_factory=list,
        description="Short human-readable reasons mapping to violated rules if REJECT, or empty if ACCEPT."
    )


# ---------- Prompt template (from your spec) ----------

FILTERING_INSTRUCTIONS = textwrap.dedent("""
You are given a text description of a physical scenario.
Your task is to decide whether this scenario should be ACCEPTED or REJECTED for use in a text-to-video dataset.
A scenario should be rejected if it matches any of the rules below. Otherwise, it should be accepted.

Rejection Rules
1) Object Visibility
   - Reject if the main object is too small, thin, or subtle to be reliably generated or recognized (e.g., dust, raindrops, snowflakes, bubbles, needles).
   - Reject if the main object is transparent or semi-transparent and not large enough to be clearly seen.

2) Motion Characteristics
   - Reject if the main motion is very high-speed such that the object would blur or be hard to track (e.g., a bullet, a fast-thrown frisbee).
   - Reject if the scenario requires perceiving subtle differences in velocity that are unlikely to be visible (e.g., slight changes in glide speed).
   - Reject if the motion relies on rapid spinning or rotation that cannot be clearly visualized (e.g., frisbee spin, precise rolling rotation).

3) Invisible or Implicit Phenomena
   - Reject if the causal factor is not directly visible (e.g., wind, temperature, electricity, or internal contents hidden inside an object).
   - Reject if the scenario requires seeing occluded parts of objects (e.g., underwater sections, inside circuits).

4) Fine-grained Visual Effects
   - Reject if the outcome depends mainly on reflections, sparkles, glitter, or other small surface details that are difficult to generate and judge.

5) Material State Distinctions
   - Reject if the scenario requires distinguishing wet vs. dry states of a material (e.g., wet paint vs. dry paint, soaked wood, sponge saturation), since these are difficult to generate or evaluate.

Acceptance Guidance
- Accept scenarios where the main objects are macroscopic, visible, and stable in shape.
- Accept when the main actions are clearly perceptible physical interactions (falling, colliding, pouring, burning, floating, tearing).

OUTPUT FORMAT (STRICT JSON):
Return ONLY a JSON object with this structure:
{
  "result": {
    "decision": "ACCEPT" | "REJECT",
    "reasons": ["short reason 1", "short reason 2", ...]   // empty if ACCEPT
  }
}
Do not include any other keys or text.
""").strip()

def build_user_prompt(scenario: str) -> str:
    return f"Scenario:\n{scenario}\n\nRespond strictly in JSON as specified."

# ---------- LLM call with robust JSON validation & retry ----------

def call_gpt5_for_filter(
    scenario: str,
    model: str = "gpt-5",   # replace with your deployed GPT-5 model name
    max_retries: int = 2,
    retry_delay: float = 0.8,
    reasoning_effort: Literal['minimal', 'low', 'medium', 'high'] = 'medium',
    max_output_tokens: int = 12800,
) -> dict:
    """Call the model once for a scenario and return a validated FilterResult."""
    messages = [
        {"role": "system", "content": FILTERING_INSTRUCTIONS},
        {"role": "user", "content": build_user_prompt(scenario)},
    ]

    last_error: Optional[Exception] = None
    for attempt in range(max_retries + 1):
        try:
            completion = client.responses.parse(
                model=model,
                input=messages,
                text_format=FilterResult,
                reasoning={'effort': reasoning_effort},
                max_output_tokens=max_output_tokens,
            )
            parsed = getattr(completion, "output_parsed", None)
            if parsed is None:
                error_info = getattr(completion, "output_text", None) or "Model did not return a valid structured response."
                messages.append({"role": "assistant", "content": error_info})
                continue
            return parsed.model_dump()
        except (json.JSONDecodeError, ValidationError) as e:
            # Ask the model to repair the JSON if validation fails
            last_error = e
            repair_prompt = (
                "Your previous output was not valid per the required JSON schema. "
                "Please output ONLY valid JSON as specified, with keys: result.decision and result.reasons."
            )
            messages.append({"role": "assistant", "content": completion.choices[0].message.content if 'completion' in locals() else ""})
            messages.append({"role": "user", "content": repair_prompt})
            if attempt < max_retries:
                time.sleep(retry_delay)
                continue
        except Exception as e:
            # Non-validation API error; bubble up after retries
            last_error = e
            if attempt < max_retries:
                time.sleep(retry_delay)
                continue

    raise RuntimeError(f"LLM call failed after retries. Last error: {last_error}")


def filter_batch(scenarios: List[str], output_path: str):
    """
    Filter multiple scenarios.
    Returns a list of tuples: (scenario, decision, reasons)
    """
    if os.path.exists(output_path):
        with open(output_path, 'r') as f:
            num_resumed = len(f.readlines())
    else:
        num_resumed = 0
    for s in tqdm.tqdm(scenarios[num_resumed:]):
        result = call_gpt5_for_filter(s)
        result['scenario'] = s
        with open(output_path, 'a') as f:
            f.write(json.dumps(result) + '\n')


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('-i', '--input_file', required=True, help="The scenarios.")
    parser.add_argument('-o', '--output_file', required=True, help="The output file. It should be JSONL.")
    args = parser.parse_args()
    
    if not args.output_file.endswith('.jsonl'):
        raise ValueError("--output_file should be JSONL.")
    
    with open(args.input_file, 'r') as f:
        all_scenarios = [l.strip() for l in f]
    
    filter_batch(all_scenarios, args.output_file)
