import argparse
from copy import deepcopy
import traceback

import numpy as np
import sys
import os
import json
import glob
import json
import os
import sys
import time
from typing import Any, Dict, List, Tuple

import numpy as np
from transformers import AutoTokenizer

from agent.load_model import init_agent, init_llm
from data.load_datasets import load_query, save_json_file
from utils.llms import Deepseek, GPT4o, Gemini
from utils.poi_analyzer import POIAnalyzer

project_root_path = os.path.dirname(os.path.abspath(__file__))
if project_root_path not in sys.path:
    sys.path.insert(0, project_root_path)


import traceback
from data.load_datasets import load_query, save_json_file
from agent.load_model import init_agent, init_llm
from utils.poi_analyzer import POIAnalyzer
from utils.llms import Gemini, GPT4o, Deepseek
from evaluators.main_evaluator import MainEvaluator


def evaluate_plan_quality(tokenizer, plan_data, query_data, case_data, enable_user_request_eval=False, enable_LLM=False, debug=False):
    """
    Evaluate plan quality, using main evaluator for comprehensive evaluation
    
    Args:
        tokenizer: Tokenizer
        plan_data: Generated plan data
        query_data: Query data
        case_data: Case data
        enable_user_request_eval: Whether to enable user request evaluation
        enable_LLM: Whether to enable LLM evaluation
        debug: Whether to print debug information
        
    Returns:
        tuple: (is_success, total_score, evaluation_details)
    """
    try:


        # Get messages from query_data and parse if it's a JSON string
        messages = query_data.get("messages", [])
        if isinstance(messages, str):
            import json
            try:
                messages = json.loads(messages)
            except json.JSONDecodeError:
                messages = []
        
        prompt_text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
            enable_thinking=False
        )
        if isinstance(plan_data["llm_response"], dict):
            import json
            solution_str = prompt_text + json.dumps(plan_data["llm_response"],ensure_ascii=False)
        else:
            solution_str = prompt_text + plan_data["llm_response"]

        poi_dict = case_data.get("poi_dict", {})
        
        # Use main evaluator for comprehensive evaluation, pass control parameters
        total_score, evaluation_details = main_evaluator.compute_score(solution_str, poi_dict)
        
        # Extract key evaluation results
        format_score = evaluation_details.get("format_score", -3.0)
        format_correct = evaluation_details.get("format_details", {}).get("format_correct", False)
        commonsense_score = evaluation_details.get("commonsense_score", 0.0)
        soft_constraint_score = evaluation_details.get("soft_constraint_score", 0.0)
        preference_score = evaluation_details.get("preference_score", 0.0)
        user_request_score = evaluation_details.get("user_request_score", 0.0)
        
        if debug:
            print(f"  Main evaluation results:")
            print(f"    Total score: {total_score:.3f}")
            print(f"    Format score: {format_score:.3f} (correct: {format_correct})")
            print(f"    Commonsense score: {commonsense_score:.3f}")
            print(f"    Soft constraint score: {soft_constraint_score:.3f}")
            print(f"    Preference score: {preference_score:.3f}")
            if enable_user_request_eval:
                print(f"    User request score: {user_request_score:.3f}")
            
            # Print LLM evaluation details
            if enable_LLM:
                soft_details = evaluation_details.get("soft_constraint_details", {})
                classic_violations = soft_details.get("classic_attractions_violations", [])
                diversity_violations = soft_details.get("diversity_violations", [])
                if classic_violations:
                    print(f"    Classic attractions violations: {len(classic_violations)}")
                if diversity_violations:
                    print(f"    Diversity violations: {len(diversity_violations)}")
            
            # Print detailed information
            for key, value in evaluation_details.items():
                if key.endswith('_details'):
                    if isinstance(value, dict):
                        print(f"    {key}:")
                        for k, v in value.items():
                            if not isinstance(v, dict):  # Only print non-dictionary values
                                print(f"      {k}: {v}")
                            else:
                                for k2, v2 in v.items():
                                    print(f"      {k}.{k2}: {v2}")
                    else:
                        print(f"      {key}: {value}")
        
        # Successful condition: format correct and total score > 0
        is_success = format_correct and total_score > 0
        
        # To maintain backward compatibility, add some traditional fields
        evaluation_details.update({
            "overall_success": is_success,
            "evaluation_method": "main_evaluator",
            "format_correct": format_correct,
            "evaluation_config": {
                "enable_user_request_eval": enable_user_request_eval,
                "enable_LLM": enable_LLM
            }
        })
        
        return is_success, total_score, evaluation_details
        
    except Exception as e:
        if debug:
            print(f"  Evaluation error: {e}")
            traceback.print_exc()
        return False, -3.0, {"error": str(e) + traceback.format_exc()}


