import json
import asyncio
import os
import random
import argparse
from typing import List, Dict, Any
from tqdm.asyncio import tqdm_asyncio

from trpo.llm_client_factory import get_async_client

# ==========================================
# Batch Prompt Templates
# ==========================================



PROMPT_EVAL_LTL_TRANSLATION_BATCH = """
You are a Formal Logic Expert.
Your task is to verify the **Translation Quality** of LTL Formulas based on Trace Analysis.

[LTL GRAMMAR REFERENCE]
- sudden_repeat: `G(act_A -> X(!act_A))`
- lack_premise: `G(!obs_A -> !act_B)`
- hazard: `G(!act_Subject)` 
- rollback: `G(obs_Trigger -> G(!act_Forbidden))`
- other: Flexible combination using '&', '|', '!', 'X', 'F', 'G', 'U'.

[CONTEXT]
- **Positive LTL** corresponds to successful milestones.
- **Negative LTLs** correspond to different failure modes (bad behaviors).
- These formulas represent **independent** tracks of behavior. 
- A formula is **Faithful** if it correctly translates the text in the Analysis JSON into LTL.

[EVALUATION CRITERIA]

1. **Positive LTL**:
   - If faithful, set `is_faithful` to true.
   - If NOT faithful, classify the error type:
     - **LOGIC_ERROR**: The formula enforces a WRONG sequence or logic that contradicts the analysis (e.g., wrong order, wrong action). This is CRITICAL.
     - **MISSING_INFO**: The formula is logically correct but skips some steps mentioned in the analysis. This is MINOR.

2. **Negative LTLs**: Evaluate EACH formula.
   - First, identify which bad behavior this formula corresponds to.
   - If faithful, set `is_faithful` to true.
   - If NOT faithful, classify the error type:
     - **HARMFUL**: The formula is wrong in a way that would punish VALID behavior (e.g., forbidding a necessary action). This is CRITICAL.
     - **INEFFECTIVE**: The formula fails to capture the bad behavior (e.g., syntax error, irrelevant constraints) but likely won't punish good behavior actively. It's a "miss".

[INPUT CASES]
{batch_input}

[OUTPUT INSTRUCTION]
Return a thinking process within <think> tag and a JSON object with a "results" list.
For THINKING:
<think>
1. Examine the Positive LTL: Does it capture the sequence of milestones? If there's an error, is it a complete logic failure (LOGIC_ERROR) or just a skip of some steps (MISSING_INFO)?
2. Examine the Negative LTL formulas one by one:
   - Attempt to match each formula with the existing bad behaviors from the analysis, selecting the one it most likely belongs to.
   - Based on this alignment, judge whether the formula faithfully reflects the content and type of that bad behavior.
   - If it is not faithful, decide the impact: Would it punish valid actions (HARMFUL) or is it just a broken/useless formula (INEFFECTIVE)?
</think>
For the JSON:
Each entry in "results" must strictly follow this structure:
{{
    "case_id": <int>,
    "pos_ltl_eval": {{ "is_faithful": true/false, "error_type": "NONE" | "LOGIC_ERROR" | "MISSING_INFO", "reason": "..." }},
    "neg_ltl_evals": [
        {{ "formula": "...", "is_faithful": true/false, "error_type": "NONE" | "HARMFUL" | "INEFFECTIVE", "reason": "..." }},
        ...
    ]
}}
Output JSON only.
"""

# ==========================================
# Core Evaluation Class
# ==========================================

