import json
import asyncio
import os
import random
import argparse
from tqdm.asyncio import tqdm_asyncio
from trpo.llm_client_factory import get_async_client

# ==========================================
# Prompt Template: Accuracy
# ==========================================

PROMPT_EVAL_AP_ACCURACY_BATCH = """
You are an expert Evaluator for a Neuro-Symbolic State Extractor.
Your task is to evaluate the **Accuracy** and **Completeness** of extracted "Visual Triplets" for a given Observation and Goal.

[EXTRACTION RULES - THE STANDARD OF TRUTH]
The system SHOULD have followed these rules:
1. **Format**: `["Subject", "Category", "Value"]`
2. **Inventory**: If agent holds X, output `["X", "loc", "self"]`. (NOT `["self", "holding", "X"]`).
3. **Self Location**: `["self", "loc", "room_name"]`.
4. **No Hallucination**: Only extract explicitly visible facts. But only ['xxx', 'seen', 'false'] is allowed.

[EVALUATION CRITERIA]
1. **Accuracy**: Are the extracted triplets correct according to the Observation and Rules?
   - Check for hallucinated objects.
   - Check for wrong attributes (e.g., saying "open" when it says "closed").
   - Check for rule violations (e.g., "holding" relation instead of "loc" "self").
2. **Completeness**: Are there missing important facts?
   - Any visible object in the observation NOT mentioned in triplets?
   - Any important state (open/closed, sliced/cooked) missing?

[INPUT CASES]
{batch_input}

[OUTPUT INSTRUCTION]
Return a JSON object with a "results" list.
Format for each result:
{{
    "case_id": <int>,
    "bad_triplets": [0, 2], // Indices (0-based) of triplets that are incorrect, invalid format, or hallucinations. Empty list if all perfect.
    "has_omission": true/false, // True if important facts are missing.
    "missed_facts": [["Subject1", "Category1", "Value1"],["Subject2", "Category2", "Value2"]], //Brief list of missing items/states.
    "reason": "Explanation..."
}}
Output JSON only.
"""

# ==========================================
# Prompt Template: Redundancy
# ==========================================

PROMPT_EVAL_AP_REDUNDANCY_BATCH = """
You are an expert Logician and Semantics Analyst.
Your task is to check a list of "Atomic Propositions" (APs) for **Redundancy**.
We want to ensure that the AP list for a task group is concise and distinct.

[DEFINITION OF REDUNDANCY]
Two APs are redundant if they assert the same atomic fact about the world, even if the wording or syntax differs slightly.
Since APs encompass system states at different moments, it is normal for an object to have multiple attributes, such as being in different locations. However, if two APs express the same attribute or state, they are considered redundant.

- Redundant: "mug state clean" vs "mug state cleaned"
- Redundant: "egg state heat" vs "apple state hot"
- NOT Redundant: "door state closed" vs "door state open" (Different attributes)
- NOT Redundant: "apple 1 loc table 1" vs "apple 2 loc table 1" (Different subjects)

[INPUT GROUPS]
{batch_input}

[OUTPUT INSTRUCTION]
Return a JSON object with a "results" list.
Format for each result:
{{
    "group_id": <string>,
    "reason": "Brief explanation of the overlap.",
    "has_redundancy": true/false,
    "redundant_sets": [
        ["ap1", "ap2"] // Lists of strings that are assert same world fact.
    ],"
}}
Output JSON only.
"""

