import argparse
import json
import os
import re
# from langfuse.openai import AsyncOpenAI
# from langfuse.openai import openai
from openai import AsyncOpenAI
import time
from dotenv import load_dotenv
from tqdm import tqdm
import logging
import uuid
import asyncio
from prompts import generate_ontology_description, COMBINED_ANNOTATION_PROMPT_TEMPLATE
os.environ["LANGFUSE_TRACING_ENABLED"] = "false"
# Set Langfuse configuration to disable tracing
# openai.langfuse_enabled = False

SEARCH_RESULT_ANALYSIS_PROMPT = """
You are an expert data analyst. Your task is to evaluate the quality of a search result based on the query that produced it.

**Search Query:**
{query}

**Search Result Documents:**
```json
{documents_json}
```

---
**Your Task:**
Analyze the search result's sufficiency and clarity.

1.  **Sufficiency:** Does the result contain enough information to likely answer the user's implicit question in the query? Choose one:
    *   `Sufficient`: The answer seems to be present.
    *   `Insufficient`: The answer is likely not here.

2.  **Clarity:** Is the information clear or does it create confusion? Choose one:
    *   `Clear`: The information is straightforward and addresses one subject.
    *   `Unclear`: The results mention multiple distinct entities that could match the query (e.g., two movies with the same title) or the information is vague.
    *   Provide a brief justification for your clarity rating.

---
**Your Final Output:**
Your response **MUST** be a single, valid JSON object with the attributes `information_quality`, `information_clarity`, and `clarity_justification`.

```json
{{
  "information_quality": "<'Sufficient'/'Insufficient'>",
  "information_clarity": "<'Clear'/'Unclear'>",
  "clarity_justification": "<Your brief explanation for the clarity rating>"
}}
```
"""

async def annotate_search_result(step, trace, model, max_retries=3, delay=5):
    """
    A dedicated function to analyze a search_result step for quality and clarity.
    """
    # Find the query that generated this result
    parent_index = step.get('trace_dependency', {}).get('dependent_on')
    if parent_index is None or not (0 <= parent_index < len(trace)):
        return # Cannot analyze without the source query

    query_step = trace[parent_index]
    if query_step.get("type") != "search":
        return # The dependency was not a search step

    prompt = SEARCH_RESULT_ANALYSIS_PROMPT.format(
        query=query_step.get("query", ""),
        documents_json=json.dumps(step.get("documents", []), indent=2)
    )

    llm_response = await call_llm(prompt, model, max_retries, delay) # Reusing the robust call_llm

    if llm_response:
        if "annotation" not in step:
            step["annotation"] = {"type": "search_result"}
        if "attributes" not in step["annotation"]:
            step["annotation"]["attributes"] = {}
        
        # Update attributes from the focused LLM call
        step["annotation"]["attributes"]["information_quality"] = llm_response.get("information_quality", "Unspecified")
        step["annotation"]["attributes"]["information_clarity"] = llm_response.get("information_clarity", "Unspecified")
        step["annotation"]["attributes"]["clarity_justification"] = llm_response.get("clarity_justification", "Unspecified")


JUDGE_PROMPT = """
You are an expert evaluator. Your task is to determine if the provided 'answer' is correct based on the 'ground_truth'.
The answer does not need to be a perfect match, but it must be semantically correct and capture the key information from the ground truth.

**Ground Truth:**
{ground_truth}

**Answer:**
{answer}

Based on the comparison, is the answer correct?
Your response must be a single JSON object with a single boolean key: "is_correct".

{{
    "is_correct": <yes_or_no>
}}
"""

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
session_id = uuid.uuid4()
# Load environment variables and initialize OpenAI client
load_dotenv()
client = AsyncOpenAI(
    api_key=os.getenv("OPENAI_API_KEY"),
    base_url=os.getenv("OPENAI_BASE") or None,
)

