#!/usr/bin/env python3
"""
Aggregate action_match_rate by policy model across multiple guardrail models
This script collects results from multiple directories and aggregates by policy model
"""
import json
import re
from pathlib import Path
from typing import Dict, List, Any, Optional
from collections import defaultdict

# Import functions from calculate_action_match_by_case
import sys
sys.path.insert(0, str(Path(__file__).parent))
from calculate_action_match_by_case import (
    parse_guardrail_feedback,
    normalize_risk_type,
    compare_risk_types,
    find_ground_truth_risk_type_from_trajectory
)


def process_single_result(
    result: Dict[str, Any],
    gt_task_data: Optional[Dict[str, Any]],
    tasks_dir: Optional[Path]
) -> Optional[Dict[str, Any]]:
    """
    Process a single result and return case information
    
    Returns:
        Dict with case info and action_match, or None if not evaluable
    """
    # Judge evaluation
    judge = result.get("judge", {})
    action_judge = judge.get("action", {})
    hazard_judge = judge.get("hazard", {})
    
    if not action_judge or not hazard_judge:
        return None
    
    action_match = action_judge.get("match", False)
    hazard_match = hazard_judge.get("match", False)
    
    # Guardrail feedback
    guardrail_feedback = result.get("guardrail_feedback", "")
    parsed_feedback = parse_guardrail_feedback(guardrail_feedback)
    
    if not parsed_feedback["risk_detected"]:
        return None
    
    agent_action = result.get("agent_action", "")
    step = result.get("step")
    gt_mitigation = result.get("gt_mitigation", "")
    
    # Find ground truth risk_type from real_planning_trajectory
    # This ensures we only count steps that are actually risky
    gt_risk_type = None
    if gt_task_data and step is not None:
        gt_risk_type = find_ground_truth_risk_type_from_trajectory(
            gt_task_data, step, agent_action
        )
    
    if not gt_risk_type:
        return None
    
    # Determine risk_type correctness
    predicted_risk_type = parsed_feedback.get("risk_type", "")
    if not predicted_risk_type:
        return None
    
    risk_type_correct = compare_risk_types(predicted_risk_type, gt_risk_type)
    hazard_correct = hazard_match
    
    # Determine case
    if not hazard_correct and not risk_type_correct:
        case = 1
    elif not hazard_correct and risk_type_correct:
        case = 2
    elif hazard_correct and not risk_type_correct:
        case = 3
    elif hazard_correct and risk_type_correct:
        case = 4
    else:
        return None
    
    return {
        "case": case,
        "action_match": action_match
    }


def aggregate_by_policy_model(
    base_dirs: List[Path],
    policy_models: List[str],
    tasks_dir: Optional[Path] = None
) -> Dict[str, Any]:
    """
    Aggregate results by policy model across multiple base directories
    
    Args:
        base_dirs: List of base directories (e.g., [20260127-18:19:57, 20260127-19:10:07])
        policy_models: List of policy models to aggregate (e.g., ['gpt-4o-openai', 'gpt-4o-mini-openai'])
        tasks_dir: Optional directory containing original task JSON files
        
    Returns:
        Dictionary with aggregated metrics by policy model
    """
    # Initialize counters for each policy model
    policy_results = {}
    for policy_model in policy_models:
        policy_results[policy_model] = {
            "case1_action_match": 0,
            "case1_action_total": 0,
            "case2_action_match": 0,
            "case2_action_total": 0,
            "case3_action_match": 0,
            "case3_action_total": 0,
            "case4_action_match": 0,
            "case4_action_total": 0,
        }
    
    # Process each base directory
    for base_dir in base_dirs:
        if not base_dir.exists():
            print(f"Warning: Base directory not found: {base_dir}")
            continue
        
        print(f"\nProcessing base directory: {base_dir}")
        
        # Find all guardrail model directories
        guardrail_dirs = [d for d in base_dir.iterdir() if d.is_dir()]
        
        for guardrail_dir in guardrail_dirs:
            guardrail_model = guardrail_dir.name
            print(f"  Processing guardrail model: {guardrail_model}")
            
            # Check each policy model
            for policy_model in policy_models:
                policy_dir = guardrail_dir / policy_model
                
                if not policy_dir.exists():
                    continue
                
                print(f"    Processing policy model: {policy_model}")
                
                # Find all task JSON files
                task_files = [f for f in policy_dir.glob("*.json") 
                             if f.name not in ["aggregate.json", "summary.json"]]
                
                print(f"      Found {len(task_files)} task files")
                
                # Process each task file
                for task_file in task_files:
                    with open(task_file, 'r') as f:
                        task_data = json.load(f)
                    
                    task_name = task_data.get("task", task_file.stem)
                    results = task_data.get("results", [])
                    
                    if not results:
                        continue
                    
                    # Load task JSON for ground truth risk_type
                    gt_task_data = None
                    if tasks_dir:
                        task_json_path = tasks_dir / f"{task_name}.json"
                        if task_json_path.exists():
                            with open(task_json_path, 'r') as f:
                                gt_task_data = json.load(f)
                    
                    # Process each result
                    for result in results:
                        case_info = process_single_result(result, gt_task_data, tasks_dir)
                        
                        if case_info is None:
                            continue
                        
                        case = case_info["case"]
                        action_match = case_info["action_match"]
                        
                        # Update counters
                        policy_results[policy_model][f"case{case}_action_total"] += 1
                        if action_match:
                            policy_results[policy_model][f"case{case}_action_match"] += 1
    
    # Calculate rates for each policy model
    aggregated_results = {}
    
    for policy_model, counts in policy_results.items():
        case1_rate = (counts["case1_action_match"] / counts["case1_action_total"] 
                     if counts["case1_action_total"] > 0 else 0.0)
        case2_rate = (counts["case2_action_match"] / counts["case2_action_total"] 
                     if counts["case2_action_total"] > 0 else 0.0)
        case3_rate = (counts["case3_action_match"] / counts["case3_action_total"] 
                     if counts["case3_action_total"] > 0 else 0.0)
        case4_rate = (counts["case4_action_match"] / counts["case4_action_total"] 
                     if counts["case4_action_total"] > 0 else 0.0)
        
        aggregated_results[policy_model] = {
            "case1_both_wrong": {
                "description": "Both hazard and risk_type wrong",
                "action_match_count": counts["case1_action_match"],
                "action_total": counts["case1_action_total"],
                "action_match_rate": case1_rate
            },
            "case2_hazard_wrong_only": {
                "description": "Only hazard wrong (risk_type correct)",
                "action_match_count": counts["case2_action_match"],
                "action_total": counts["case2_action_total"],
                "action_match_rate": case2_rate
            },
            "case3_risk_type_wrong_only": {
                "description": "Only risk_type wrong (hazard correct)",
                "action_match_count": counts["case3_action_match"],
                "action_total": counts["case3_action_total"],
                "action_match_rate": case3_rate
            },
            "case4_both_correct": {
                "description": "Both hazard and risk_type correct",
                "action_match_count": counts["case4_action_match"],
                "action_total": counts["case4_action_total"],
                "action_match_rate": case4_rate
            }
        }
    
    return aggregated_results


