"""
Parallel Iterative Plan Refinement with Constraint Feedback

This script implements an iterative refinement process for solving planning problems (calendar, meeting, trip, zebralogic)
by generating plans directly using LLMs, evaluating constraints, and providing feedback for improvement.

Features:
1. Generates plans directly using various LLM models (GPT, DeepSeek, etc.)
2. Uses prompts from force JSON files (calendar, meeting, trip, zebralogic)
3. Evaluates the solution against domain-specific constraints
4. Provides iterative feedback for constraint violations
5. Saves conversation history, plans, and evaluation results for each pass
6. Parallel processing with rate limiting for efficiency
7. Extracts DeepSeek reasoning content and counts tokens

Directory structure for outputs:
../output/Plan/{model_name}/{task}/n_pass/{example_id}/{pass_number}_pass/
  - conversation.json: Full conversation history
  - plan.json: Generated plan
  - evaluation.json: Constraint evaluation results
  - reasoning.txt: DeepSeek reasoning content (if available)
  - full_response.txt: Full model response

Usage:
python iterative_plan_refinement_parallel.py --task calendar --model gpt-4o-mini --start 0 --end 5
python iterative_plan_refinement_parallel.py --task all --model DeepSeek-V3 gpt-4o-mini --max_passes 3
python iterative_plan_refinement_parallel.py --task zebralogic --model DeepSeek-V3 --start 0 --end 5
"""

import argparse
import json
import os
import asyncio
import re
import time
from datetime import datetime
from kani import Kani, chat_in_terminal
from kani.engines.huggingface import HuggingEngine
from kani.engines.openai import OpenAIEngine
import concurrent.futures
from typing import List, Dict, Any, Tuple
import logging
import shutil
from openai import OpenAI
import tiktoken

# Set up logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('iterative_plan_refinement.log'),
        logging.StreamHandler()
    ]
)

def parse_args():
    """Parse command line arguments"""
    parser = argparse.ArgumentParser(description="Run iterative plan refinement with parallel processing")
    parser.add_argument("--model", type=str, required=True, help="Model to use (e.g., 'DeepSeek-R1')")
    parser.add_argument("--task", type=str, required=True, choices=["calendar", "trip", "meeting", "zebralogic"], help="Task type")
    parser.add_argument("--max_passes", type=int, default=5, help="Maximum number of refinement passes")
    parser.add_argument("--max_concurrent", type=int, default=10, help="Maximum number of concurrent examples to process")
    parser.add_argument("--rate_limit", type=int, default=60, help="Rate limit (requests per minute)")
    parser.add_argument("--start", type=int, help="Start example number (inclusive)")
    parser.add_argument("--end", type=int, help="End example number (exclusive)")
    parser.add_argument("--fresh", action="store_true", help="Clear all output directories before running")
    parser.add_argument("--examples", type=str, help="Comma-separated list of example numbers to run (e.g., '25,35')")
    
    args = parser.parse_args()
    
    # Clean up examples argument
    if args.examples:
        # Remove all quotes and whitespace
        args.examples = args.examples.replace('"', '').replace("'", "").strip()
        # Split and clean each number
        args.examples = ','.join(num.strip() for num in args.examples.split(','))
    
    return args

# Load API keys
try:
    with open("../../openai_research/deepseek_api_key.json") as f:
        keys = json.load(f)
except FileNotFoundError:
    logging.error("scheduling_key.json not found. Please create this file with your API keys.")
    exit(1)

def initialize_model(model_name, keys):
    """Initializes the Kani AI model based on the model name."""
    if model_name.startswith("gpt") or model_name.startswith("o"):
        if model_name == "o3-mini":
            model_name = "o3-mini"
        elif model_name == "gpt-4o-mini":
            model_name = "gpt-4o-mini-2024-07-18"
        elif model_name == "gpt-5-2025-08-07":
            model_name = "gpt-5-2025-08-07"
        engine = OpenAIEngine(keys["openai"], model=model_name, max_context_size=20000)
    elif model_name == "DeepSeek-R1":
        engine = OpenAIEngine(keys["deepseek"], model="deepseek-reasoner", api_base="https://api.deepseek.com", max_context_size=20000)
    elif model_name == "DeepSeek-V3":
        engine = OpenAIEngine(keys["deepseek"], model="deepseek-chat", api_base="https://api.deepseek.com", max_context_size=20000)
    else:
        engine = HuggingEngine(model_id=model_name)

    ai = Kani(engine, system_prompt="")
    return ai

# JSON schemas for different tasks
JSON_SCHEMAS = {
    "calendar": {
        "name": "time_range_schema",
        "schema": {
            "type": "object",
            "properties": {
                "time_range": {
                    "type": "string",
                    "pattern": "^\\{\\d{1,2}:\\d{2}:\\d{1,2}:\\d{2}\\}$"
                },
                "day": {
                    "type": "string",
                }
            },
            "required": ["time_range", "day"],
            "em": ["Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday", "Sunday"]
        }
    },
    "meeting": {
        "name": "meeting_plan_schema",
        "schema": {
            "type": "object",
            "properties": {
                "itinerary": {
                    "type": "array",
                    "items": {
                        "type": "object",
                        "properties": {
                            "action": {"type": "string", "enum": ["meet"]},
                            "location": {"type": "string"},
                            "person": {"type": "string"},
                            "start_time": {"type": "string"},
                            "end_time": {"type": "string"}
                        },
                        "required": ["action", "location", "person", "start_time", "end_time"]
                    }
                }
            },
            "required": ["itinerary"]
        }
    },
    "trip": {
        "name": "trip_plan_schema",
        "schema": {
            "type": "object",
            "properties": {
                "itinerary": {
                    "type": "array",
                    "items": {
                        "anyOf": [
                            {
                                "type": "object",
                                "properties": {
                                    "day_range": {"type": "string", "pattern": "^Day \\d+-\\d+$"},
                                    "place": {"type": "string"}
                                },
                                "required": ["day_range", "place"]
                            },
                            {
                                "type": "object",
                                "properties": {
                                    "flying": {"type": "string", "pattern": "^Day \\d+-\\d+$"},
                                    "from": {"type": "string"},
                                    "to": {"type": "string"}
                                },
                                "required": ["flying", "from", "to"]
                            }
                        ]
                    }
                }
            },
            "required": ["itinerary"]
        }
    },
    "zebralogic": {
        "name": "zebralogic_schema",
        "schema": {
            "type": "object",
            "properties": {
                "solution": {
                    "type": "object",
                    "properties": {
                        "header": {"type": "array", "items": {"type": "string"}},
                        "rows": {
                            "type": "array",
                            "items": {
                                "type": "array",
                                "items": {"type": "string"}
                            }
                        }
                    },
                    "required": ["header", "rows"]
                }
            },
            "required": ["solution"]
        }
    }
}

def get_task_prompt(task, example):
    """Get the appropriate prompt for the task type"""
    if task == "calendar":
        return example['prompt_0shot']
    elif task == "meeting":
        return example['prompt_0shot']
    elif task == "trip":
        return example['prompt_0shot']
    elif task == "zebralogic":
        return example['prompt_0shot']
    else:
        raise ValueError(f"Unknown task type: {task}")