async def call_llm(prompt, model, max_retries=3, delay=5):
    """Calls the OpenAI API with retry logic and the most robust JSON parsing."""
    for attempt in range(max_retries):
        try:
            response = await client.chat.completions.create(
                model=model,
                messages=[{"role": "user", "content": prompt}],
                temperature=0.0,
                response_format={"type": "json_object"},
                # metadata={"langfuse_session_id": session_id}
            )
            content = response.choices[0].message.content

            # --- Final robust parsing logic ---
            # Greedily find the substring that looks like a valid JSON object
            match = re.search(r'\{.*\}', content, re.DOTALL)
            if match:
                json_str = match.group(0)
                try:
                    # Test if the extracted string is valid JSON
                    return json.loads(json_str)
                except json.JSONDecodeError as e:
                    logging.error(f"Extracted JSON-like string failed to parse on attempt {attempt + 1}/{max_retries}: {e}")
                    logging.error(f"Problematic extracted string: >>>\n{json_str}\n<<<")
            else:
                 logging.error(f"Could not find any JSON-like object in the response on attempt {attempt + 1}/{max_retries}.")
                 logging.error(f"Full problematic content: >>>\n{content}\n<<<")

            if attempt + 1 == max_retries: return None
            time.sleep(delay)

        except Exception as e:
            logging.error(f"API call failed on attempt {attempt + 1}/{max_retries}: {e}")
            if attempt + 1 == max_retries: return None
            time.sleep(delay)
    return None

async def call_judge_llm(ground_truth, answer, model):
    """Calls the LLM to judge the correctness of an answer."""
    prompt = JUDGE_PROMPT.format(ground_truth=ground_truth, answer=answer)
    llm_response = await call_llm(prompt, model) # Reusing the existing robust call_llm function
    if llm_response and isinstance(llm_response.get("is_correct"), bool):
        return llm_response.get("is_correct")
    else:
        logging.warning(f"Judge LLM call failed or returned malformed data. Response: {llm_response}")
        return None 


def get_concise_history(trace, current_index, max_doc_len=200):
    """Creates a clean and concise history of steps for the prompt."""
    previous_steps_for_prompt = []
    for step in trace[:current_index]:
        # Create a clean copy of the step, removing annotations
        clean_step = {k: v for k, v in step.items() if k not in ['annotation', 'trace_dependency']}
        
        # If it's a search_result, make the documents concise
        if clean_step.get("type") == "search_result":
            if "documents" in clean_step and isinstance(clean_step["documents"], list):
                concise_docs = []
                for doc in clean_step["documents"]:
                    concise_doc = doc.copy()
                    if "content" in concise_doc and isinstance(concise_doc["content"], str):
                        concise_doc["content"] = concise_doc["content"][:max_doc_len] + "..."
                    concise_docs.append(concise_doc)
                clean_step["documents"] = concise_docs
        
        previous_steps_for_prompt.append(clean_step)
    return previous_steps_for_prompt