def main():
    """Main function"""
    import argparse
    
    parser = argparse.ArgumentParser(
        description="Aggregate action_match_rate by policy model across multiple guardrail models"
    )
    parser.add_argument(
        "--base_dirs",
        type=str,
        nargs="+",
        required=True,
        help="Base directories containing guardrail model folders (e.g., 20260127-18:19:57 20260127-19:10:07)"
    )
    parser.add_argument(
        "--policy_models",
        type=str,
        nargs="+",
        default=["gpt-4o-openai", "gpt-4o-mini-openai"],
        help="Policy models to aggregate. Default: gpt-4o-openai gpt-4o-mini-openai"
    )
    parser.add_argument(
        "--results_base",
        type=str,
        default="../results/embodied_planning",
        help="Base directory for results. Default: ../results/embodied_planning"
    )
    parser.add_argument(
        "--tasks_dir",
        type=str,
        default="../data/tasks",
        help="Directory containing original task JSON files. Default: ../data/tasks"
    )
    parser.add_argument(
        "--output",
        type=str,
        help="Output JSON file path (optional)"
    )
    
    args = parser.parse_args()
    
    # Get script directory to resolve relative paths
    script_dir = Path(__file__).parent
    
    # Resolve results_base path
    results_base = Path(args.results_base)
    if not results_base.is_absolute():
        results_base = script_dir / results_base
    
    # Resolve base_dirs
    base_dirs = []
    for base_dir_name in args.base_dirs:
        base_dir = results_base / base_dir_name
        if not base_dir.exists():
            print(f"Warning: Base directory not found: {base_dir}")
            continue
        base_dirs.append(base_dir)
    
    if not base_dirs:
        raise ValueError("No valid base directories found")
    
    # Resolve tasks_dir
    tasks_dir = Path(args.tasks_dir)
    if not tasks_dir.is_absolute():
        tasks_dir = script_dir / tasks_dir
    
    if not tasks_dir.exists():
        print(f"Warning: Tasks directory not found: {tasks_dir}. Risk type match will not be calculated.")
        tasks_dir = None
    
    # Aggregate results
    print(f"Aggregating results by policy model...")
    print(f"Base directories: {[str(d) for d in base_dirs]}")
    print(f"Policy models: {args.policy_models}")
    
    aggregated_results = aggregate_by_policy_model(base_dirs, args.policy_models, tasks_dir)
    
    # Print results
    print(f"\n{'='*70}")
    print(f"Aggregated Action Match Rate by Policy Model")
    print(f"{'='*70}")
    
    for policy_model, metrics in aggregated_results.items():
        print(f"\n{'='*70}")
        print(f"Policy Model: {policy_model}")
        print(f"{'='*70}")
        
        for case_key, case_data in metrics.items():
            print(f"\n{case_data['description']}:")
            print(f"  Action Match Rate: {case_data['action_match_rate']:.4f}")
            print(f"  Action Match Count: {case_data['action_match_count']}/{case_data['action_total']}")
    
    print(f"\n{'='*70}")
    
    # Save to file
    if args.output:
        output_path = Path(args.output)
        output_path.parent.mkdir(parents=True, exist_ok=True)
        with open(output_path, 'w') as f:
            json.dump(aggregated_results, f, indent=2, ensure_ascii=False)
        print(f"\nResults saved to: {output_path}")
    else:
        # Default output path
        output_base_dir = script_dir.parent / "results" / "action_match" / "aggregated"
        output_base_dir.mkdir(parents=True, exist_ok=True)
        
        policy_str = "_".join(args.policy_models)
        output_file = output_base_dir / f"aggregated_by_policy_{policy_str}.json"
        
        with open(output_file, 'w') as f:
            json.dump(aggregated_results, f, indent=2, ensure_ascii=False)
        print(f"\nResults saved to: {output_file}")


if __name__ == "__main__":
    main()