def add_json_formatting_instruction(prompt, task, golden_plan=None):
    """Add JSON formatting instruction to the prompt based on task type"""
    if task == "calendar":
        return prompt + "\n\nPlease output the proposed time in the following JSON format:\n{\"time_range\": \"{HH:MM:HH:MM}\", \"day\": \"<DAY>\"}. For example, if the proposed time is Tuesday, 14:30 to 15:30, the output should be:\n{\"time_range\": \"{14:30:15:30}\", \"day\": \"Tuesday\"}."
    elif task == "meeting":
        return prompt + "\n\nPlease output the meeting schedule in the following JSON format:\n{\"itinerary\": [{\"action\": \"meet\", \"person\": \"<PERSON_NAME>\", \"start_time\": \"<HH:MM>\", \"end_time\": \"<HH:MM>\"}]}. Make sure to include the person's name for each meeting."
    elif task == "trip":
        return prompt + "\n\nPlease output the trip plan in the following JSON format:\n{\"itinerary\": [{\"day_range\": \"Day X-Y\", \"place\": \"<CITY>\"}]}. Include all city visits with their day ranges. Do not include separate flight entries in the JSON output.\n\nIMPORTANT: When you fly from city A to city B on day X, that day counts for BOTH cities. For example:\n- If you stay in Venice from Day 1-3 and fly to Vienna on Day 3, then:\n  - Venice: Day 1-3 (3 days)\n  - Vienna: Day 3-6 (4 days, including the flight day)\n- The flight day (Day 3) is counted for both Venice and Vienna.\n- Do NOT create separate flight entries in the JSON."
    elif task == "zebralogic":
        headers = None
        if isinstance(golden_plan, dict):
            if "solution" in golden_plan:
                headers = golden_plan["solution"].get("header")
            else:
                headers = golden_plan.get("header")

        if headers:
            return prompt + (
                f"\n\nPlease output the solution in the following JSON format:\n"
                f"{{\"solution\": {{\"header\": {json.dumps(headers)}, "
                f"\"rows\": [[\"1\", \"Value1\", \"Value2\", ...], [\"2\", \"Value1\", \"Value2\", ...]]}}}}"
            )
        else:
            return prompt + (
                "\n\nPlease output the solution in the following JSON format:\n"
                "{\"solution\": {\"header\": [\"House\", \"Attribute1\", \"Attribute2\", ...], "
                "\"rows\": [[\"1\", \"Value1\", \"Value2\", ...], [\"2\", \"Value1\", \"Value2\", ...]]}}"
            )
    else:
        return prompt

def evaluate_calendar(constraints, pred_dict):
    """Evaluate calendar constraints comprehensively (flat dict, not nested)"""
    # Check for missing fields - handle both time_range and start_time/end_time formats
    if not pred_dict or "day" not in pred_dict:
        return False, {"missing_fields": True}
    
    pred_day = pred_dict["day"]
    
    # Handle time_range format (e.g., "13:00:13:30" or "{13:00:13:30}")
    if "time_range" in pred_dict:
        time_range = pred_dict["time_range"]
        # Remove curly braces if present
        if time_range.startswith("{") and time_range.endswith("}"):
            time_range = time_range[1:-1]
        
        # Parse time_range format "HH:MM:HH:MM"
        try:
            # Split by ":" and reconstruct start and end times
            parts = time_range.split(":")
            if len(parts) == 4:  # "HH:MM:HH:MM" format
                pred_start = f"{parts[0]}:{parts[1]}"
                pred_end = f"{parts[2]}:{parts[3]}"
            else:
                return False, {"invalid_time_range_format": time_range}
        except ValueError:
            return False, {"invalid_time_range_format": time_range}
    
    # Handle start_time/end_time format
    elif "start_time" in pred_dict and "end_time" in pred_dict:
        pred_start = pred_dict["start_time"]
        pred_end = pred_dict["end_time"]
    else:
        return False, {"missing_fields": True}
    
    # Check for None values in any of the fields
    if pred_day is None or pred_start is None or pred_end is None:
        return False, {"null_fields": True}
    
    # Convert time strings to numerical values
    if isinstance(pred_start, str):
        pred_start_parts = pred_start.split(":")
        try:
            pred_start = float(pred_start_parts[0]) + float(pred_start_parts[1]) / 60
        except ValueError:
            return False, {"unparsable": True}
    if isinstance(pred_end, str):
        pred_end_parts = pred_end.split(":")
        try:
            pred_end = float(pred_end_parts[0]) + float(pred_end_parts[1]) / 60
        except ValueError:
            return False, {"unparsable": True}
    
    meeting_duration = constraints.get("meeting_duration")
    if meeting_duration is None:
        return False, {"missing_meeting_duration": True}
    if (pred_end - pred_start) != meeting_duration:
        return False, {"meeting_duration": meeting_duration}
    
    for disallowed_range in constraints.get("disallowed_ranges", []):
        if disallowed_range["day"] == pred_day:
            if (pred_start >= disallowed_range["start"] and pred_start < disallowed_range["end"]) or \
               (pred_end > disallowed_range["start"] and pred_end <= disallowed_range["end"]) or \
               (pred_start <= disallowed_range["start"] and pred_end >= disallowed_range["end"]):
                return False, disallowed_range
    
    return True, {}

def evaluate_meeting(constraints, pred_dict):
    """Evaluate meeting constraints comprehensively (flat dict, not nested)"""
    from datetime import datetime
    
    def parse_time(s):
        # Return None for invalid time formats instead of raising exception
        try:
            # handles "H:MM" or "H:MMAM"/"H:MMPM"
            if s.endswith(("AM", "PM")):
                return datetime.strptime(s, "%I:%M%p")
            return datetime.strptime(s, "%H:%M")
        except ValueError:
            return None
    
    # Check for missing itinerary
    if not pred_dict or "itinerary" not in pred_dict:
        return False, {"missing_itinerary": True}
    
    itinerary = pred_dict["itinerary"]
    if not isinstance(itinerary, list):
        return False, {"invalid_itinerary": True}
    
    # Build map person→availability & location
    people = {p["name"]: p for p in constraints.get("people_to_meet", [])}
    start_location = constraints.get("start", {}).get("location")
    start_time = constraints.get("start", {}).get("time_of_day")
    
    # Parse predicted meetings
    meetings = []
    for m in itinerary:
        if "person" not in m or "start_time" not in m or "end_time" not in m:
            return False, {"missing_meeting_fields": m}
        
        name = m["person"]
        # Require person name to be provided
        if not name or name == "Unknown":
            return False, {"missing_person_name": "Person name must be provided for each meeting"}
        
        start = parse_time(m["start_time"])
        end = parse_time(m["end_time"])
        if start is None or end is None:  # Invalid time format
            return False, {"invalid_time_format": {"start": m["start_time"], "end": m["end_time"]}}
        
        loc = people.get(name, {}).get("location")
        meetings.append({"person": name, "start": start, "end": end, "location": loc})
    
    # Sort chronologically
    meetings.sort(key=lambda x: x["start"])
    
    # 1) Each meeting must lie within that person's available window
    for m in meetings:
        p = people.get(m["person"])
        if not p:
            continue
        avail = p["time_of_day"]
        av_from = parse_time(avail["from"])
        av_to = parse_time(avail["to"])
        if m["start"] < av_from or m["end"] > av_to:
            return False, {"person": m["person"], "time_of_day": avail}
        
        # Check meeting duration requirement
        min_duration = p.get("min_meeting_duration", 0)
        if min_duration > 0:
            actual_duration = (m["end"] - m["start"]).total_seconds() / 60
            if actual_duration < min_duration:
                return False, {
                    "meeting_duration": {
                        "person": m["person"],
                        "required": min_duration,
                        "actual": actual_duration
                    }
                }
    
    # 2) Build travel‐time lookup
    travel = {}
    for d in constraints.get("travel_distances", []):
        pl = d["place"]
        frm = pl.get("from", constraints.get("start", {}).get("location"))
        to = pl["to"]
        travel[(frm, to)] = d["walking_time"]
    
    # 3) Check start‐to‐first meeting
    # Parse start time
    if start_time and meetings:
        st = parse_time(start_time)
        first = meetings[0]
        # 0a) meeting must not start before you arrive
        if first["start"] < st:
            return False, {"start_time": start_time}
        # 0b) travel from start_location
        walk0 = travel.get((start_location, first["location"]))
        gap0 = (first["start"] - st).total_seconds() / 60
        if walk0 is not None and walk0 > gap0:
            return False, {
                "travel_start": {
                    "to_person": first["person"],
                    "to_location": first["location"],
                    "travel_time": walk0
                }
            }
    
    # 4) Check following meetings
    for a, b in zip(meetings, meetings[1:]):
        gap_mins = (b["start"] - a["end"]).total_seconds() / 60
        walk = travel.get((a["location"], b["location"]))
        if walk is not None and walk > gap_mins:
            return False, {
                "travel": {
                    "from_person": a["person"],
                    "to_person": b["person"],
                    "from_location": a["location"],
                    "to_location": b["location"],
                    "travel_time": walk
                }
            }
    
    return True, {}