async def annotate_step(trace_with_indices, step_index, ontology, lock, ground_truth=None, model=None):
    """Annotates a single step for both its node type and its error dependency."""
    current_step = trace_with_indices[step_index]
    step_type = current_step.get("type")

    # --- New Direct Annotation for search_result ---
    if step_type == "search_result":
        # First, establish the deterministic dependency link
        if step_index > 0:
            current_step["trace_dependency"] = {"dependent_on": step_index - 1}
        else:
            current_step["trace_dependency"] = {"dependent_on": None}
        
        # Now, call the dedicated analysis function
        await annotate_search_result(current_step, trace_with_indices, model)
        return # End processing for this step here

    elif step_type.endswith("Answer"): # Handles CorrectAnswer and IncorrectAnswer
        current_step["trace_dependency"] = {"dependent_on": step_index - 1 if step_index > 0 else None}
        current_step["annotation"] = {
            "type": step_type,
            "justification": "System-generated answer based on correctness."
        }
        # LLM as a Judge
        if ground_truth and step_type == "IncorrectAnswer":
            is_correct = await call_judge_llm(ground_truth, current_step["content"], model)
            # If the LLM judge returns a boolean, we update the annotation.
            # Otherwise, we don't add the key.
            if is_correct is not None:
                current_step["annotation"]["is_correct"] = is_correct
            if is_correct:
                # remove IncorrectAnswer and change the type to CorrectAnswer
                current_step["type"] = "CorrectAnswer"
                current_step["annotation"]["justification"] = "System-generated answer based on correctness."
                current_step["content"] = current_step["content"]
                current_step["annotation"]["type"] = "CorrectAnswer"
        return

    # Only use the LLM for complex annotation of 'reasoning' and 'search' steps
    if step_type not in ["reasoning", "search"]:
        # For other simple steps that we don't have deterministic rules for yet,
        # we can just assign a default dependency and skip LLM annotation.
        if step_index > 0:
            current_step["trace_dependency"] = {"dependent_on": step_index - 1}
        else:
            current_step["trace_dependency"] = {"dependent_on": None}
        current_step["annotation"] = {"type": step_type, "justification": "Dependency inferred by position, not analyzed by LLM."}
        return

    # The prompt context should not include the annotations of previous steps
    # to avoid confusing the LLM. We create a clean version for the prompt.
    previous_steps_for_prompt = get_concise_history(trace_with_indices, step_index)

    # If the current step is a search, add the search count to the prompt.
    search_count_info = ""
    if step_type == "search":
        search_count = sum(1 for step in trace_with_indices[:step_index + 1] if step.get("type") == "search")
        previous_search_queries = [step.get("query") for step in trace_with_indices[:step_index] if step.get("type") == "search"]
        search_count_info = f"\nThis is search number {search_count} in the session. The previous search queries were: {previous_search_queries}."

    prompt = COMBINED_ANNOTATION_PROMPT_TEMPLATE.format(
        ontology_str=generate_ontology_description(ontology, step_type),
        previous_steps_json=json.dumps(previous_steps_for_prompt, indent=2),
        current_step_json=json.dumps({k: v for k, v in current_step.items() if k not in ['annotation', 'trace_dependency']}, indent=2),
        current_step_index=step_index,
        search_count_info=search_count_info
    )

    llm_response = None
    max_retries = 3
    for attempt in range(max_retries):
        response = await call_llm(prompt, model)

        # Basic response check
        if not response:
            logging.warning(f"LLM call for step {step_index} returned no response on attempt {attempt + 1}/{max_retries}.")
            if attempt < max_retries - 1:
                await asyncio.sleep(1)
            continue

        # --- Validation ---
        # 1. Annotation structure validation
        ann_data = response.get("annotation", {})
        if "type" not in ann_data or "justification" not in ann_data:
            logging.warning(f"Attempt {attempt + 1}/{max_retries}: Malformed annotation for step {step_index} (missing type/justification).")
            if attempt < max_retries - 1:
                await asyncio.sleep(1)
            continue
            
        # 2. Annotation type validation against ontology
        step_type_from_llm = ann_data.get("type")
        actual_step_type = current_step.get("type")
        if step_type_from_llm not in ontology.get(actual_step_type, {}):
            logging.warning(f"Attempt {attempt + 1}/{max_retries}: Invalid annotation type '{step_type_from_llm}' for step {step_index} of type {actual_step_type}.")
            if attempt < max_retries - 1:
                await asyncio.sleep(1)
            continue
        
        # All checks passed
        llm_response = response
        break

    if not llm_response:
        logging.error(f"LLM annotation failed for step {step_index} after {max_retries} attempts. Skipping annotations.")
        current_step["annotation"] = {"error": f"LLM validation failed after {max_retries} retries"}
        current_step["trace_dependency"] = {"dependent_on": step_index - 1}
        return

    # --- Annotation Data Validation ---
    annotation_data = llm_response.get("annotation", {})
    step_type_from_llm = annotation_data.get("type") 
    current_step["annotation"] = annotation_data

    # --- Dependency Edge Data Validation ---
    # We can trust the data here because it was validated in the retry loop.
    # dependency_data = llm_response.get("trace_dependency", {})
    # dependent_on_index = dependency_data.get("dependent_on")
    if step_index == 0:
         current_step["trace_dependency"] = {"dependent_on": None}
    else:
        current_step["trace_dependency"] = {"dependent_on": step_index - 1}

    # --- Process Attributes ---
    if "annotation" in current_step and "attributes" in llm_response.get("annotation", {}):
        attributes = llm_response["annotation"].get("attributes", {})
        
        # We can clean up any null values.
        current_step["annotation"]["attributes"] = {
            k: v for k, v in attributes.items() if v is not None
        }
        
        # --- Attribute Back-propagation ---
        # This logic handles cases where an attribute identified in the current step
        # actually describes a property of a *previous* step.
        # This is no longer needed for search_result attributes.
        attributes_to_backpropagate = [] # Formerly: ["information_quality", "information_clarity"]
        
        for attr_name in attributes_to_backpropagate:
            if attr_name in current_step["annotation"].get("attributes", {}):
                if step_index > 0:
                    # The most common dependency is the previous step, but we respect the dependency edge
                    dependency_edge = current_step.get("trace_dependency", {})
                    dependent_on_index = dependency_edge.get("dependent_on") if dependency_edge else step_index - 1
                    
                    if dependent_on_index is not None and 0 <= dependent_on_index < len(trace_with_indices):
                        previous_step = trace_with_indices[dependent_on_index]
                if previous_step.get("type") == "search_result":
                    if "annotation" not in previous_step:
                        previous_step["annotation"] = {"type": "search_result"}
                    if "attributes" not in previous_step["annotation"]:
                        previous_step["annotation"]["attributes"] = {}
                            
                    # Back-propagate the attribute
                    previous_step["annotation"]["attributes"][attr_name] = current_step["annotation"]["attributes"][attr_name]
            
            # We can remove it from the current step after back-propagation
            # as it describes the prior step.
                    del current_step["annotation"]["attributes"][attr_name]