class TRPO_LTLEvaluator:
    def __init__(self, log_file_path: str, llm_client_name="gpt-4o", concurrency=5, max_entries=None, llm_batch_size=5):
        self.log_file_path = log_file_path
        # Force using the requested high-quality model
        self.client = get_async_client(llm_client_name) 
        self.sem = asyncio.Semaphore(concurrency)
        self.llm_batch_size = llm_batch_size
        self.data = self._load_data(max_entries)
        
    def _load_data(self, max_entries):
        data = []
        if not os.path.exists(self.log_file_path):
            print(f"Error: {self.log_file_path} not found.")
            return []
        
        with open(self.log_file_path, 'r', encoding='utf-8') as f:
            lines = f.readlines()
        
        # Random sampling if file is too large
        if max_entries and len(lines) > max_entries:
            random.shuffle(lines)
            lines = lines[:max_entries]
            
        for line in lines:
            if line.strip():
                try:
                    entry = json.loads(line)
                    # Basic validation
                    if 'group_uid' in entry and 'task_goal' in entry:
                        data.append(entry)
                except json.JSONDecodeError:
                    continue
        print(f"Loaded {len(data)} records from {self.log_file_path}")
        return data

    # --------------------------------------------------------
    # Generic Batch Processor
    # --------------------------------------------------------
    async def _process_batch(self, batch_items, prompt_template, result_parser):
        # 1. Construct Batch Input String
        batch_input_str = ""
        for i, item in enumerate(batch_items):
            batch_input_str += f"\n--- CASE {i+1} ---\n{item['input_str']}"
        
        prompt = prompt_template.format(batch_input=batch_input_str)
        
        async with self.sem:
            try:
                resp = await self.client.chat.completions.create(
                    messages=[{"role": "user", "content": prompt}],
                    temperature=0.0,
                    max_tokens=4096
                )
                content = resp.choices[0].message.content
                
                # Robust JSON extraction
                try:
                    json_str = content[content.find('{'):content.rfind('}')+1]
                    result_json = json.loads(json_str)
                except:
                    # Fallback for simple markdown code block
                    if "```json" in content:
                        content = content.split("```json")[1].split("```")[0].strip()
                        result_json = json.loads(content)
                    else:
                        raise ValueError("Could not extract JSON")

                # Parse results back to items
                processed_results = []
                results_list = result_json.get("results", [])
                
                # Map back by index (assuming strict ordering)
                for res in results_list:
                    case_idx = res.get("case_id") - 1 # 1-based to 0-based
                    if 0 <= case_idx < len(batch_items):
                        original_item = batch_items[case_idx]
                        parsed = result_parser(original_item, res)
                        if parsed: processed_results.append(parsed)
                return processed_results
            except Exception as e:
                # print(f"Batch failed: {e}")
                return []

    async def _run_batch_check(self, items, prepare_input_fn, prompt_template, result_parser, desc):
        # Prepare valid items
        valid_items = []
        for entry in items:
            input_str = prepare_input_fn(entry)
            if input_str:
                valid_items.append({"entry": entry, "input_str": input_str})
        
        if not valid_items:
            print(f"No valid items for {desc}")
            return []

        # Chunk into batches
        batches = [valid_items[i:i + self.llm_batch_size] for i in range(0, len(valid_items), self.llm_batch_size)]
        
        print(f"\n=== {desc} (Items: {len(valid_items)}, Batches: {len(batches)}) ===")
        
        tasks = [self._process_batch(b, prompt_template, result_parser) for b in batches]
        results = await tqdm_asyncio.gather(*tasks)
        
        # Flatten results
        flat_results = [item for sublist in results for item in sublist]
        return flat_results



    # --------------------------------------------------------
    # 2. Evaluate LTL Translation Correctness (Non-True Only)
    # --------------------------------------------------------
    async def eval_ltl_translation(self, sample_size=50, show_errors=False):
        # Filter: Skip anomalous data (neg_ltl is "true" OR pos_ltl is "true" OR milestones are empty)
        def is_valid_translation(d):
            # Check LTL formulas
            neg = str(d.get('neg_ltl', 'true')).lower()
            pos = str(d.get('pos_ltl', 'true')).lower()
            if neg == 'true' or pos == 'true': return False
            
            # Check milestones in analysis
            analysis = d.get('trace_analysis')
            if not analysis: return False
            try:
                parsed = json.loads(analysis) if isinstance(analysis, str) else analysis
                return len(parsed.get('milestones', [])) > 0
            except:
                return False

        filtered_data = [d for d in self.data if is_valid_translation(d)]
            
        if sample_size:
            filtered_data = random.sample(filtered_data, min(sample_size, len(filtered_data)))
            
        def prepare(entry):
            goal = entry.get('task_goal')
            analysis = entry.get('trace_analysis')
            pos = entry.get('pos_ltl')
            neg = entry.get('neg_ltl')
            return f"Goal: {goal}\nAnalysis JSON: {str(analysis)}\nPos LTL: {pos}\nNeg LTL List: {neg}"

        def parse(item, res):
            return {
                "group_uid": item["entry"].get("group_uid"),
                "task_goal": item["entry"].get("task_goal"),
                "input_analysis": item["entry"].get("trace_analysis"),
                "pos_ltl_raw": item["entry"].get("pos_ltl"),
                "neg_ltl_raw": item["entry"].get("neg_ltl"),
                "pos_ltl_eval": res.get("pos_ltl_eval"),
                "neg_ltl_evals": res.get("neg_ltl_evals", [])
            }

        results = await self._run_batch_check(
            filtered_data, prepare, PROMPT_EVAL_LTL_TRANSLATION_BATCH, parse, "[2] Evaluating LTL Translation (Non-Trivial)"
        )

        # Statistics
        pos_total = len(results)
        pos_faithful = 0
        pos_logic_error = 0
        pos_missing_info = 0
        
        neg_total = sum(len(r['neg_ltl_evals']) for r in results)
        neg_faithful = 0
        neg_harmful = 0
        neg_ineffective = 0

        for r in results:
            # Pos Stats
            pe = r['pos_ltl_eval']
            if pe.get('is_faithful'):
                pos_faithful += 1
            else:
                etype = pe.get('error_type', 'LOGIC_ERROR') # Default to worst case if missing
                if etype == 'MISSING_INFO': pos_missing_info += 1
                else: pos_logic_error += 1
            
            # Neg Stats
            for ne in r['neg_ltl_evals']:
                if ne.get('is_faithful'):
                    neg_faithful += 1
                else:
                    etype = ne.get('error_type', 'HARMFUL') # Default to worst case
                    if etype == 'INEFFECTIVE': neg_ineffective += 1
                    else: neg_harmful += 1

        print(f"\n>>> LTL Translation Summary:")
        print(f"  - Positive LTL:")
        print(f"      - Correctness Rate:  {pos_faithful/(pos_total+1e-6)*100:.1f}% ({pos_faithful}/{pos_total})")
        print(f"      - Harmlessness Rate: {(pos_total-pos_logic_error)/(pos_total+1e-6)*100:.1f}%")
        print(f"      - Logic Errors (Critical): {pos_logic_error}")
        print(f"      - Missing Info (Minor):    {pos_missing_info}")
        
        print(f"  - Negative LTLs:")
        print(f"      - Correctness Rate:  {neg_faithful/(neg_total+1e-6)*100:.1f}% ({neg_faithful}/{neg_total})")
        print(f"      - Harmlessness Rate: {(neg_total-neg_harmful)/(neg_total+1e-6)*100:.1f}%")
        print(f"      - Harmful (Critical):      {neg_harmful}")
        print(f"      - Ineffective (Minor):     {neg_ineffective}")
        
        # Overall Metrics
        total_all = pos_total + neg_total
        faithful_all = pos_faithful + neg_faithful
        critical_all = pos_logic_error + neg_harmful
        harmless_all = total_all - critical_all
        
        print(f"\n>>> Overall Performance:")
        print(f"  - Correctness Rate:  {faithful_all/(total_all+1e-6)*100:.1f}%")
        print(f"  - Harmlessness Rate: {harmless_all/(total_all+1e-6)*100:.1f}%")

        if show_errors:
            for r in results:
                p_fail = not r['pos_ltl_eval'].get('is_faithful')
                n_fail = any(not n.get('is_faithful') for n in r['neg_ltl_evals'])
                if p_fail or n_fail:
                    print(f"\n[ERROR CASE] Group UID: {r['group_uid']}")
                    print(f"Goal: {r['task_goal']}")
                    print(f"Analysis: {r['input_analysis']}")
                    
                    if p_fail:
                        print(f"Pos LTL Eval Fail: {json.dumps(r['pos_ltl_eval'], indent=2, ensure_ascii=False)}")

                    for n in r['neg_ltl_evals']:
                        if not n.get('is_faithful'):
                            print(f"Neg LTL Eval Fail: {json.dumps(n, indent=2, ensure_ascii=False)}")
                    
                    # Print Raw LTLs for reference
                    print(f"Original Pos LTL: {r.get('pos_ltl_raw')}")
                    print(f"Original Neg LTLs: {r.get('neg_ltl_raw')}")
                    print("-" * 40)
        return results

# ==========================================
# Entry Point
# ==========================================

async def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--log_file", type=str, required=True, help="Path to ltl_metadata.jsonl")
    parser.add_argument("--load_limit", type=int, default=2500, help="Max records to load")
    parser.add_argument("--sample_size", type=int, default=50, help="Samples per task")
    parser.add_argument("--concurrency", type=int, default=5, help="Concurrent requests")
    parser.add_argument("--llm_batch_size", type=int, default=1, help="Records per LLM request")
    parser.add_argument("--show_errors", action="store_true", help="If set, print detailed info for failed evaluations")
    args = parser.parse_args()

    evaluator = TRPO_LTLEvaluator(
        args.log_file, 
        llm_client_name="gpt-4o", # Enforce gpt-4o as requested
        concurrency=args.concurrency, 
        max_entries=args.load_limit,
        llm_batch_size=args.llm_batch_size
    )
    
    # Run the requested tasks
    await evaluator.eval_ltl_translation(args.sample_size, show_errors=args.show_errors)

if __name__ == "__main__":
    asyncio.run(main())