def evaluate_trip(constraints, pred_dict):
    """Evaluate trip constraints comprehensively (flat dict, not nested)"""
    # Check for missing itinerary
    if not pred_dict or "itinerary" not in pred_dict:
        return False, {"missing_itinerary": True}
    
    itinerary = pred_dict["itinerary"]
    if not isinstance(itinerary, list):
        return False, {"invalid_itinerary": True}
    
    # Parse itinerary segments
    segments = []
    for seg in itinerary:
        if "day_range" not in seg or "place" not in seg:
            return False, {"missing_segment_fields": seg}
        
        # Parse day range
        day_range = seg["day_range"]
        if not day_range.startswith("Day "):
            return False, {"invalid_day_range_format": day_range}
        dr = day_range.replace("Day ", "")
        if "-" in dr:
            start_s, end_s = dr.split("-")
        else:
            start_s, end_s = [dr, dr]
        try:
            start, end = int(start_s), int(end_s)
        except ValueError:
            return False, {"unparsable_day_range": day_range}
        segments.append({"place": seg["place"], "start": start, "end": end})
    
    # Sort segments by start day to ensure chronological order for constraint evaluation
    segments.sort(key=lambda x: x["start"])
    
    # Validate trip starts on day 1 and ends on the correct day
    trip_length = constraints.get("trip_length")
    if trip_length is not None:
        if not segments or segments[0]["start"] != 1 or segments[-1]["end"] != trip_length:
            return False, {"trip_length": {"required": trip_length, "actual": "invalid_start_end"}}
        
        # Check for gaps or overlaps between consecutive segments
        for a, b in zip(segments, segments[1:]):
            if a["end"] != b["start"]:
                return False, {"gap_or_overlap": {"between": f"Day {a['end']} and Day {b['start']}"}}
    
    # Check stay_days (convert from city_length format for consistency with Python refinement)
    # Convert city_length format to stay_days format
    city_length = constraints.get("city_length", [])
    stay_days = {}
    for city_req in city_length:
        stay_days[city_req["city"]] = city_req["days"]
    
    for seg in segments:
        required = stay_days.get(seg["place"])
        if required is not None:
            actual = seg["end"] - seg["start"] + 1
            if actual != required:
                return False, {"stay_days": {seg["place"]: required}}
    
    # Check flight constraints
    allowed_flights = [(d["from"], d["to"]) for d in constraints.get("direct_flights", [])]
    for a, b in zip(segments, segments[1:]):
        pair = (a["place"], b["place"])
        if pair not in allowed_flights:
            return False, {"flight": {"from": a["place"], "to": b["place"]}}
    
    # Check event_ranges (must fall entirely within the visit segment)
    for ev in constraints.get("city_day_ranges", []):
        place = ev["city"]
        container = next((s for s in segments if s["place"] == place), None)
        if not container:
            return False, {"missing_place": place}
        if container["start"] > ev["start"] or container["end"] < ev["end"]:
            return False, {"event_range": ev}
    
    return True, {}

def evaluate_zebralogic(constraints, predicted_output):
    """Evaluate ZebraLogic solution with more robust comparison"""
    if not predicted_output or not isinstance(predicted_output, list):
        return False, {"invalid_output": "No valid output structure found"}

    # tolerate either flattened or nested constraints
    gp = (
        constraints.get("golden_plan")
        or constraints.get("constraints", {}).get("golden_plan")
        or {}
    )
    # allow string or dict; unwrap "solution" if present
    if isinstance(gp, str):
        try:
            gp = json.loads(gp)
        except Exception:
            gp = {}
    if isinstance(gp, dict) and "solution" in gp:
        gp = gp["solution"]

    golden_output = parse_zebralogic_golden(gp)
    
    if not isinstance(golden_output, list):
        return False, {"invalid_golden": "Invalid golden solution format"}
    
    # First check for exact match (string comparison of sorted JSON)
    try:
        pred_str = json.dumps(predicted_output, sort_keys=True)
        gold_str = json.dumps(golden_output, sort_keys=True)
        if pred_str == gold_str:
            return True, {}
    except Exception:
        pass
    
    # If not exact match, do field-by-field comparison
    violations = []
    
    # Check structure matches
    if len(predicted_output) != len(golden_output):
        violations.append(f"Wrong number of houses: expected {len(golden_output)}, got {len(predicted_output)}")
    
    # Check each house
    for house_num, (gold_house, pred_house) in enumerate(zip(golden_output, predicted_output), 1):
        if not isinstance(pred_house, dict):
            violations.append(f"House {house_num} is not a valid dictionary")
            continue
            
        # Check all fields in golden exist in predicted (case insensitive)
        for field, gold_value in gold_house.items():
            field_lower = field.lower()
            pred_fields = {k.lower(): v for k, v in pred_house.items()}
            
            if field_lower not in pred_fields:
                violations.append(f"House {house_num} missing field '{field}'")
            else:
                pred_value = pred_house.get(field)  # Use original case for comparison
                if str(pred_value).strip().lower() != str(gold_value).strip().lower():
                    violations.append(
                        f"House {house_num} wrong {field}: expected '{gold_value}', got '{pred_value}'"
                    )
    
    if violations:
        return False, {"violations": violations}
    return True, {}

def parse_zebralogic_golden(golden_plan):
    """Parse the golden solution into a structured format (wrapped or unwrapped)."""
    # allow string
    if isinstance(golden_plan, str):
        try:
            golden_plan = json.loads(golden_plan)
        except Exception:
            return {"error": "Invalid golden plan format"}

    # unwrap if needed
    if isinstance(golden_plan, dict) and "solution" in golden_plan:
        golden_plan = golden_plan["solution"]

    if not isinstance(golden_plan, dict) or "header" not in golden_plan or "rows" not in golden_plan:
        return {"error": "Invalid golden plan format"}

    headers = golden_plan["header"]
    rows = golden_plan["rows"]

    solution = [dict(zip(headers, row)) for row in rows]
    return solution

def normalize_trip_itinerary(itinerary):
    """Normalize trip itinerary for exact match comparison"""
    if not itinerary or "itinerary" not in itinerary:
        return {}
    
    normalized = {"itinerary": []}
    for item in itinerary["itinerary"]:
        if "day_range" in item and "place" in item:
            normalized["itinerary"].append({
                "day_range": item["day_range"],
                "place": item["place"]
            })
        elif "flying" in item and "from" in item and "to" in item:
            normalized["itinerary"].append({
                "flying": item["flying"],
                "from": item["from"],
                "to": item["to"]
            })
    
    return normalized