PREMISE_EVALUATION_PROMPT = """
You are a ruthless, evidence-based critical thinking expert. Your task is to evaluate the factual premise of an AI agent's reasoning based *only* on the evidence it had at the time.

**Context: The Agent's Goal (Original Question):**
{question}

**Evidence: The Search Results the Agent Had Access To:**
```json
{search_evidence_json}
```

**Agent's Reasoning Text to Analyze:**
"{reasoning_text}"

---
Task:
1) Extract the atomic factual premises from the step (skip meta/plan-only wording that contains no factual claim).
2) For each premise, find a direct supporting span in the provided evidence. If no exact or near-verbatim support exists, mark that premise as unmatched.
3) Decide the label with STRICT rules:
   - Directly Grounded: ALL atomic premises are supported by explicit evidence spans.
   - Not Grounded: ANY atomic premise lacks a supporting span; OR the step contains only meta/plan text without factual premises.

Additional rules:
- QUESTION anchor alone is NOT sufficient for Directly Grounded; do not label as grounded solely for restating the task/intent.
- Superlatives/temporal/quantitative claims (e.g., last/first/only, years, counts) require explicit evidence spans.

---
Return a single JSON object:
{{
  "premise_grounding": "<Directly Grounded|Not Grounded>",
  "anchor_type": "<EVIDENCE|QUESTION|TRACE|NONE>",
  "evidence_citations": [
    {{"premise": "...", "evidence_snippet": "..."}}
  ],
  "unmatched_premises": ["..."],
  "premise_justification": "<brief explanation referencing the citations or explaining unmatched>"
}}

"""

async def evaluate_reasoning_premise(step, trace, question, model):
    """
    A dedicated function to perform a second pass analysis on a reasoning step's premise.
    """
    reasoning_text = step.get("content")
    if not reasoning_text:
        return
        
    # Find all search evidence that the agent had access to up to this point
    search_evidence = []
    step_index = step.get('step_index', 0)
    
    # Collect all search results from previous steps
    for i in range(step_index):
        prev_step = trace[i]
        if prev_step.get("type") == "search_result" and "documents" in prev_step:
            search_evidence.extend(prev_step.get("documents", []))
    
    # If no search evidence found, try to get the immediate parent search result
    if not search_evidence:
        parent_index = step.get('trace_dependency', {}).get('dependent_on')
        if parent_index is not None and (0 <= parent_index < len(trace)):
            parent_step = trace[parent_index]
            if parent_step.get("type") == "search_result":
                search_evidence = parent_step.get("documents", [])
    
    # If still no evidence, create a placeholder
    if not search_evidence:
        search_evidence = [{"error": "No search evidence found for this reasoning step"}]

    prompt = PREMISE_EVALUATION_PROMPT.format(
        question=question,
        search_evidence_json=json.dumps(search_evidence, indent=2),
        reasoning_text=reasoning_text
    )
    
    llm_response = await call_llm(prompt, model) # Reusing the robust call_llm

    if llm_response:
        # Ensure attributes container exists
        if "annotation" not in step:
            step["annotation"] = {"type": step.get("type", None)}
        if "attributes" not in step.get("annotation", {}):
            step["annotation"]["attributes"] = {}

        # Extract schema with backward compatibility
        llm_pg = llm_response.get("premise_grounding", llm_response.get("premiseGrounding"))
        llm_anchor = llm_response.get("anchor_type", llm_response.get("anchorType", "NONE"))
        llm_citations = llm_response.get("evidence_citations", llm_response.get("citations", []))
        llm_unmatched = llm_response.get("unmatched_premises", llm_response.get("unmatchedPremises", []))
        llm_just = (
            llm_response.get("premise_justification")
            or llm_response.get("justification")
            or "Unspecified"
        )

        # Guardrail: For reasoning step types, require evidence citations for Directly Grounded
        annotation_type = step.get("annotation", {}).get("type")
        if annotation_type in ["StateAssessment", "PlanFormation", "InformationSynthesis", "CritiqueAndCorrection"]:
            if not llm_citations:
                llm_pg = "Not Grounded"
            if isinstance(llm_unmatched, list) and len(llm_unmatched) > 0:
                llm_pg = "Not Grounded"

        # Final write with explicit fields
        step["annotation"]["attributes"]["premise_grounding"] = llm_pg or "Unspecified"
        step["annotation"]["attributes"]["anchor_type"] = llm_anchor or "NONE"
        step["annotation"]["attributes"]["evidence_citations"] = llm_citations or []
        step["annotation"]["attributes"]["unmatched_premises"] = llm_unmatched or []
        step["annotation"]["attributes"]["premise_justification"] = llm_just