class APEvaluator:
    def __init__(self, log_file_path: str, eval_type="accuracy", llm_client_name="or-gemini-3-flash", concurrency=5, max_entries=None, llm_batch_size=5, show_error=False):
        self.log_file_path = log_file_path
        self.eval_type = eval_type
        self.client = get_async_client(llm_client_name) 
        self.sem = asyncio.Semaphore(concurrency)
        self.llm_batch_size = llm_batch_size
        self.show_error = show_error
        
        if self.eval_type == "accuracy":
            self.data = self._load_accuracy_data(max_entries)
        elif self.eval_type == "redundancy":
            self.data = self._load_group_data(max_entries)
        else:
            raise ValueError(f"Unknown eval_type: {eval_type}")
        
    def _load_accuracy_data(self, max_entries):
        if not os.path.exists(self.log_file_path):
            print(f"Error: {self.log_file_path} not found.")
            return []
        
        with open(self.log_file_path, 'r', encoding='utf-8') as f:
            lines = f.readlines()
        
        seen = set()
        unique_data = []
        for line in lines:
            if not line.strip(): continue
            try:
                entry = json.loads(line)
                key = (entry.get('task_goal'), entry.get('observation'))
                if key not in seen:
                    seen.add(key)
                    unique_data.append(entry)
            except json.JSONDecodeError:
                continue

        if max_entries and len(unique_data) > max_entries:
            random.shuffle(unique_data)
            unique_data = unique_data[:max_entries]
        else:
            random.shuffle(unique_data)
            
        print(f"Loaded {len(unique_data)} unique observations for accuracy check.")
        return unique_data

    def _load_group_data(self, max_entries):
        """Loads data for redundancy check. Expects JSONL with 'group_uid' and 'ap_list'."""
        if not os.path.exists(self.log_file_path):
            print(f"Error: {self.log_file_path} not found.")
            return []
        
        data = []
        with open(self.log_file_path, 'r', encoding='utf-8') as f:
            for line in f:
                if not line.strip(): continue
                try:
                    entry = json.loads(line)
                    if 'ap_list' in entry and isinstance(entry['ap_list'], list):
                        # Filter out 'act_' (actions) and '_seen_' (visual memory)
                        entry['ap_list'] = [ap for ap in entry['ap_list'] if not str(ap).startswith('act_') and '_seen_' not in str(ap)]
                        data.append(entry)
                except json.JSONDecodeError:
                    continue
        
        if max_entries and len(data) > max_entries:
            random.shuffle(data)
            data = data[:max_entries]
        
        print(f"Loaded {len(data)} groups for redundancy check.")
        return data

    async def _process_accuracy_batch(self, batch_items):
        batch_input_str = ""
        for i, item in enumerate(batch_items):
            triplets_str = json.dumps(item.get('triplets', []), ensure_ascii=False)
            batch_input_str += f"\n--- CASE {i+1} ---\nGoal: {item['task_goal']}\nObservation: {item['observation']}\nExtracted Triplets: {triplets_str}"
        
        prompt = PROMPT_EVAL_AP_ACCURACY_BATCH.format(batch_input=batch_input_str)
        return await self._call_llm(prompt, batch_items, "accuracy")

    async def _process_redundancy_batch(self, batch_items):
        batch_input_str = ""
        for item in batch_items:
            # item is a group dict
            group_id = item.get('group_uid', 'unknown')
            ap_list_str = json.dumps(item.get('ap_list', []), ensure_ascii=False)
            batch_input_str += f"\n--- GROUP: {group_id} ---\nAP List: {ap_list_str}"
        
        prompt = PROMPT_EVAL_AP_REDUNDANCY_BATCH.format(batch_input=batch_input_str)
        return await self._call_llm(prompt, batch_items, "redundancy")

    async def _call_llm(self, prompt, batch_items, mode):
        async with self.sem:
            try:
                resp = await self.client.chat.completions.create(
                    messages=[{"role": "user", "content": prompt}],
                    temperature=0.0,
                    max_tokens=2048
                )
                content = resp.choices[0].message.content
                
                # Robust JSON extraction
                try:
                    json_str = content[content.find('{'):content.rfind('}')+1]
                    result_json = json.loads(json_str)
                except:
                    if "```json" in content:
                        content = content.split("```json")[1].split("```")[0].strip()
                        result_json = json.loads(content)
                    else:
                        return []

                processed_results = []
                
                if mode == "accuracy":
                    for res in result_json.get("results", []):
                        case_idx = res.get("case_id") - 1
                        if 0 <= case_idx < len(batch_items):
                            item = batch_items[case_idx]
                            processed_results.append({
                                "task_goal": item.get('task_goal', 'N/A'),
                                "observation": item['observation'],
                                "triplets": item['triplets'],
                                "bad_triplets": res.get("bad_triplets", []),
                                "has_omission": res.get("has_omission", False),
                                "missed_facts": res.get("missed_facts", []),
                                "reason": res.get("reason")
                            })
                            
                elif mode == "redundancy":
                    # Map by group_uid logic
                    results_map = {r.get("group_id"): r for r in result_json.get("results", [])}
                    for item in batch_items:
                        g_id = item.get('group_uid')
                        res = results_map.get(g_id)
                        if res:
                            processed_results.append({
                                "group_uid": g_id,
                                "ap_list": item.get('ap_list'),
                                "has_redundancy": res.get("has_redundancy"),
                                "redundant_sets": res.get("redundant_sets"),
                                "reason": res.get("reason")
                            })
                        else:
                            # Fallback if LLM missed an ID
                            processed_results.append({
                                "group_uid": g_id,
                                "error": "LLM output missing for this group"
                            })

                return processed_results
            except Exception as e:
                print(f"Batch failed: {e}")
                return []

    async def run_evaluation(self):
        if not self.data:
            print("No data to evaluate.")
            return

        batches = [self.data[i:i + self.llm_batch_size] for i in range(0, len(self.data), self.llm_batch_size)]
        print(f"\n=== Evaluating ({self.eval_type}) (Items: {len(self.data)}, Batches: {len(batches)}) ===")
        
        if self.eval_type == "accuracy":
            tasks = [self._process_accuracy_batch(b) for b in batches]
        else:
            tasks = [self._process_redundancy_batch(b) for b in batches]
            
        results = await tqdm_asyncio.gather(*tasks)
        flat_results = [item for sublist in results for item in sublist]
        
        print("\n" + "="*40)
        print("          EVALUATION REPORT")
        print("="*40)

        if self.eval_type == "accuracy":
            total_cases = len(flat_results)
            omission_cases = sum(1 for r in flat_results if r.get('has_omission'))
            
            total_triplets = 0
            bad_triplets_count = 0
            missed_facts_count = 0
            missed_facts_filtered_count = 0
            
            for r in flat_results:
                triplets = r.get('triplets', [])
                total_triplets += len(triplets)
                bad_triplets_count += len(r.get('bad_triplets', []))
                
                m_facts = r.get('missed_facts', [])
                missed_facts_count += len(m_facts)
                
                for f in m_facts:
                    # Filter out specific types of missed facts (e.g. 'seen', 'self loc')
                    should_filter = False
                    f_str = str(f)
                    
                    # 1. Filter 'seen'
                    if "'seen'" in f_str or '"seen"' in f_str:
                        should_filter = True
                    if isinstance(f, list) and len(f) > 1 and f[1] == "seen":
                        should_filter = True
                    
                    # 2. Filter SPECIFIC 'self' location ["self", "loc", "room"]
                    # We only filter 'room' as it's often an over-correction. Specific rooms are fine.
                    if isinstance(f, list) and len(f) >= 3:
                        if f[0] == "self" and f[1] == "loc" and str(f[2]).lower() == "room":
                            should_filter = True
                    elif "self" in f_str.lower() and "loc" in f_str.lower() and "room" in f_str.lower():
                        should_filter = True

                    if not should_filter:
                        missed_facts_filtered_count += 1
            
            print(f"Total Cases Checked  : {total_cases}")
            print(f"Cases w/ Omissions   : {omission_cases} ({omission_cases/total_cases*100:.1f}%)")
            print("-" * 20)
            print(f"Total Triplets Found : {total_triplets}")
            print(f"Bad Triplets Count   : {bad_triplets_count}")
            print(f"Triplet Error Rate   : {(bad_triplets_count/total_triplets*100 if total_triplets > 0 else 0):.1f}%")
            print(f"Missed Facts Count   : {missed_facts_count}")
            print(f"Missed Facts Rate    : {(missed_facts_count/total_triplets*100 if total_triplets > 0 else 0):.1f}%")
            print(f"Missed (filtered)    : {missed_facts_filtered_count}")
            print(f"Missed Rate (filtered): {(missed_facts_filtered_count/total_triplets*100 if total_triplets > 0 else 0):.1f}%")

        else:
            total_groups = len(flat_results)
            clean_groups = sum(1 for r in flat_results if not r.get('has_redundancy') and 'error' not in r)
            
            # AP Level Stats
            total_aps = 0
            redundant_aps = 0
            
            for r in flat_results:
                ap_list = r.get('ap_list', [])
                if ap_list:
                    total_aps += len(ap_list)
                
                if r.get('has_redundancy'):
                    r_sets = r.get('redundant_sets', [])
                    if r_sets:
                        # For each set of size N, N-1 are redundant
                        for s in r_sets:
                            if isinstance(s, list) and len(s) > 1:
                                # Check if items are identical strings (hallucination)
                                if len(set(s)) == 1:
                                    continue # Ignore sets like ["a", "a"]
                                redundant_aps += (len(s) - 1)
            
            print(f"Total Groups Checked : {total_groups}")
            print(f"Clean Groups (No Red): {clean_groups} ({clean_groups/total_groups*100:.1f}%)")
            print("-" * 20)
            print(f"Total APs Found      : {total_aps}")
            print(f"Redundant APs (Est.) : {redundant_aps}")
            print(f"AP Redundancy Rate   : {(redundant_aps/total_aps*100 if total_aps > 0 else 0):.1f}%")

        if self.show_error:
            self._print_errors(flat_results)
            
    def _print_errors(self, results):
        print("\n" + "="*40)
        print("          ERROR / FAILURE DETAILS")
        print("="*40)
        
        count = 0
        for r in results:
            if self.eval_type == "accuracy":
                has_bad = bool(r.get('bad_triplets'))
                has_miss = r.get('has_omission')
                
                if has_bad or has_miss:
                    count += 1
                    print(f"\n[CASE {count}]")
                    print(f"Goal        : {r.get('task_goal')}")
                    print(f"Observation : {r.get('observation')}")
                    print(f"Triplets    : {json.dumps(r.get('triplets', []), ensure_ascii=False)}")
                    if has_miss:
                        print(f"MISSING     : {r.get('missed_facts')}")
                    if has_bad:
                        triplets = r.get('triplets', [])
                        bad_indices = r.get('bad_triplets', [])
                        print(f"BAD TRIPLETS:")
                        for idx in bad_indices:
                            if 0 <= idx < len(triplets):
                                print(f"  - {triplets[idx]}")
                    print(f"Reason      : {r.get('reason')}")

            else:
                if r.get('has_redundancy'):
                    count += 1
                    print(f"\n[GROUP {r.get('group_uid')}]")
                    print(f"Redundant Sets: {json.dumps(r.get('redundant_sets'), ensure_ascii=False)}")
                    print(f"Reason        : {r.get('reason')}")
                elif 'error' in r:
                     print(f"\n[GROUP {r.get('group_uid')}] - ERROR: {r.get('error')}")

async def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--log_file", type=str, required=True, help="Path to input file (jsonl)")
    parser.add_argument("--eval_type", type=str, default="accuracy", choices=["accuracy", "redundancy"], help="Type of evaluation")
    parser.add_argument("--sample_size", type=int, default=50, help="Max samples to check")
    parser.add_argument("--concurrency", type=int, default=5)
    parser.add_argument("--show_error", action="store_true", help="Show details of failed cases in console")
    args = parser.parse_args()

    evaluator = APEvaluator(
        args.log_file,
        eval_type=args.eval_type,
        concurrency=args.concurrency,
        max_entries=args.sample_size,
        show_error=args.show_error
    )
    await evaluator.run_evaluation()

if __name__ == "__main__":
    asyncio.run(main())