def _extract_balanced_json_candidates(text: str) -> list[str]:
    """Return candidate JSON objects by scanning for balanced braces.
    Longest, more 'complete' objects appear earlier in results."""
    candidates = []
    n = len(text)
    i = 0
    while i < n:
        if text[i] == '{':
            depth = 0
            start = i
            j = i
            while j < n:
                c = text[j]
                if c == '{':
                    depth += 1
                elif c == '}':
                    depth -= 1
                    if depth == 0:
                        candidates.append(text[start:j+1])
                        break
                j += 1
            i = j + 1
        else:
            i += 1
    # Prefer objects that look like the solution block, then sort by length desc
    def score(s: str) -> tuple[int, int]:
        has_solution = '"solution"' in s
        has_header = '"header"' in s
        has_rows = '"rows"' in s
        return (int(has_solution) + int(has_header) + int(has_rows), len(s))
    candidates.sort(key=score, reverse=True)
    return candidates

def extract_answer_from_text(text, task, golden_plan=None):
    """Extract structured answer from text response"""
    import re
    
    if task == "calendar":
        # Try all JSON objects with both "time_range" and "day" keys (robust for pretty-printed and compact JSON)
        try:
            # Use a more flexible regex that handles newlines and pretty-printed JSON
            json_pattern_multi = r'\{[\s\S]*?"time_range"[\s\S]*?"day"[\s\S]*?\}'
            matches = re.findall(json_pattern_multi, text, re.DOTALL)
            for match in matches:
                try:
                    result = json.loads(match)
                    if "time_range" in result and "day" in result:
                        return result
                except Exception:
                    continue
        except Exception:
            pass
        
        # Look for time range pattern in the format "Monday, 13:30 - 14:30"
        time_pattern = r'(Monday|Tuesday|Wednesday|Thursday|Friday|Saturday|Sunday),?\s*(\d{1,2}:\d{2})\s*-\s*(\d{1,2}:\d{2})'
        match = re.search(time_pattern, text, re.IGNORECASE)
        
        if match:
            day = match.group(1)
            start_time = match.group(2)
            end_time = match.group(3)
            
            # Convert to the expected format {HH:MM:HH:MM}
            time_range = f"{{{start_time}:{end_time}}}"
            
            return {
                "day": day,
                "time_range": time_range
            }
        
        return None
    
    elif task == "meeting":
        # Use LLM-based extraction for meetings (following SMT/Python approach)
        import os
        
        # Try to get OpenAI API key
        openai_key = None
        try:
            # Try to load from scheduling_key.json
            with open("../../openai_research/deepseek_api_key.json") as f:
                key_data = json.load(f)
                openai_key = key_data.get("openai")
        except (FileNotFoundError, KeyError):
            # Try environment variable
            openai_key = os.getenv("OPENAI_API_KEY")
        
        if not openai_key:
            print("Warning: Could not find OpenAI API key for answer extraction")
            return None
        
        try:
            client = OpenAI(api_key=openai_key)
        except Exception as e:
            print(f"Warning: Could not initialize OpenAI client for answer extraction: {e}")
            return None
        
        # Define extraction prompt for meetings
        prompt = f"Given the following meeting schedule:\n{text}\nExtract the time and the person of each meeting in a JSON format like {{\"itinerary\": [{{\"action\": \"meet\", \"person\": \"David\",\"start_time\": \"13:00\", \"end_time\": \"14:00\"}}, ...]}}. Do not include location. Only keep the meeting times, and ignore time for starting, waiting, or traveling. The time should be converted to a 24-hour format. If no time range is given at all, output an empty JSON."
        
        try:
            response = client.chat.completions.create(
                model="gpt-4o-mini",  # Using gpt-4o-mini as fallback
                messages=[
                    {
                        "role": "user",
                        "content": prompt
                    }
                ],
                response_format={"type": "json_object"},
                temperature=0,
                max_tokens=2000,
                top_p=1
            )
            output_json = response.choices[0].message.content
            print(f"Extracted answer JSON: {output_json}")
            return json.loads(output_json)
        except Exception as e:
            print(f"Error in answer extraction: {e}")
            return None
    
    elif task == "trip":
        # First try to extract JSON from the response (for model outputs)
        json_pattern = r'```json\s*(\{.*?\})\s*```'
        json_match = re.search(json_pattern, text, re.DOTALL | re.IGNORECASE)
        if json_match:
            try:
                json_str = json_match.group(1)
                result = json.loads(json_str)
                if "itinerary" in result and isinstance(result["itinerary"], list):
                    return result
            except json.JSONDecodeError:
                pass
        
        # Try to find JSON without code blocks - improved pattern
        # Look for JSON objects that contain "itinerary" field
        json_pattern2 = r'\{[^}]*"itinerary"[^}]*\}'
        json_match2 = re.search(json_pattern2, text, re.DOTALL)
        if json_match2:
            try:
                # Find the complete JSON object by finding the outermost braces
                start_pos = text.rfind('{', 0, json_match2.start())
                if start_pos == -1:
                    start_pos = json_match2.start()
                
                # Find the matching closing brace
                brace_count = 0
                end_pos = start_pos
                for i in range(start_pos, len(text)):
                    if text[i] == '{':
                        brace_count += 1
                    elif text[i] == '}':
                        brace_count -= 1
                        if brace_count == 0:
                            end_pos = i + 1
                            break
                
                json_str = text[start_pos:end_pos]
                result = json.loads(json_str)
                if "itinerary" in result and isinstance(result["itinerary"], list):
                    return result
            except json.JSONDecodeError:
                pass
        
        # Try to find any JSON object that might contain itinerary
        # This is a more aggressive approach for malformed JSON
        try:
            # Look for the start of a JSON object
            start_pos = text.find('{')
            end_pos = text.rfind('}')
            if start_pos != -1 and end_pos > start_pos and 'itinerary' in text:
                json_str = text[start_pos:end_pos+1]
                result = json.loads(json_str)
                if "itinerary" in result and isinstance(result["itinerary"], list):
                    return result
        except json.JSONDecodeError:
            pass
        
        # Fallback: Parse golden trip plan format (for gold text)
        import re
        
        itinerary = []
        
        for line in text.split('\n'):
            line = line.strip()
            if not line or not line.startswith('**Day'):
                continue
                
            day_match = re.search(r'Day (\d+)(?:-(\d+))?', line)
            if not day_match:
                continue
                
            start_day = int(day_match.group(1))
            end_day = int(day_match.group(2)) if day_match.group(2) else start_day
            day_range = f"Day {start_day}-{end_day}"
            
            place_match = re.search(r'(?:Arriving in|Visit|Stay in|at) ([^\s,.]+)', line, re.IGNORECASE)
            if place_match:
                itinerary.append({
                    "day_range": day_range,
                    "place": place_match.group(1)
                })
        
        # Sort by day range start for consistent comparison
        itinerary.sort(key=lambda x: (
            int(x["day_range"].split()[1].split("-")[0]),
            x["place"]
        ))
        
        if itinerary:
            return {"itinerary": itinerary}
        
        return None
    
    elif task == "zebralogic":
        # 1) Try fenced JSON
        json_pattern = r'```json\s*(\{.*?\})\s*```'
        json_match = re.search(json_pattern, text, re.DOTALL | re.IGNORECASE)
        if json_match:
            try:
                result = json.loads(json_match.group(1))
                return parse_zebralogic_output(result)
            except json.JSONDecodeError:
                pass

        # 2) Try whole-text JSON
        try:
            obj = json.loads(text)
            return parse_zebralogic_output(obj)
        except json.JSONDecodeError:
            pass

        # 3) Balanced-brace extraction (works without code fences)
        for cand in _extract_balanced_json_candidates(text):
            try:
                obj = json.loads(cand)
                parsed = parse_zebralogic_output(obj)
                if parsed:
                    return parsed
            except json.JSONDecodeError:
                continue

        return None
    
    return None