async def annotate_trace(trace, ontology, lock, ground_truth=None, question=None, model=None):
    """Processes all annotations for a single trace."""
    # First, add a step_index to each step for clear referencing.
    for i, step in enumerate(trace):
        step['step_index'] = i

    # --- PASS 1: Basic Annotation (The "What") ---
    for i in range(len(trace)):
        await annotate_step(trace, i, ontology, lock, ground_truth, model)

    # --- PASS 2: Premise Evaluation (The "Why") ---
    for step in trace:
        # We only need to evaluate the premise for reasoning steps.
        annotation_type = step.get("annotation", {}).get("type")
        if annotation_type in ["StateAssessment", "PlanFormation", "InformationSynthesis", "CritiqueAndCorrection"]:
             await evaluate_reasoning_premise(step, trace, question, model)
             
    # Also evaluate premise for reasoning steps that might not have been properly annotated
    for step in trace:
        if step.get("type") == "reasoning" and "premise_grounding" not in step.get("annotation", {}).get("attributes", {}):
            annotation_type = step.get("annotation", {}).get("type")
            if annotation_type in ["StateAssessment", "PlanFormation", "InformationSynthesis", "CritiqueAndCorrection"]:
                await evaluate_reasoning_premise(step, trace, question, model)
             
    return trace


async def main():
    parser = argparse.ArgumentParser(description="Annotate agent traces using OpenAI.")
    parser.add_argument("--input_file", required=True, help="Path to the input .jsonl file.")
    parser.add_argument("--output_file", required=True, help="Path to the output .jsonl file.")
    parser.add_argument("--ontology_file", default="trace_annotator/ontology.json", help="Path to the ontology JSON file.")
    parser.add_argument("--concurrency", type=int, default=30, help="Number of concurrent requests to OpenAI.")
    parser.add_argument("--model", default="gpt-4.1-mini", help="Model to use for annotation.")
    args = parser.parse_args()

    try:
        with open(args.ontology_file, 'r') as f:
            ontology = json.load(f)
        logging.info(f"Successfully loaded ontology from {args.ontology_file}")
    except (FileNotFoundError, json.JSONDecodeError) as e:
        logging.error(f"Failed to load ontology from {args.ontology_file}: {e}")
        return

    with open(args.input_file, 'r') as infile:
        traces = [json.loads(line) for line in infile]

    ontology_lock = asyncio.Lock()
    semaphore = asyncio.Semaphore(args.concurrency)

    async def process_and_annotate_trace(trace_data):
        async with semaphore:
            try:
                # Make a deep copy to avoid modifying original data in case of partial failure
                processed_trace = json.loads(json.dumps(trace_data['trace']))
                ground_truth = trace_data.get('ground_truth')
                question = trace_data.get('question') # <-- Get the question

                # Pass the question to the annotation function
                processed_trace = await annotate_trace(processed_trace, ontology, ontology_lock, ground_truth, question=question, model=args.model)

                trace_data['trace'] = processed_trace
                return trace_data
            except Exception as e:
                logging.error(f"A critical error occurred while processing a trace: {e}. Trace was not written to output.", exc_info=True)
                logging.error(f"Original trace that caused failure: {json.dumps(trace_data)}")
                return None

    tasks = [process_and_annotate_trace(trace_data) for trace_data in traces]

    with open(args.output_file, 'w') as outfile:
        for future in tqdm(asyncio.as_completed(tasks), total=len(tasks), desc="Annotating Traces"):
            result = await future
            if result:
                outfile.write(json.dumps(result) + '\n')

    try:
        with open(args.ontology_file, 'w') as f:
            json.dump(ontology, f, indent=2)
        logging.info(f"Ontology updated and saved back to {args.ontology_file}")
    except Exception as e:
        logging.error(f"Failed to save updated ontology to {args.ontology_file}: {e}")


    logging.info(f"Annotation complete. Output saved to {args.output_file}")

if __name__ == "__main__":
    asyncio.run(main()) 