def generate_plans(args, query_data_list, agent, method, res_dir, log_dir):
    """
    Generate plan data (without evaluation)
    
    Args:
        args: Command line parameters
        query_data_list: Query data list
        agent: Agent instance
        method: Method name
        res_dir: Result directory
        log_dir: Log directory
        
    Returns:
        tuple: (generation_count, generated_files, avg_generation_time)
    """
    print("\n" + "="*50)
    print("GENERATION MODE - Generating plans only")
    print("="*50)
    
    generation_count = 0
    generated_files = []
    total_generation_time = 0.0
    generation_times = []
    
    for i, case_data in enumerate(query_data_list):
        data_idx = case_data.get("message_id", f"case_{i}")
        
        # Handle restart logic
        if (args.restart_from is not None) and (data_idx != args.restart_from):
            continue
        else:
            args.restart_from = None

        print("------------------------------")
        print(f"Generating [{i + 1}/{len(query_data_list)}]:")
        print("Data ID:", data_idx)

        # Skip if result already exists
        result_file = os.path.join(res_dir, f"{data_idx}.json")
        if args.skip and os.path.exists(result_file):
            print(f"Skipping {data_idx} - result already exists")
            continue

        generation_count += 1
        
        # Extract query information from poi_dict
        poi_dict = case_data.get("poi_dict", {})
        original_case = case_data.get("original_case", {})

        
        origin_userQuery = original_case.get("userQuery", "")
        original_case['origin_userQuery'] = origin_userQuery
        del original_case['userQuery']

        if args.agent in ["LLMNeSy", "RuleNeSy"]:
            origin_userQuery = origin_userQuery.split(".")[0]
            
        query_i = {
            "userQuery": origin_userQuery,
            "day": poi_dict.get("day", 3),
            "departure": poi_dict.get("departure", ""),
            "arrive": poi_dict.get("arrive", ""),
            "transportation": poi_dict.get("transportation", ""),
            "transport_pool": poi_dict.get("transport_pool", "{}"),
            "hotel_pool": poi_dict.get("hotel_pool", "{}"),
            "poi_pool": poi_dict.get("poi_pool", "{}"),
            "preference": poi_dict.get("preference", {}),
            "messages": case_data.get("messages", []),
            "locale": poi_dict.get("locale", "en-US"),
            "case_index": case_data.get("case_index", i),
            **original_case
        }
        print("Query:", query_i.get("userQuery", "No query found"))

        try:
            # Start timing for this generation
            start_time = time.time()
            
            # Run the appropriate agent
            if args.agent == "Direct":
                agent_success, plan = agent.run(query_i, load_cache=True)
            elif args.agent == "CoT":
                agent_success, plan = agent.run(query_i, load_cache=True)
            elif args.agent in ["LLM-modulo"]:
                agent_success, plan = agent.solve(query_i, prob_idx=data_idx)
            elif args.agent in ["LLMNeSy", "RuleNeSy"]:
                agent_success, plan = agent.run(query_i, load_cache=True, 
                                        oralce_translation=args.oracle_translation, 
                                        preference_search=args.preference_search)
            elif args.agent == "HyperTree":
                agent_success, plan = agent.run(query_i, load_cache=True)
            elif args.agent == "TTG":
                agent_success,plan = agent.run(query_i, load_cache=True)

            else:
                raise Exception(f"Agent {args.agent} not implemented")

            # End timing for this generation
            end_time = time.time()
            generation_time = end_time - start_time
            generation_times.append(generation_time)
            total_generation_time += generation_time
            
            print(f"  Generation time: {generation_time:.2f} seconds")

            # Add token count information if available
            if hasattr(agent, 'backbone_llm') and agent.backbone_llm:
                plan["input_token_count"] = getattr(agent.backbone_llm, 'input_token_count', 0)
                plan["output_token_count"] = getattr(agent.backbone_llm, 'output_token_count', 0)
                plan["input_token_maxx"] = getattr(agent.backbone_llm, 'input_token_maxx', 0)

            # Add case metadata
            plan["case_index"] = case_data.get("case_index", i + 1)
            plan["data_idx"] = data_idx
            plan["agent_success"] = agent_success
            plan["generation_only"] = True
            plan["method"] = method
            plan["generation_time"] = generation_time
            
            # Add original case data for evaluation
            plan["poi_dict"] = poi_dict
            plan["messages"] = case_data.get("messages", [])

            # Save results
            save_json_file(json_data=plan, file_path=result_file)
            generated_files.append(result_file)
            print(f"✓ Generated and saved to {result_file}")

        except Exception as e:
            # Record generation time even for failed cases
            if 'start_time' in locals():
                end_time = time.time()
                generation_time = end_time - start_time
                generation_times.append(generation_time)
                total_generation_time += generation_time
                print(f"  Generation time: {generation_time:.2f} seconds (failed)")
            else:
                generation_time = 0.0
                
            print(f"✗ Error generating {data_idx}: {e}")
            error_plan = {
                "error": str(e),
                "data_idx": data_idx,
                "case_index": case_data.get("case_index", i + 1),
                "agent": args.agent,
                "llm": args.llm,
                "status": "generation_failed",
                "generation_only": True,
                "generation_time": generation_time
            }
            save_json_file(json_data=error_plan, file_path=result_file)
            generated_files.append(result_file)
    
    # Calculate and display timing statistics
    avg_generation_time = total_generation_time / generation_count if generation_count > 0 else 0.0
    
    print("\n" + "="*30 + " TIMING STATISTICS " + "="*30)
    print(f"Total plans generated: {generation_count}")
    print(f"Total generation time: {total_generation_time:.2f} seconds")
    print(f"Average generation time: {avg_generation_time:.2f} seconds per plan")
    
    if generation_times:
        min_time = min(generation_times)
        max_time = max(generation_times)
        print(f"Min generation time: {min_time:.2f} seconds")
        print(f"Max generation time: {max_time:.2f} seconds")
        
        # Calculate median
        sorted_times = sorted(generation_times)
        n = len(sorted_times)
        if n % 2 == 0:
            median_time = (sorted_times[n//2-1] + sorted_times[n//2]) / 2
        else:
            median_time = sorted_times[n//2]
        print(f"Median generation time: {median_time:.2f} seconds")
    
    print("="*78)
    
    return generation_count, generated_files, avg_generation_time


def evaluate_existing_results(args, tokenizer, input_dir,res_dir):
    """
    Evaluate existing result files
    
    Args:
        args: Command line parameters
        tokenizer: Tokenizer
        input_dir: Input directory
        
    Returns:
        tuple: (eval_count, succ_count, evaluation_results, format_scores, preference_scores)
    """
    print("\n" + "="*50)
    print("EVALUATION MODE - Evaluating existing results")
    print("="*50)
    
    # Find all result files
    result_files = glob.glob(os.path.join(input_dir, "*.json"))
    if not result_files:
        print(f"No result files found in {input_dir}")
        return 0, 0, [], [], []
    
    print(f"Found {len(result_files)} result files to evaluate")
    
    eval_count = 0
    succ_count = 0
    evaluation_results = []
    format_scores = []
    preference_scores = []
    
    for i, result_file in enumerate(result_files):
        try:
            with open(result_file, 'r', encoding='utf-8') as f:
                plan_data = json.load(f)
            
            data_idx = plan_data.get("data_idx", f"case_{i}")
            if  os.path.basename(result_file) == "evaluation_summary.json":
                continue
            print("------------------------------")
            print(f"Evaluating [{i + 1}/{len(result_files)}]:")
            print("Data ID:", data_idx)
            print("File:", os.path.basename(result_file))

            eval_count += 1
            
            # Check if this is a valid generated result
            if plan_data.get("error") or not plan_data.get("llm_response") or '"error"' in plan_data.get("llm_response",'error'):
                print(f"✗ Skipping {data_idx} - generation failed or no response")
                # Record failed evaluation
                format_scores.append(-3.0)
                preference_scores.append(0.0)
                evaluation_results.append({
                    "data_idx": data_idx,
                    "format_score": -3.0,
                    "total_score": -3.0,
                    "preference_score": 0.0,
                    "overall_success": False,
                    "evaluation_error": "Generation failed or no response",
                    "generation_time": plan_data.get("generation_time", 0.0)
                })
                continue
            
            # Check if already has evaluation results (unless force re-evaluation)
            if not args.force_reeval and plan_data.get("evaluation_details"):
                print(f"Loading existing evaluation results for {data_idx}")
                # Load existing evaluation results
                evaluation_details = plan_data.get("evaluation_details", {})
                total_score = plan_data.get("total_score", -3.0)
                is_success = plan_data.get("evaluation_success", False)
                
                # Update success count
                if is_success:
                    succ_count += 1
                    print(f"✓ Existing success for {data_idx} (Total Score: {total_score:.3f})")
                else:
                    print(f"✗ Existing failed for {data_idx} (Total Score: {total_score:.3f})")
                
                # Record scores from existing evaluation
                format_score = evaluation_details.get("format_score", -3.0)
                preference_score = evaluation_details.get("preference_score", 0.0)
                
            else:

                # If the evaluation file exists, skip
                output_file = os.path.join(res_dir, os.path.basename(result_file).replace(".json", "_eval.json"))
                if os.path.exists(output_file):
                    print('already eval,skip', output_file)
                    with open(output_file, 'r', encoding='utf-8') as f:
                        plan_data = json.load(f)
                    data_idx = plan_data.get("data_idx", f"case_{i}")
                    format_score = plan_data.get("format_score", -3.0)
                    preference_score = plan_data.get("preference_score", 0.0)
                    total_score = plan_data.get("total_score", -3.0)
                    is_success = plan_data.get("evaluation_success", False)
                    evaluation_details = plan_data.get("evaluation_details", {})
                    format_scores.append(format_score)
                    preference_scores.append(preference_score)
                    generation_time = plan_data.get("generation_time", 0.0)  # Read generation_time from file
                    evaluation_results.append({
                        "data_idx": data_idx,
                        "format_score": format_score,
                        "preference_score": preference_score,
                        "total_score": total_score,
                        "overall_success": is_success,
                        "evaluation_details": evaluation_details,
                        "generation_time": generation_time
                    })
                    continue


                # Perform new evaluation
                print("  Evaluating plan quality...")
                
                # Construct case_data and query_i from saved data
                poi_dict = plan_data.get("poi_dict", {})
                messages = plan_data.get("messages", [])
                
                case_data = {
                    "poi_dict": poi_dict,
                    "messages": messages,
                    "case_index": plan_data.get("case_index", i + 1)
                }
                
                query_i = {
                    "userQuery": poi_dict.get("userQuery", ""),
                    "messages": messages,
                    **poi_dict
                }
                
                # Evaluate plan quality
                is_success, total_score, evaluation_details = evaluate_plan_quality(
                    tokenizer, plan_data, query_i, case_data,
                    enable_user_request_eval=args.enable_user_request_eval,
                    enable_LLM=args.enable_LLM,
                    debug=True
                )
                
                # Update success based on evaluation
                if is_success:
                    succ_count += 1
                    print(f"✓ Success for {data_idx} (Total Score: {total_score:.3f})")
                else:
                    print(f"✗ Failed for {data_idx} (Total Score: {total_score:.3f})")

                # Record scores
                format_score = evaluation_details.get("format_score", -3.0)
                preference_score = evaluation_details.get("preference_score", 0.0)
                
                # Update plan data with evaluation results
                plan_data.update({
                    "evaluation_success": is_success,
                    "total_score": total_score,
                    "format_score": format_score,
                    "preference_score": preference_score,
                    "evaluation_details": evaluation_details,
                    "evaluated_with_config": {
                        "enable_user_request_eval": args.enable_user_request_eval,
                        "enable_LLM": args.enable_LLM
                    }
                })

                # Save updated results
                save_json_file(json_data=plan_data, file_path=output_file)
                print(f"Updated evaluation results in {output_file}")
            
            # Record scores in summary (both for existing and new evaluations)
            format_scores.append(format_score)
            preference_scores.append(preference_score)
            generation_time = plan_data.get("generation_time", 0.0)  # Read generation_time from file
            evaluation_results.append({
                "data_idx": data_idx,
                "format_score": format_score,
                "preference_score": preference_score,
                "total_score": total_score,
                "overall_success": is_success,
                "evaluation_details": evaluation_details,
                "generation_time": generation_time
            })

        except Exception as e:
            print(f"✗ Error evaluating {os.path.basename(result_file)}: {e}")
            format_scores.append(-3.0)
            preference_scores.append(0.0)
            evaluation_results.append({
                "data_idx": data_idx if 'data_idx' in locals() else f"case_{i}",
                "format_score": -3.0,
                "total_score": -3.0,
                "preference_score": 0.0,
                "overall_success": False,
                "evaluation_error": str(e),
                "generation_time": plan_data.get("generation_time", 0.0) if 'plan_data' in locals() else 0.0
            })
    
    return eval_count, succ_count, evaluation_results, format_scores, preference_scores


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Flexible Trip Planning Experiment Runner")
    
    # Mode control
    parser.add_argument(
        "--mode",
        type=str,
        default="both",
        choices=["generate_only", "evaluate_only", "both"],
        help="Execution mode: generate_only, evaluate_only, both"
    )
    
    # Data and directory arguments
    parser.add_argument(
        "--splits",
        "-s",
        type=str,
        default="synthesis",
        help="query subset (synthesis, generalized)",
    )
    parser.add_argument("--index", "-id", type=int, default=None, help="query index") 
    parser.add_argument(
        "--skip", "-sk", type=int, default=0, help="skip if the plan exists"
    )
    parser.add_argument('--restart_from', type=int, default=None, help='Restart Data ID')
    parser.add_argument(
        "--input_dir", 
        type=str, 
        default=None, 
        help="Input directory for evaluate_only mode"
    )
    parser.add_argument(
        "--output_dir", 
        type=str, 
        default=None, 
        help="Custom output directory"
    )
    parser.add_argument(
        "--force_reeval", 
        action="store_true", 
        help="Force re-evaluation of already evaluated results"
    )
    
    # Agent arguments
    parser.add_argument(
        "--agent",
        "-a",
        type=str,
        default="Direct",
        choices=["LLMNeSy", "LLM-modulo", "Direct", "CoT", "HyperTree", 'TTG'],
        help="Agent type to use"
    )
    parser.add_argument(
        "--llm",
        "-l",
        type=str,
        default="gpt-4o",
        choices=["gpt-4o", "gemini", 'deepseek','Qwen3-8B','Qwen3-14B','Qwen3-32B'],
        help="LLM model to use"
    )
    
    # Evaluation control arguments
    parser.add_argument(
        "--enable_user_request_eval", 
        action="store_true", 
        default=False,
        help="Enable user request evaluation"
    )
    parser.add_argument(
        "--enable_LLM", 
        action="store_true", 
        default=True,
        help="Enable LLM evaluation (classic attractions coverage and itinerary diversity)"
    )
    
    # Agent-specific arguments
    parser.add_argument('--oracle_translation', action='store_true', help='Set this flag to enable oracle translation.')
    parser.add_argument('--preference_search', action='store_true', help='Set this flag to enable preference search.')
    parser.add_argument('--refine_steps', type=int, default=3, help='Steps for refine-based method, such as LLM-modulo')

    args = parser.parse_args()

    print("Arguments:", args)
    print(f"Mode: {args.mode}")
    print(f"User Request Eval: {'enable' if args.enable_user_request_eval else 'disable'}")
    print(f"LLM Eval: {'enable' if args.enable_LLM else 'disable'}")

    # Initialize tokenizer
    print("加载tokenizer...")
    model_name = "Qwen/Qwen3-14B"
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    # Create method name for directory structure
    method = args.agent + "_" + args.llm
    if args.agent == "LLM-modulo":
        method += f"_{args.refine_steps}steps"
        if not args.oracle_translation:
            print("Warning: LLM-modulo works best with oracle translation enabled")

    if args.oracle_translation:
        method = method + "_oracletranslation"
    if args.preference_search:
        method = method + "_preferencesearch"

    # Set up directories
    if args.output_dir:
        res_dir = args.output_dir
    else:
        res_dir = os.path.join(project_root_path, "results", f"{args.splits}", method)
    
    log_dir = os.path.join(project_root_path, "cache", method)
    
    if not os.path.exists(res_dir):
        os.makedirs(res_dir)
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    print("Results directory:", res_dir)
    print("Log directory:", log_dir)

    # Initialize variables
    succ_count, eval_count, generation_count = 0, 0, 0
    format_scores = []
    preference_scores = []
    evaluation_results = []
    generated_files = []
    avg_generation_time = 0.0

    poi_analyzer = POIAnalyzer(use_api=False)
    llm = Gemini()
    main_evaluator = MainEvaluator(
        enable_user_request_eval=args.enable_user_request_eval,
        enable_LLM=args.enable_LLM,
        poi_analyzer=poi_analyzer,
        llm=llm
    )

    if args.mode in ["generate_only", "both"]:
        # Generation phase
        print("\nStarting generation phase...")
        
        # Load query data
        query_data_list = load_query(args)
        print(f"Loaded {len(query_data_list)} samples")

        # Filter by index if specified
        if args.index is not None:
            filtered_list = []
            for case in query_data_list:
                if case.get("message_id") == args.index or case.get("original_case", {}).get("message_id") == args.index:
                    filtered_list.append(case)
            
            if filtered_list:
                query_data_list = filtered_list
                print(f"Filtered to {len(query_data_list)} samples matching index {args.index}")
            else:
                print(f"Warning: Index {args.index} not found in data. Using all samples.")

        # Set model context length based on agent type
        if args.agent in ["LLM-modulo"]:
            max_model_len = 8192
        elif args.agent in ["LLMNeSy"]:
            max_model_len = 8192
        elif args.agent == "HyperTree":
            max_model_len = 16384
        else:
            max_model_len = 7000

        # Initialize agent
        cache_dir = os.path.join(project_root_path, "cache")
        kwargs = {
            "method": args.agent,
            "env": "",
            "backbone_llm": init_llm(args.llm, max_model_len=max_model_len),
            "cache_dir": cache_dir,
            "log_dir": log_dir, 
            "debug": True,
            "refine_steps": args.refine_steps,
            "main_evaluator": main_evaluator,
            "tokenizer": tokenizer,
            "poi_analyzer": poi_analyzer
        }
        agent = init_agent(kwargs)
        print(f"Initialized {args.agent} agent with {args.llm} model")

        # Generate plans
        generation_count, generated_files, avg_generation_time = generate_plans(
            args, query_data_list, agent, method, res_dir, log_dir
        )
        print(f"\nGeneration completed: {generation_count} plans generated")
        print(f"Average generation time: {avg_generation_time:.2f} seconds per plan")

    if args.mode in ["evaluate_only", "both"]:
        # Evaluation phase
        print("\nStarting evaluation phase...")
        
        # Determine input directory
        if args.mode == "evaluate_only":
            if args.input_dir:
                input_dir = args.input_dir
            else:
                input_dir = res_dir
            print(f"Input directory: {input_dir}")
        else:
            input_dir = res_dir
        
        # Evaluate existing results
        eval_count, succ_count, eval_results, eval_format_scores, eval_preference_scores = evaluate_existing_results(
            args, tokenizer, input_dir,res_dir
        )
        
        # Merge results
        evaluation_results.extend(eval_results)
        format_scores.extend(eval_format_scores)
        preference_scores.extend(eval_preference_scores)
        
        print(f"\nEvaluation completed: {eval_count} evaluations, {succ_count} successful")

    # Calculate additional statistics
    dr_count = 0  
    cpr_count = 0  
    total_scores = []  
    commonsense_scores = [] 
    user_request_scores = [] 
    generation_times = [] 
    
    for result in evaluation_results:
        format_score = result.get("format_score", -3.0)
        total_score = result.get("total_score", -3.0)
        generation_time = result.get("generation_time", 0.0)
        
       
        if format_score == 1.0:
            dr_count += 1
            
        
        evaluation_details = result.get("evaluation_details", {})
        commonsense_score = evaluation_details.get("commonsense_score", 0.0)
        user_request_score = evaluation_details.get("user_request_score", 0.0)
        
        commonsense_scores.append(commonsense_score)
        user_request_scores.append(user_request_score)
        generation_times.append(generation_time)
        
        
        if commonsense_score > 0.9:
            cpr_count += 1
            
        
        total_scores.append(total_score)
    
    # Calculate statistical indicators
    dr = dr_count / eval_count * 100 if eval_count > 0 else 0  # DR ratio
    cpr = cpr_count / eval_count * 100 if eval_count > 0 else 0  # CPR ratio
    avg_score = sum(total_scores) / len(total_scores) if total_scores else 0  # Average total_score
    avg_time = sum(generation_times) / len(generation_times) if generation_times else 0  # Average generation time from file

    # Save evaluation summary
    summary_file = os.path.join(res_dir, "evaluation_summary.json")
    summary_data = {
        "execution_mode": args.mode,
        "experiment_config": {
            "agent": args.agent,
            "llm": args.llm,
            "oracle_translation": args.oracle_translation,
            "preference_search": args.preference_search,
            "refine_steps": args.refine_steps if args.agent == "LLM-modulo" else None,
            "enable_user_request_eval": args.enable_user_request_eval,
            "enable_LLM": args.enable_LLM
        },
        "overall_stats": {
            "generation_count": generation_count,
            "evaluation_count": eval_count,
            "successful": succ_count,
            "success_rate": succ_count/eval_count*100 if eval_count > 0 else 0,
            # New statistical indicators
            "DR": dr, 
            "CPR": cpr, 
            "avg_score": avg_score, 
            "avg_time": avg_time,  
            # Keep original indicators
            "avg_format_score": sum(format_scores)/len(format_scores) if format_scores else 0,
            "avg_preference_score": sum(preference_scores)/len(preference_scores) if preference_scores else 0,
            "avg_commonsense_score": sum(commonsense_scores)/len(commonsense_scores) if commonsense_scores else 0,
            "avg_user_request_score": sum(user_request_scores)/len(user_request_scores) if user_request_scores else 0,
            "avg_generation_time": avg_generation_time
        },
        "detailed_results": evaluation_results
    }
    save_json_file(json_data=summary_data, file_path=summary_file)

    print("\n" + "="*50)
    print("EXPERIMENT COMPLETED")
    print(f"Mode: {args.mode}")
    if generation_count > 0:
        print(f"Generated: {generation_count} plans")
        print(f"Average generation time: {avg_generation_time:.2f} seconds per plan")
    if eval_count > 0:
        print(f"Evaluated: {eval_count} plans")
        print(f"Successful: {succ_count}")
        print(f"Success rate: {succ_count/eval_count*100:.3f}%")
        print(f"DR (format_score=1): {dr:.3f}%")
        print(f"CPR (commonsense_score>0.9): {cpr:.3f}%")
        print(f"Average total score: {avg_score:.4f}")
        print(f"Average generation time: {avg_time:.3f} seconds")
        print(f"Average format score: {sum(format_scores)/len(format_scores):.3f}" if format_scores else "N/A")
        print(f"Average preference score: {sum(preference_scores)/len(preference_scores):.3f}" if preference_scores else "N/A")
        print(f"Average commonsense score: {sum(commonsense_scores)/len(commonsense_scores):.3f}" if commonsense_scores else "N/A")
        print(f"Average user request score: {sum(user_request_scores)/len(user_request_scores):.3f}" if user_request_scores else "N/A")
    print(f"Results directory: {res_dir}")
    print(f"Evaluation summary: {summary_file}")
    print("="*50) 