def parse_zebralogic_output(output):
    """Parse model output into structured format: list[dict] per house."""
    if not output:
        return None
    try:
        if isinstance(output, str):
            # Try direct JSON, else try to locate an inner object
            try:
                output = json.loads(output)
            except json.JSONDecodeError:
                m = re.search(r'\{.*\}', output, re.DOTALL)
                if not m:
                    return None
                output = json.loads(m.group(0))

        if isinstance(output, dict):
            solution = output.get("solution", output)
            if "header" in solution and "rows" in solution:
                headers = solution["header"]
                rows = solution["rows"]
                result = []
                for row in rows:
                    if len(row) == len(headers):
                        result.append(dict(zip(headers, row)))
                return result
    except Exception as e:
        logging.warning(f"Error parsing ZebraLogic output: {e}")
    return None

def format_constraint_feedback(violated_constraints, task):
    """Format constraint violations into detailed feedback for the model"""
    if not violated_constraints:
        return "All constraints are satisfied!"
    
    feedback = "The following constraints were violated:\n\n"
    
    if task == "calendar":
        if "meeting_duration" in violated_constraints:
            duration_info = violated_constraints["meeting_duration"]
            if isinstance(duration_info, dict):
                feedback += f"- Meeting duration should be {duration_info['required']} hours, but you provided {duration_info['actual']:.2f} hours\n"
            else:
                feedback += f"- Meeting duration should be {duration_info} hours\n"
        
        if "disallowed_range" in violated_constraints:
            range_info = violated_constraints["disallowed_range"]
            feedback += f"- Time conflicts with existing schedule on {range_info['day']} from {range_info['start']} to {range_info['end']}\n"
        
        if "work_hours" in violated_constraints:
            hours = violated_constraints["work_hours"]
            feedback += f"- Meeting must be within work hours (9:00-17:00)\n"
        
        if "unparsable_time_range" in violated_constraints:
            feedback += f"- Time format could not be parsed. Use format: {{HH:MM:HH:MM}}\n"
    
    elif task == "meeting":
        if "num_people_to_meet" in violated_constraints:
            num_required = violated_constraints["num_people_to_meet"]
            feedback += f"- Must meet with exactly {num_required} people\n"
        
        if "unmet_people" in violated_constraints:
            people_info = violated_constraints["unmet_people"]
            feedback += f"- Need to meet {len(people_info)} people: {', '.join(people_info)}\n"
        
        if "person_unavailable" in violated_constraints:
            person_info = violated_constraints["person_unavailable"]
            feedback += f"- {person_info['person']} is not available during the scheduled time\n"
        
        if "insufficient_travel_time" in violated_constraints:
            travel_info = violated_constraints["insufficient_travel_time"]
            feedback += f"- Insufficient travel time between {travel_info['from']} and {travel_info['to']} (need {travel_info['required']} min, have {travel_info['available']:.1f} min)\n"
        
        if "invalid_time_format" in violated_constraints:
            time_info = violated_constraints["invalid_time_format"]
            feedback += f"- Invalid time format: {time_info['start']} or {time_info['end']}\n"
    
    elif task == "trip":
        if "trip_length" in violated_constraints:
            length_info = violated_constraints["trip_length"]
            if length_info['actual'] == "invalid_start_end":
                feedback += f"- Trip must start on Day 1 and end on Day {length_info['required']}\n"
            else:
                feedback += f"- Trip must cover {length_info['required']} days, but covers {length_info['actual']}\n"
        
        if "stay_days" in violated_constraints:
            for city, required_days in violated_constraints["stay_days"].items():
                feedback += f"- Must stay in {city} for exactly {required_days} days\n"
        
        if "gap_or_overlap" in violated_constraints:
            gap_info = violated_constraints["gap_or_overlap"]
            feedback += f"- There is a gap or overlap {gap_info['between']}\n"
        
        if "flight" in violated_constraints:
            flight_info = violated_constraints["flight"]
            feedback += f"- No direct flight available from {flight_info['from']} to {flight_info['to']}\n"
        
        if "missing_place" in violated_constraints:
            feedback += f"- Missing required city: {violated_constraints['missing_place']}\n"
    
    elif task == "zebralogic":
        if "violations" in violated_constraints:
            feedback += "\n".join(f"- {v}" for v in violated_constraints["violations"][:10])  # Limit to 10 violations
        else:
            for k, v in violated_constraints.items():
                if k != "differences":  # Skip GPT-generated differences which may be unreliable
                    feedback += f"- {k}: {v}"
    
    feedback += "\n\nPlease revise your plan to address these issues."
    return feedback

def check_exact_match(gold_formatted, pred_formatted, task):
    """Check if prediction exactly matches gold answer"""
    if not gold_formatted or not pred_formatted:
        return False
    
    if task == "calendar":
        # Compare day and time_range
        gold_day = gold_formatted.get("day", "").lower()
        gold_time = gold_formatted.get("time_range", "")
        pred_day = pred_formatted.get("day", "").lower()
        pred_time = pred_formatted.get("time_range", "")
        
        return gold_day == pred_day and gold_time == pred_time
    
    elif task == "meeting":
        # Compare itinerary lists
        gold_itinerary = gold_formatted.get("itinerary", [])
        pred_itinerary = pred_formatted.get("itinerary", [])
        
        if len(gold_itinerary) != len(pred_itinerary):
            return False
        
        # Sort meetings for comparison
        def sort_key(meeting):
            return (meeting.get("person", ""), meeting.get("start_time", ""))
        
        gold_sorted = sorted(gold_itinerary, key=sort_key)
        pred_sorted = sorted(pred_itinerary, key=sort_key)
        
        for gold_meeting, pred_meeting in zip(gold_sorted, pred_sorted):
            if (gold_meeting.get("person", "").lower() != pred_meeting.get("person", "").lower() or
                gold_meeting.get("start_time", "") != pred_meeting.get("start_time", "") or
                gold_meeting.get("end_time", "") != pred_meeting.get("end_time", "")):
                return False
        
        return True
    
    elif task == "trip":
        # Compare itinerary
        gold_itinerary = gold_formatted.get("itinerary", [])
        pred_itinerary = pred_formatted.get("itinerary", [])
        
        if len(gold_itinerary) != len(pred_itinerary):
            return False
        
        for gold_item, pred_item in zip(gold_itinerary, pred_itinerary):
            if (gold_item.get("day_range") != pred_item.get("day_range") or
                gold_item.get("place", "").lower() != pred_item.get("place", "").lower()):
                return False
        
        return True
    
    elif task == "zebralogic":
        # Compare the entire solution structure
        try:
            gold_str = json.dumps(gold_formatted, sort_keys=True)
            pred_str = json.dumps(pred_formatted, sort_keys=True)
            return gold_str == pred_str
        except Exception:
            return False
    
    return False

class RateLimiter:
    """Simple rate limiter to avoid API limits"""
    def __init__(self, requests_per_second: float):
        self.requests_per_second = requests_per_second
        self.last_request_time = 0
    
    async def wait(self):
        if self.requests_per_second <= 0:
            return
        
        current_time = time.time()
        time_since_last = current_time - self.last_request_time
        min_interval = 1.0 / self.requests_per_second
        
        if time_since_last < min_interval:
            wait_time = min_interval - time_since_last
            await asyncio.sleep(wait_time)
        
        self.last_request_time = time.time()

def count_tokens(text):
    """Count tokens in text with fallback to character count if tiktoken fails"""
    try:
        # Define the model (e.g., "gpt-3.5-turbo" or "gpt-4")
        model_name = "gpt-4o"  # this doesn't matter for DeepSeek models
        # Initialize the encoder for the specific model
        encoding = tiktoken.encoding_for_model(model_name)
        # Document to be tokenized
        document = f"{text}"
        # Count the tokens
        tokens = encoding.encode(document)
        token_count = len(tokens)
        return token_count
    except Exception as e:
        logging.warning(f"Token counting failed, using fallback method: {str(e)}")
        return len(text)

async def run_model_with_rate_limit(ai, prompt, rate_limiter, model_name):
    """Run the AI model with rate limiting and extract reasoning content for DeepSeek models"""
    await rate_limiter.wait()
    
    start_time = time.time()
    try:
        # 1) Normal Kani round (gets assistant .text for your code extraction)
        msg = await ai.chat_round_str(prompt)
        response = msg
        api_time = time.time() - start_time

        reasoning_content = ""
        full_token_count = count_tokens(response)  # fallback if usage not available
        reasoning_tokens = 0

        # 2) If DeepSeek, ALSO call the raw OpenAI-compatible endpoint to get reasoning_content + usage
        if model_name.startswith("DeepSeek"):
            engine = ai.engine
            ds_model = "deepseek-reasoner" if "R1" in model_name else "deepseek-chat"

            # Build messages; include system if you want (you set system_prompt="" above)
            raw = await engine.client.chat.completions.create(
                model=ds_model,
                messages=[{"role": "user", "content": prompt}],
            )
            raw_msg = raw.choices[0].message

            # Chain-of-thought lives here for R1:
            reasoning_content = getattr(raw_msg, "reasoning_content", "") or ""

            # Prefer server usage counts if present
            try:
                usage = getattr(raw, "usage", None)
                if usage:
                    full_token_count = getattr(usage, "total_tokens", full_token_count) or full_token_count
                    # DeepSeek does not split "reasoning tokens" in usage today; keep your fallback if you want
            except Exception:
                pass

            # Optional: if you want a rough number for reasoning tokens, count the reasoning text
            reasoning_tokens = count_tokens(reasoning_content) if reasoning_content else 0

        else:
            # non-DeepSeek path unchanged
            reasoning_tokens = 0

        return response, api_time, full_token_count, reasoning_content, reasoning_tokens

    except Exception as e:
        logging.error(f"Error calling model {model_name}: {e}")
        return None, 0, 0, "", 0

def save_output_files(task, model_name, example_id, pass_num, conversation, plan, evaluation):
    """Save all output files for a given pass"""
    output_dir = f"../output/Plan/{model_name}/{task}/token_pass/{example_id}/{pass_num}_pass"
    os.makedirs(output_dir, exist_ok=True)
    
    # Save conversation
    with open(f"{output_dir}/conversation.json", "w") as f:
        json.dump(conversation, f, indent=4)
    
    # Save plan
    with open(f"{output_dir}/plan.json", "w") as f:
        json.dump(plan, f, indent=4)
    
    # Save evaluation results (ensure token data is included)
    evaluation_with_tokens = evaluation.copy()
    if "timing" not in evaluation_with_tokens:
        evaluation_with_tokens["timing"] = {}
    evaluation_with_tokens["timing"].setdefault("total_tokens", 0)
    evaluation_with_tokens["timing"].setdefault("reasoning_tokens", 0)
    
    # Ensure reasoning content is properly included
    if "reasoning_content" not in evaluation_with_tokens and evaluation.get("reasoning_content"):
        evaluation_with_tokens["reasoning_content"] = evaluation["reasoning_content"]
    
    with open(f"{output_dir}/evaluation.json", "w") as f:
        json.dump(evaluation_with_tokens, f, indent=4)
    
    # Also save reasoning content separately if it exists
    reasoning_content = evaluation.get("reasoning_content")
    if reasoning_content:
        with open(f"{output_dir}/reasoning.txt", "w") as f:
            f.write(reasoning_content)
        
        # Save the full raw response for debugging
        full_response = None
        for msg in conversation:
            if msg.get("role") == "assistant" and "content" in msg:
                full_response = msg["content"]
                break
        
        if full_response:
            with open(f"{output_dir}/full_response.txt", "w") as f:
                f.write(full_response)

def calculate_token_statistics():
    """Calculate and display token usage statistics across all examples"""
    token_data = {
        "calendar": {"total_tokens": 0, "reasoning_tokens": 0, "count": 0},
        "meeting": {"total_tokens": 0, "reasoning_tokens": 0, "count": 0},
        "trip": {"total_tokens": 0, "reasoning_tokens": 0, "count": 0},
        "zebralogic": {"total_tokens": 0, "reasoning_tokens": 0, "count": 0},
        "overall": {"total_tokens": 0, "reasoning_tokens": 0, "count": 0}
    }
    
    # Scan all evaluation files to collect token data
    for task in ["calendar", "meeting", "trip", "zebralogic"]:
        task_dir = f"../output/Plan"
        if not os.path.exists(task_dir):
            continue
            
        for model_name in os.listdir(task_dir):
            model_dir = os.path.join(task_dir, model_name, task, "n_pass")
            if not os.path.exists(model_dir):
                continue
                
            for example_id in os.listdir(model_dir):
                example_dir = os.path.join(model_dir, example_id)
                if not os.path.isdir(example_dir):
                    continue
                    
                for pass_dir in os.listdir(example_dir):
                    if pass_dir.endswith("_pass") and os.path.isdir(os.path.join(example_dir, pass_dir)):
                        eval_file = os.path.join(example_dir, pass_dir, "evaluation.json")
                        if os.path.exists(eval_file):
                            try:
                                with open(eval_file, 'r') as f:
                                    eval_data = json.load(f)
                                    if "timing" in eval_data:
                                        token_data[task]["total_tokens"] += eval_data["timing"].get("total_tokens", 0)
                                        token_data[task]["reasoning_tokens"] += eval_data["timing"].get("reasoning_tokens", 0)
                                        token_data[task]["count"] += 1
                                        
                                        token_data["overall"]["total_tokens"] += eval_data["timing"].get("total_tokens", 0)
                                        token_data["overall"]["reasoning_tokens"] += eval_data["timing"].get("reasoning_tokens", 0)
                                        token_data["overall"]["count"] += 1
                            except Exception as e:
                                logging.warning(f"Error reading evaluation file {eval_file}: {e}")
    
    # Print statistics
    print("\n=== Token Usage Statistics ===")
    for task in ["calendar", "meeting", "trip", "zebralogic", "overall"]:
        if token_data[task]["count"] > 0:
            avg_total = token_data[task]["total_tokens"] / token_data[task]["count"]
            avg_reasoning = token_data[task]["reasoning_tokens"] / token_data[task]["count"]
            reasoning_percentage = (avg_reasoning / avg_total * 100) if avg_total > 0 else 0
            
            print(f"\n{task.capitalize()}:")
            print(f"  Examples processed: {token_data[task]['count']}")
            print(f"  Average total tokens per response: {avg_total:.1f}")
            print(f"  Average reasoning tokens per response: {avg_reasoning:.1f}")
            print(f"  Reasoning percentage: {reasoning_percentage:.1f}%")

async def process_single_example(
    example_id,
    example,
    constraints,
    model,
    max_passes,
    rate_limiter,
    semaphore,
    task,
    args
):
    """Process a single example with iterative refinement"""
    # Initialize variables that might be referenced in error handling
    gold_text = ""
    gold_formatted = {}
    pred_formatted = {}
    violated_constraints = {}
    is_exact_match = False
    constraints_satisfied = False
    response_text = ""
    pass_num = 0
    
    try:
        logging.info(f"[{example_id}] Starting processing with model {model}")
        
        # Create output directory
        output_dir = f"../output/Plan/{model}/{task}/token_pass/{example_id}"
        os.makedirs(output_dir, exist_ok=True)
        
        # Initialize AI model (outside semaphore to allow parallel initialization)
        try:
            logging.info(f"[{example_id}] About to initialize model...")
            ai = initialize_model(model, keys)
            logging.info(f"[{example_id}] Model initialized successfully")
        except Exception as e:
            logging.error(f"[{example_id}] Failed to initialize model: {str(e)}")
            # Save error evaluation result
            error_eval_result = {
                "has_execution_error": True,
                "execution_output": f"Model initialization failed: {str(e)}",
                "pred": {},
                "gold": {},
                "status": "Model initialization error",
                "violated_constraint": {},
                "is_exact_match": False,
                "constraints_satisfied": False,
                "pass_number": 0,
                "timing": {
                    "total_tokens": 0,
                    "reasoning_tokens": 0
                },
                "reasoning_content": ""
            }
            with open(f"{output_dir}/1_pass/evaluation.json", "w") as f:
                json.dump(error_eval_result, f, indent=4)
            return
        
        # Initialize conversation history
        conversation_history = []
        
        # Get gold answer text (for reference only, not for exact match)
        gold_text = extract_gold_answer(example, task)
        if gold_text:
            if isinstance(gold_text, (dict, list)):
                gold_preview = json.dumps(gold_text, ensure_ascii=False)[:100]
            else:
                gold_preview = str(gold_text)[:100]
        else:
            gold_preview = "None"
        logging.info(f"[{example_id}] Pass {pass_num} gold text: {gold_preview}...")

        # Initial prompt with task-specific formatting
        golden_plan = example.get("golden_plan", {})
        prompt = example.get("prompt_0shot", "")
        current_prompt = add_json_formatting_instruction(prompt, task, golden_plan)
        
        for pass_num in range(1, max_passes + 1):
            pass_start_time = time.time()
            logging.info(f"[{example_id}] Starting pass {pass_num}")
            
            # Create output directory for this pass
            pass_output_dir = f"{output_dir}/{pass_num}_pass"
            os.makedirs(pass_output_dir, exist_ok=True)
            
            # Get response from model with rate limiting (use semaphore only for API calls)
            api_call_start = time.time()
            retry_count = 0
            max_retries = 3
            while retry_count < max_retries:
                try:
                    logging.info(f"[{example_id}] Making API call (attempt {retry_count + 1})")
                    # Use semaphore only for the actual API call
                    async with semaphore:
                        response_text, api_time, full_token_count, reasoning_content, reasoning_tokens = await run_model_with_rate_limit(ai, current_prompt, rate_limiter, model)
                    logging.info(f"[{example_id}] API call successful")
                    break
                except Exception as e:
                    retry_count += 1
                    logging.warning(f"[{example_id}] API error in pass {pass_num} (attempt {retry_count}): {e}")
                    if retry_count >= max_retries:
                        logging.error(f"[{example_id}] Max retries reached, giving up")
                        # Save error evaluation result
                        error_eval_result = {
                            "has_execution_error": True,
                            "execution_output": f"Max API retries ({max_retries}) reached in pass {pass_num}",
                            "pred": {},
                            "gold": {},
                            "status": "API retry limit exceeded",
                            "violated_constraint": {},
                            "is_exact_match": False,
                            "constraints_satisfied": False,
                            "pass_number": pass_num,
                            "timing": {
                                "total_tokens": 0,
                                "reasoning_tokens": 0
                            },
                            "reasoning_content": ""
                        }
                        with open(f"{pass_output_dir}/evaluation.json", "w") as f:
                            json.dump(error_eval_result, f, indent=4)
                        return
                    await asyncio.sleep(5)
                    try:
                        ai = initialize_model(model, keys)
                        logging.info(f"[{example_id}] Model reinitialized after error")
                    except Exception as init_error:
                        logging.error(f"[{example_id}] Failed to reinitialize model: {str(init_error)}")
                        # Save error evaluation result
                        error_eval_result = {
                            "has_execution_error": True,
                            "execution_output": f"Model reinitialization failed: {str(init_error)}",
                            "pred": {},
                            "gold": {},
                            "status": "Model reinitialization error",
                            "violated_constraint": {},
                            "is_exact_match": False,
                            "constraints_satisfied": False,
                            "pass_number": pass_num,
                            "timing": {
                                "total_tokens": 0,
                                "reasoning_tokens": 0
                            },
                            "reasoning_content": ""
                        }
                        with open(f"{pass_output_dir}/evaluation.json", "w") as f:
                            json.dump(error_eval_result, f, indent=4)
                        return
            
            api_call_time = time.time() - api_call_start
            logging.info(f"[{example_id}] Pass {pass_num} API call completed - {api_call_time:.2f}s")
            
            # Add to conversation history
            conversation_history.append({"role": "user", "content": current_prompt})
            conversation_history.append({
                "role": "assistant", 
                "content": response_text,
                "reasoning_content": reasoning_content,
                "reasoning_tokens": reasoning_tokens,
                "total_tokens": full_token_count
            })
            
            # Save conversation
            with open(f"{pass_output_dir}/conversation.json", "w") as f:
                json.dump(conversation_history, f, indent=4)
            
            # Extract prediction
            try:
                pred_formatted = extract_answer_from_text(response_text, task, golden_plan)
                logging.info(f"[{example_id}] Pass {pass_num} extracted prediction: {pred_formatted}")
            except Exception as e:
                logging.error(f"[{example_id}] Pass {pass_num} failed to extract prediction: {str(e)}")
                pred_formatted = {}
            
            # Save plan
            with open(f"{pass_output_dir}/plan.json", "w") as f:
                json.dump(pred_formatted, f, indent=4)
            
            # Set num_people_to_meet from constraints for meeting tasks
            if task == "meeting":
                # Use num_people_to_meet from constraints if available, otherwise use people_to_meet length
                if "num_people_to_meet" not in constraints:
                    people_to_meet = constraints.get("people_to_meet", [])
                    constraints["num_people_to_meet"] = len(people_to_meet)
            
            # Evaluate constraints
            if task == "calendar":
                constraints_satisfied, violated_constraints = evaluate_calendar(constraints, pred_formatted)
            elif task == "meeting":
                constraints_satisfied, violated_constraints = evaluate_meeting(constraints, pred_formatted)
            elif task == "trip":
                constraints_satisfied, violated_constraints = evaluate_trip(constraints, pred_formatted)
            elif task == "zebralogic":
                constraints_satisfied, violated_constraints = evaluate_zebralogic(constraints, pred_formatted)
            
            logging.info(f"[{example_id}] Pass {pass_num} constraints satisfied: {constraints_satisfied}")
            logging.info(f"[{example_id}] Pass {pass_num} violated constraints: {violated_constraints}")
            
            # After (handles both dict and string; avoids passing dict into text extractor)
            if task == "zebralogic":
                # Prefer the golden plan object directly
                gp = example.get("golden_plan", {})
                # Unwrap if wrapped
                if isinstance(gp, dict) and "solution" in gp:
                    gp = gp["solution"]
                gold_formatted = parse_zebralogic_golden(gp)
            else:
                if gold_text:
                    gold_formatted = extract_answer_from_text(gold_text, task, golden_plan)
            
            # Check exact match
            if gold_formatted and pred_formatted:
                is_exact_match = check_exact_match(gold_formatted, pred_formatted, task)
            else:
                is_exact_match = False
            logging.info(f"[{example_id}] Pass {pass_num} exact match: {is_exact_match}")
            
            # Determine status - check exact match first, then constraints
            if is_exact_match:
                status = "Exact match"
                constraints_satisfied = True  # Exact match implies constraints are satisfied
            elif constraints_satisfied:
                status = "Correct plan (constraints satisfied)"
            else:
                status = "Wrong plan"
            
            # Save evaluation
            eval_result = {
                "has_execution_error": False,
                "execution_output": response_text,
                "pred": pred_formatted,
                "gold": (gold_formatted or {}),
                "status": status,
                "violated_constraint": violated_constraints,
                "is_exact_match": is_exact_match,
                "constraints_satisfied": constraints_satisfied,
                "pass_number": pass_num,
                "timing": {
                    "api_call_time": api_time,
                    "total_tokens": full_token_count,
                    "reasoning_tokens": reasoning_tokens
                },
                "reasoning_content": reasoning_content
            }
            with open(f"{pass_output_dir}/evaluation.json", "w") as f:
                json.dump(eval_result, f, indent=4)
            
            if is_exact_match or constraints_satisfied:
                if is_exact_match:
                    logging.info(f"[{example_id}] SUCCESS! Exact match in pass {pass_num}")
                else:
                    logging.info(f"[{example_id}] SUCCESS! Constraints satisfied in pass {pass_num}")
                return
            else:
                logging.info(f"[{example_id}] Pass {pass_num} failed both exact match and constraints, preparing feedback")
                # Prepare feedback for next iteration
                constraint_feedback = format_constraint_feedback(violated_constraints, task)
                current_prompt = f"The previous solution produced the following output:\n{response_text}\n{constraint_feedback}\n\nPlease revise your solution to satisfy these constraints."
        
        logging.warning(f"[{example_id}] FAILED to solve within {max_passes} passes")
        
        # Save final evaluation result even if we failed to solve
        if 'pred_formatted' in locals():
            final_eval_result = {
                "has_execution_error": False,
                "execution_output": response_text,
                "pred": pred_formatted,
                "gold": gold_formatted if 'gold_formatted' in locals() else {},
                "status": "Failed to solve within max passes",
                "violated_constraint": violated_constraints,
                "is_exact_match": is_exact_match,
                "constraints_satisfied": constraints_satisfied,
                "pass_number": pass_num,
                "timing": {
                    "total_tokens": full_token_count,
                    "reasoning_tokens": reasoning_tokens
                },
                "reasoning_content": reasoning_content
            }
            with open(f"{pass_output_dir}/evaluation.json", "w") as f:
                json.dump(final_eval_result, f, indent=4)
            logging.info(f"[{example_id}] Saved final evaluation result from pass {pass_num}")
        
        return
        
    except Exception as e:
        logging.error(f"[{example_id}] Unexpected error: {str(e)}")
        # Save error evaluation result
        try:
            error_eval_result = {
                "has_execution_error": True,
                "execution_output": f"Unexpected error: {str(e)}",
                "pred": {},
                "gold": {},
                "status": "Unexpected error",
                "violated_constraint": {},
                "is_exact_match": False,
                "constraints_satisfied": False,
                "pass_number": 0,
                "timing": {
                    "total_tokens": 0,
                    "reasoning_tokens": 0
                },
                "reasoning_content": ""
            }
            # Try to save to first pass directory, create if needed
            first_pass_dir = f"{output_dir}/1_pass"
            os.makedirs(first_pass_dir, exist_ok=True)
            with open(f"{first_pass_dir}/evaluation.json", "w") as f:
                json.dump(error_eval_result, f, indent=4)
            logging.info(f"[{example_id}] Saved error evaluation result")
        except Exception as save_error:
            logging.error(f"[{example_id}] Failed to save error evaluation: {str(save_error)}")
        return

async def main():
    """Main function"""
    args = parse_args()
    
    # Load examples and constraints
    examples = load_examples(args.task)
    constraints = load_constraints(args.task)
    
    # Filter examples based on arguments
    if args.examples:
        example_numbers = [int(x) for x in args.examples.split(',')]
        examples = {k: v for k, v in examples.items() if any(str(num) in k for num in example_numbers)}
    elif args.start is not None or args.end is not None:
        start = args.start or 0
        end = args.end or len(examples)
        example_items = list(examples.items())[start:end]
        examples = dict(example_items)
    
    logging.info(f"Starting processing of {len(examples)} examples")
    
    # Initialize rate limiter and semaphore
    rate_limiter = RateLimiter(args.rate_limit)
    semaphore = asyncio.Semaphore(args.max_concurrent)
    
    # Process examples in parallel
    tasks = []
    for example_id, example in examples.items():
        logging.info(f"Creating task for {example_id}")
        task = asyncio.create_task(
            process_single_example(
                example_id,
                example,
                constraints.get(example_id, {}),
                args.model,
                args.max_passes,
                rate_limiter,
                semaphore,
                args.task,
                args
            )
        )
        tasks.append(task)
        logging.info(f"Task created for {example_id}")
    
    logging.info(f"All {len(tasks)} tasks created, waiting for completion...")
    
    # Wait for all tasks to complete
    results = await asyncio.gather(*tasks, return_exceptions=True)
    
    # Log results
    success_count = sum(1 for r in results if not isinstance(r, Exception))
    error_count = len(results) - success_count
    logging.info(f"Completed processing {len(results)} examples: {success_count} successful, {error_count} failed")
    
    # Calculate and print token statistics
    calculate_token_statistics()

def load_examples(task):
    """Load examples for the specified task"""
    if task == "calendar":
        with open("../data/calendar_scheduling_100.json", 'r') as f:
            return json.load(f)
    elif task == "meeting":
        with open("../data/meeting_planning_100.json", 'r') as f:
            return json.load(f)
    elif task == "trip":
        with open("../data/trip_planning_100.json", 'r') as f:
            return json.load(f)
    elif task == "zebralogic":
        with open("../data/zebralogic_sample_100.json", 'r') as f:
            zebra_data = json.load(f)
            examples = {}
            for example_id, example in zebra_data.items():
                examples[example_id] = {
                    "prompt_0shot": example["prompt_0shot"],
                    "golden_plan": example["golden_plan"]
                }
            return examples
    else:
        raise ValueError(f"Unknown task: {task}")

def load_constraints(task):
    """Load constraints from the appropriate JSON file - consistent with SMT program"""
    task_name_map = {
        "calendar": "calendar_scheduling",
        "trip": "trip_planning",
        "meeting": "meeting_planning",
        "zebralogic": "zebralogic_sample"
    }
    if task == "zebralogic":
        with open(f"../data/{task_name_map[task]}_100.json") as f:
            zebra_data = json.load(f)
            constraints = {}
            for example_id, example in zebra_data.items():
                # flatten to match other tasks
                constraints[example_id] = {
                    "golden_plan": example.get("golden_plan"),
                    "meta": example.get("meta", {})
                }
            return constraints
    else:
        with open(f"../data/{task_name_map[task]}_100_constraints.json") as f:
            constraints_data = json.load(f)
            return {example_id: data.get("constraints", {}) for example_id, data in constraints_data.items()}

def extract_gold_answer(example, task):
    if task == "zebralogic":
        # return the object; downstream extractor already handles both shapes
        return example.get("golden_plan", {})
    return example.get("golden_plan", "")

if __name__ == "__main__":
    asyncio.run(main())
