import json
import multiprocessing
import random
import uuid
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Optional, Any
from datetime import datetime

from loguru import logger

from vita.agent.llm_agent import LLMAgent, LLMSoloAgent
from vita.data_model.simulation import (
    AgentInfo,
    Info,
    Results,
    RunConfig,
    SimulationRun,
    UserInfo,
    TerminationReason,
)

from vita.data_model.tasks import Task
from vita.data_model.simulation import EvaluationType
from vita.environment.environment import get_cross_environment, EnvironmentInfo
from vita.evaluator.evaluator import evaluate_simulation, evaluate_simulation_multiple_times
from vita.metrics.agent_metrics import compute_metrics
from vita.orchestrator.orchestrator import Orchestrator
from vita.registry import RegistryInfo, registry
from vita.user.user_simulator import get_global_user_sim_guidelines
from vita.utils.display import ConsoleDisplay
from vita.utils.pydantic_utils import get_pydantic_hash
from vita.utils.utils import DATA_DIR, get_commit_hash, get_now, show_dict_diff, global_time
from vita.utils.csv_utils import save_results_to_csv


def get_options() -> RegistryInfo:
    """
    Returns options for the simulator.
    """
    return registry.get_info()


def get_environment_info(
    domain_name: str, include_tool_info: bool = False
) -> EnvironmentInfo:
    """Get information about the environment for a registered Domain"""
    return EnvironmentInfo(
            domain_name=domain_name,
            tool_defs=None
        )


def load_tasks(task_set_name: str, language: str = None) -> list[Task]:
    """
    Loads the tasks for the given domain.
    """
    global registry
    if ',' in task_set_name:
        task_loader = registry.get_tasks_loader("cross_domain")
    else:
        task_loader = registry.get_tasks_loader(task_set_name)
    tasks = task_loader(language)
    return tasks


def get_tasks(
    task_set_name: str,
    task_ids: Optional[list[str]] = None,
    num_tasks: Optional[int] = None,
    language: str = None,
) -> list[Task]:
    """
    Loads the tasks for the given domain.
    """
    if task_ids is None and num_tasks is None:
        return load_tasks(task_set_name=task_set_name, language=language)
    tasks = []
    if task_ids is not None:
        tasks = [
            task for task in load_tasks(task_set_name=task_set_name, language=language) if task.id in task_ids
        ]
        if len(tasks) != len(task_ids):
            missing_tasks = set(task_ids) - set([task.id for task in tasks])
            raise ValueError(
                f"Not all tasks were found for task set {task_set_name}: {missing_tasks}"
            )
    if num_tasks is not None:
        tasks = load_tasks(task_set_name=task_set_name, language=language)[:num_tasks]
    return tasks


def make_run_name(config: RunConfig) -> str:
    """
    Make a run name from the run config
    """
    clean_llm_agent_name = config.llm_agent.split("/")[-1]
    agent_name = f"{config.agent}_{clean_llm_agent_name}"

    clean_llm_user_name = config.llm_user.split("/")[-1]
    user_name = f"{config.user}_{clean_llm_user_name}"

    # Add think mode indicator to the filename if enable_think is True
    think_suffix = "_think" if config.enable_think else ""
    
    return f"{get_now()}_{config.domain}_{agent_name}_{user_name}{think_suffix}"


def run_domain(config: RunConfig) -> Results:
    """
    Run simulations for a domain
    Returns:
        Results: The simulation results
    """
    config.validate()
    ConsoleDisplay.display_run_config(config)
    
    # Check if this is a re-evaluation mode with optional re-run
    if hasattr(config, 're_evaluate_file') and config.re_evaluate_file:
        results = re_evaluate_simulation(config)
        return results
    
    if config.task_set_name is None:
        task_set_name = config.domain
    else:
        task_set_name = config.task_set_name
    tasks = get_tasks(task_set_name, config.task_ids, config.num_tasks, config.language)

    num_trials = config.num_trials
    save_to = config.save_to
    if save_to is None:
        save_to = f"{make_run_name(config)}.json"
    save_to = DATA_DIR / "simulations" / save_to
    config.save_to = save_to
    
    # Run simulations with the specified evaluation type
    simulation_results = run_tasks(
        domain=config.domain,
        tasks=tasks,
        agent=config.agent,
        user=config.user,
        llm_agent=config.llm_agent,
        llm_args_agent=config.llm_args_agent,
        llm_user=config.llm_user,
        llm_args_user=config.llm_args_user,
        num_trials=num_trials,
        max_steps=config.max_steps,
        max_errors=config.max_errors,
        save_to=save_to,
        console_display=True,
        evaluation_type=config.evaluation_type,
        max_concurrency=config.max_concurrency,
        seed=config.seed,
        log_level=config.log_level,
        enable_think=config.enable_think,
        enable_memory=config.enable_memory,
        memory_llm=config.memory_llm,
        memory_llm_args=config.memory_llm_args,
        llm_evaluator=config.llm_evaluator,
        llm_args_evaluator=config.llm_args_evaluator,
        language=config.language,
    )
    
    metrics = compute_metrics(simulation_results)
    ConsoleDisplay.display_agent_metrics(metrics)

    if config.csv_output_file and simulation_results.simulations:
        try:
            csv_output = config.csv_output_file
            save_results_to_csv(simulation_results, csv_output, config, metrics)
            ConsoleDisplay.console.print(f"\n💾 [bold green]Results appended to CSV: {csv_output}[/bold green]")
        except Exception as e:
            ConsoleDisplay.console.print(f"\n[bold red]Error saving to CSV: {e}[/bold red]")

    return simulation_results


def run_tasks(
    domain: str,
    tasks: list[Task],
    agent: str,
    user: str,
    llm_agent: Optional[str] = None,
    llm_args_agent: Optional[dict] = None,
    llm_user: Optional[str] = None,
    llm_args_user: Optional[dict] = None,
    num_trials: int = 1,
    max_steps: int = 100,
    max_errors: int = 10,
    save_to: Optional[str | Path] = None,
    console_display: bool = True,
    evaluation_type: EvaluationType = "nl_rubrics",
    max_concurrency: int = 1,
    seed: Optional[int] = 300,
    log_level: Optional[str] = "INFO",
    enable_think: bool = False,
    enable_memory: bool = False,
    memory_llm: Optional[str] = None,
    memory_llm_args: Optional[dict] = None,
    llm_evaluator: Optional[str] = None,
    llm_args_evaluator: Optional[dict] = None,
    language: str = None,
) -> Results:
    """
    Runs tasks for a given domain.
    If llm_as_judge is True, the LLM will be used to annotate the simulation run.
    Calculates the reward for the simulation run.
    Args:
        domain (str): The domain to run the simulation on.
        tasks (list[Task]): The tasks to run.
        agent (str): The agent to run the simulation on.
        user (str): The user to run the simulation on.
        llm_agent (str): The model to use for the agent.
        llm_args_agent (dict): The arguments to pass to the LLM for the agent.
        llm_user (str): The model to use for the user.
        llm_args_user (dict): The arguments to pass to the LLM for the user.
        num_trials (int): The number of trials to run the simulation on.
        max_steps (int): The maximum number of steps to run the simulation.
        max_errors (int): The maximum number of errors to allow in the simulation.
        save_to (str | Path): The path to json file where to save the simulation results. If the file already exists, it will try to resume the run.
        console_display (bool): Whether to display the simulation results in the console.
        evaluation_type (EvaluationType): The type of evaluation to use.
        max_concurrency (int): The maximum number of concurrent simulations to run.
        seed (int): The seed to use for the simulation.
        log_level (str): The log level to use.
        enable_think (bool): Whether to enable think mode for the agent LLM.
        enable_memory (bool): Whether to enable memory compression for the agent LLM.
        memory_llm (str): The LLM to use for memory compression.
        memory_llm_args (dict): The arguments to pass to the memory LLM.
    Returns:
        The simulation results and the annotations (if llm_review is True).
    """
    if isinstance(save_to, str):
        save_to = Path(save_to)
    # Set log level from config
    logger.remove()
    logger.add(lambda msg: print(msg), level=log_level)
    if len(tasks) == 0:
        raise ValueError("No tasks to run")
    if num_trials <= 0:
        raise ValueError("Number of trials must be greater than 0")
    if max_steps <= 0:
        raise ValueError("Max steps must be greater than 0")
    if max_errors <= 0:
        raise ValueError("Max errors must be greater than 0")

    random.seed(seed)

    seeds = [random.randint(0, 1000000) for _ in range(num_trials)]
    if "seed" in llm_args_agent:
        logger.warning("Each trial will modify the seed for the agent")

    if "seed" in llm_args_user:
        logger.warning("Each trial will modify the seed for the user")

    lock = multiprocessing.Lock()

    info = get_info(
        domain=domain,
        agent=agent,
        user=user,
        llm_agent=llm_agent,
        llm_args_agent=llm_args_agent,
        llm_user=llm_user,
        llm_args_user=llm_args_user,
        num_trials=num_trials,
        max_steps=max_steps,
        max_errors=max_errors,
        seed=seed,
    )
    simulation_results = Results(
        info=info,
        tasks=tasks,
        simulations=[],
    )
    done_runs = set()
    if save_to is not None:
        # If save_to already exists, check if the user wants to resume the run.
        if save_to.exists():
            response = (
                ConsoleDisplay.console.input(
                    "[yellow]File [bold]{}[/bold] already exists. Do you want to resume the run? (y/n)[/yellow] ".format(
                        save_to
                    )
                )
                .lower()
                .strip()
            )
            if response != "y":
                raise FileExistsError(
                    f"File {save_to} already exists. Please delete it or use a different save_to name."
                )
            with open(save_to, "r") as fp:
                prev_simulation_results = Results.model_validate_json(fp.read())
                # Check if the run config has changed
                if get_pydantic_hash(prev_simulation_results.info) != get_pydantic_hash(
                    simulation_results.info
                ):
                    diff = show_dict_diff(
                        prev_simulation_results.info.model_dump(),
                        simulation_results.info.model_dump(),
                    )
                    ConsoleDisplay.console.print(
                        f"The run config has changed.\n\n{diff}\n\nDo you want to resume the run? (y/n)"
                    )
                    response = (
                        ConsoleDisplay.console.input(
                            "[yellow]File [bold]{}[/bold] already exists. Do you want to resume the run? (y/n)[/yellow] ".format(
                                save_to
                            )
                        )
                        .lower()
                        .strip()
                    )
                    if response != "y":
                        raise ValueError(
                            "The run config has changed. Please delete the existing file or use a different save_to name."
                        )
                # Check if the task set has changed
                if not all(
                    get_pydantic_hash(task) == get_pydantic_hash(prev_task)
                    for task, prev_task in zip(
                        sorted(simulation_results.tasks, key=lambda x: x.id),
                        sorted(prev_simulation_results.tasks, key=lambda x: x.id),
                    )
                ):
                    raise ValueError(
                        "The task set has changed. Please delete the existing file or use a different save_to name."
                    )
                # Check which of the runs have already been done
                done_runs = set(
                    [
                        (sim.trial, sim.task_id, sim.seed)
                        for sim in prev_simulation_results.simulations
                    ]
                )
                simulation_results = prev_simulation_results
                ConsoleDisplay.console.print(
                    f"[bold yellow]Resuming run from {len(done_runs)} runs. {len(tasks) * num_trials - len(done_runs)} runs remaining.[/bold yellow]"
                )
        # Create new save file
        else:
            # Check if save_to exists and create parent directories if needed
            if not save_to.parent.exists():
                save_to.parent.mkdir(parents=True, exist_ok=True)
            logger.info(f"Saving simulation batch to {save_to}")
            with open(save_to, "w") as fp:
                fp.write(simulation_results.model_dump_json(indent=2))

    def _save(simulation: SimulationRun):
        def serialize_datetime(obj):
            if isinstance(obj, datetime):
                return obj.isoformat()
            raise TypeError(f"Object of type {type(obj)} is not JSON serializable")
        if save_to is None:
            return
        with lock:
            with open(save_to, "r") as fp:
                ckpt = json.load(fp)
            
            simulation_dict = simulation.model_dump()
            
            ckpt["simulations"].append(simulation_dict)
            with open(save_to, "w", encoding="utf-8") as fp:
                json.dump(ckpt, fp, indent=2, ensure_ascii=False, default=serialize_datetime)

    def _run(task: Task, trial: int, seed: int, progress_str: str) -> dict:
        ConsoleDisplay.console.print(
            f"[bold green]{progress_str} Running task {task.id}, trial {trial + 1}[/bold green]"
        )
        try:
            simulation = run_task(
                domain=domain,
                task=task,
                agent=agent,
                user=user,
                llm_agent=llm_agent,
                llm_args_agent=llm_args_agent,
                llm_user=llm_user,
                llm_args_user=llm_args_user,
                max_steps=max_steps,
                max_errors=max_errors,
                evaluation_type=evaluation_type,
                seed=seed,
                max_retries=3, # 每个任务重试3次
                enable_think=enable_think,
                enable_memory=enable_memory,
                memory_llm=memory_llm,
                memory_llm_args=memory_llm_args,
                llm_evaluator=llm_evaluator,
                llm_args_evaluator=llm_args_evaluator,
                language=language,
            )
            simulation.trial = trial
            if console_display:
                ConsoleDisplay.display_simulation(simulation, show_details=False)
            _save(simulation)
            return {"status": "success", "simulation": simulation}
        except Exception as e:
            logger.error(f"Error running task {task.id}, trial {trial}: {e}")
            logger.warning(f"Task {task.id}, trial {trial} failed but continuing with other tasks")
            if console_display:
                ConsoleDisplay.console.print(f"[bold red]Task {task.id}, trial {trial} failed: {e}[/bold red]")
            return {"status": "failed", "task_id": task.id, "trial": trial, "error": str(e)}

    args = []
    for i, task in enumerate(tasks):
        for trial in range(num_trials):
            if (trial, task.id, seeds[trial]) in done_runs:
                ConsoleDisplay.console.print(
                    f"[bold yellow]Skipping task {task.id}, trial {trial} because it has already been run.[/bold yellow]"
                )
                continue
            progress_str = f"{i}/{len(tasks)} (trial {trial + 1}/{num_trials})"
            args.append((task, trial, seeds[trial], progress_str))

    with ThreadPoolExecutor(max_workers=max_concurrency) as executor:
        res = list(executor.map(_run, *zip(*args)))
        successful_sims = []
        failed_sims = []
        for sim_result in res:
            if sim_result["status"] == "success":
                successful_sims.append(sim_result["simulation"])
            else:
                failed_sims.append(sim_result)
        
        simulation_results.simulations.extend(successful_sims)
        
        failed_count = len(res) - len(successful_sims)
    
    ConsoleDisplay.console.print(
        f"\n✨ [bold green]Successfully completed all simulations![/bold green]\n"
        f"📊 [bold blue]Statistics:[/bold blue]\n"
        f"  ✅ Successful tasks: {len(successful_sims)}\n"
        f"  ❌ Failed tasks: {len(failed_sims)}\n"
        f"  📝 Total tasks: {len(res)}\n"
        f"To review the simulations, run: [bold blue]vita view[/bold blue]"
    )
    
    if failed_sims:
        ConsoleDisplay.console.print(f"\n[bold red]Failed tasks:[/bold red]")
        for failed_result in failed_sims:
            ConsoleDisplay.console.print(f"  - Task {failed_result['task_id']}, Trial {failed_result['trial']}: {failed_result['error']}")
        
        failed_task_ids = list(set([failed_result['task_id'] for failed_result in failed_sims]))
        ConsoleDisplay.console.print(f"\n[bold red]Failed task IDs:[/bold red] {', '.join(failed_task_ids)}")
    
    return simulation_results


def run_task(
    domain: str,
    task: Task,
    agent: str,
    user: str,
    llm_agent: Optional[str] = None,
    llm_args_agent: Optional[dict] = None,
    llm_user: Optional[str] = None,
    llm_args_user: Optional[dict] = None,
    max_steps: int = 100,
    max_errors: int = 10,
    evaluation_type: EvaluationType = "all",
    seed: Optional[int] = None,
    max_retries: int = 3,
    enable_think: bool = False,
    enable_memory: bool = False,
    memory_llm: Optional[str] = None,
    memory_llm_args: Optional[dict] = None,
    llm_evaluator: Optional[str] = None,
    llm_args_evaluator: Optional[dict] = None,
    language: str = None,
) -> SimulationRun:
    """
    Runs tasks for a given domain.
     If llm_as_judge is True, the LLM will be used to annotate the simulation run.
     Calculates the reward for the simulation run.
     Args:
         domain (str): The domain to run the simulation on.
         task (Task): The task to run.
         agent (str): The agent to run the simulation on.
         user (str): The user to run the simulation on.
         llm_agent (str): The model to use for the agent.
         llm_args_agent (dict): The arguments to pass to the LLM for the agent.
         llm_user (str): The model to use for the user.
         llm_args_user (dict): The arguments to pass to the LLM for the user.
         max_steps (int): The maximum number of steps to run the simulation.
         max_errors (int): The maximum number of errors to allow in the simulation.
         evaluation_type (EvaluationType): The type of evaluation to use.
         seed (int): The seed to use for the simulation.
         max_retries (int): The maximum number of retries if an error occurs.
     Returns:
         The simulation run.
    """

    if max_steps <= 0:
        raise ValueError("Max steps must be greater than 0")
    if max_errors <= 0:
        raise ValueError("Max errors must be greater than 0")

    for attempt in range(max_retries + 1):  # +1 是因为第一次不算重试
        try:
            return _run_task_internal(
                domain=domain,
                task=task,
                agent=agent,
                user=user,
                llm_agent=llm_agent,
                llm_args_agent=llm_args_agent,
                llm_user=llm_user,
                llm_args_user=llm_args_user,
                max_steps=max_steps,
                max_errors=max_errors,
                evaluation_type=evaluation_type,
                seed=seed,
                enable_think=enable_think,
                enable_memory=enable_memory,
                memory_llm=memory_llm,
                memory_llm_args=memory_llm_args,
                llm_evaluator=llm_evaluator,
                llm_args_evaluator=llm_args_evaluator,
                language=language
            )
        except Exception as e:
            if attempt < max_retries:
                logger.warning(f"Task {task.id} failed on attempt {attempt + 1}/{max_retries + 1}: {e}. Retrying...")
                _clear_global_state()
                continue
            else:
                logger.error(f"Task {task.id} failed after {max_retries + 1} attempts. Last error: {e}")
                raise e


def _run_task_internal(
    domain: str,
    task: Task,
    agent: str,
    user: str,
    llm_agent: Optional[str] = None,
    llm_args_agent: Optional[dict] = None,
    llm_user: Optional[str] = None,
    llm_args_user: Optional[dict] = None,
    max_steps: int = 100,
    max_errors: int = 10,
    evaluation_type: EvaluationType = "all",
    seed: Optional[int] = None,
    enable_think: bool = False,
    enable_memory: bool = False,
    memory_llm: Optional[str] = None,
    memory_llm_args: Optional[dict] = None,
    llm_evaluator: Optional[str] = None,
    llm_args_evaluator: Optional[dict] = None,
    language: str = None,
) -> SimulationRun:
    """
    Internal implementation of run_task without retry logic.
    """
    _clear_global_state()
    
    global registry
    logger.info(
        f"STARTING SIMULATION: Domain: {domain}, Task: {task.id}, Agent: {agent}, User: {user}"
    )
    if "," in domain:
        environment = get_cross_environment(domain, task.environment, language)
    else:
        environment_constructor = registry.get_env_constructor(domain)
        environment = environment_constructor(task.environment, language)
    AgentConstructor = registry.get_agent_constructor(agent)

    solo_mode = False
    time = environment.tools.db.time
    global global_time
    global_time = time
    logger.info(f"|| Time Set To: {time}")

    if issubclass(AgentConstructor, LLMAgent):
        agent = AgentConstructor(
            tools=environment.get_tools(),
            domain_policy=environment.get_policy(),
            llm=llm_agent,
            llm_args=llm_args_agent,
            time=time,
            enable_think=enable_think,
            enable_memory=enable_memory,
            memory_llm=memory_llm,
            memory_llm_args=memory_llm_args,
            language=language,
        )
    elif issubclass(AgentConstructor, LLMSoloAgent):
        solo_mode = True
        agent = AgentConstructor(
            tools=environment.get_tools(),
            domain_policy=environment.get_policy(),
            llm=llm_agent,
            llm_args=llm_args_agent,
            time=time,
            enable_think=enable_think,
            language=language,
        )
    else:
        raise ValueError(
            f"Unknown agent type: {AgentConstructor}. Should be LLMAgent or LLMSoloAgent"
        )

    UserConstructor = registry.get_user_constructor(user)

    user = UserConstructor(
        persona=str(task.user_scenario.user_profile),
        instructions=str(task.instructions),
        llm=llm_user,
        llm_args=llm_args_user,
        language=language,
    )

    orchestrator = Orchestrator(
        domain=domain,
        agent=agent,
        user=user,
        environment=environment,
        task=task,
        max_steps=max_steps,
        max_errors=max_errors,
        seed=seed,
        solo_mode=solo_mode,
        language=language
    )
    simulation = orchestrator.run()

    # Handle all_types evaluation - run simulation once but evaluate with all methods
    if evaluation_type == "all_types":
        # Run simulation once, then evaluate with all methods using multiple evaluations
        nl_rubrics_reward = evaluate_simulation_multiple_times(
            domain=domain,
            task=task,
            simulation=simulation,
            evaluation_type="nl_rubrics",
            llm_evaluator=llm_evaluator,
            llm_args_evaluator=llm_args_evaluator,
            language=language,
        )
        
        trajectory_reward = evaluate_simulation_multiple_times(
            domain=domain,
            task=task,
            simulation=simulation,
            evaluation_type="trajectory",
            llm_evaluator=llm_evaluator,
            llm_args_evaluator=llm_args_evaluator,
            language=language,
        )
        
        trajectory_wo_user_reward = evaluate_simulation_multiple_times(
            domain=domain,
            task=task,
            simulation=simulation,
            evaluation_type="trajectory_wo_user",
            llm_evaluator=llm_evaluator,
            llm_args_evaluator=llm_args_evaluator,
            language=language,
        )

        # Create a combined reward info with all evaluation results
        # Store the different evaluation results in the info field
        combined_info = {
            "nl_rubrics_evaluation": {
                "reward": nl_rubrics_reward.reward,
                "nl_rubrics": nl_rubrics_reward.nl_rubrics,
                "reward_breakdown": nl_rubrics_reward.reward_breakdown,
                "info": nl_rubrics_reward.info
            },
            "trajectory_evaluation": {
                "reward": trajectory_reward.reward,
                "nl_rubrics": trajectory_reward.nl_rubrics,
                "reward_breakdown": trajectory_reward.reward_breakdown,
                "window_evaluations": trajectory_reward.window_evaluations,
                "info": trajectory_reward.info
            },
            "trajectory_wo_user_evaluation": {
                "reward": trajectory_wo_user_reward.reward,
                "nl_rubrics": trajectory_wo_user_reward.nl_rubrics,
                "reward_breakdown": trajectory_wo_user_reward.reward_breakdown,
                "window_evaluations": trajectory_wo_user_reward.window_evaluations,
                "info": trajectory_wo_user_reward.info
            },
            "evaluation_methods": ["nl_rubrics", "trajectory", "trajectory_wo_user"]
        }
        
        # Use the trajectory reward as the main reward, but store all results in info
        trajectory_reward.info = combined_info
        simulation.reward_info = trajectory_reward
        
        # Extract multiple evaluation details for logging
        nl_all_rewards = []
        traj_all_rewards = []
        traj_wo_user_all_rewards = []
        
        if nl_rubrics_reward.info:
            # Count how many evaluations we have
            num_evaluations = len([k for k in nl_rubrics_reward.info.keys() if k.startswith("evaluation_index_")])
            nl_all_rewards = [nl_rubrics_reward.info.get(f"evaluation_index_{i+1}", {}).get("reward", 0) 
                             for i in range(num_evaluations) if f"evaluation_index_{i+1}" in nl_rubrics_reward.info]
        
        if trajectory_reward.info:
            # Count how many evaluations we have
            num_evaluations = len([k for k in trajectory_reward.info.keys() if k.startswith("evaluation_index_")])
            traj_all_rewards = [trajectory_reward.info.get(f"evaluation_index_{i+1}", {}).get("reward", 0) 
                               for i in range(num_evaluations) if f"evaluation_index_{i+1}" in trajectory_reward.info]
        
        if trajectory_wo_user_reward.info:
            # Count how many evaluations we have
            num_evaluations = len([k for k in trajectory_wo_user_reward.info.keys() if k.startswith("evaluation_index_")])
            traj_wo_user_all_rewards = [trajectory_wo_user_reward.info.get(f"evaluation_index_{i+1}", {}).get("reward", 0) 
                                       for i in range(num_evaluations) if f"evaluation_index_{i+1}" in trajectory_wo_user_reward.info]
        
        logger.info(
            f"FINISHED SIMULATION: Domain: {domain}, Task: {task.id}, Agent: {agent.__class__.__name__}, User: {user.__class__.__name__}. "
            f"Rewards (Mode) - NL: {nl_rubrics_reward.reward} | {nl_rubrics_reward.reward_breakdown}, "
            f"Trajectory: {trajectory_reward.reward} | {trajectory_reward.reward_breakdown}, "
            f"Trajectory_wo_user: {trajectory_wo_user_reward.reward} | {trajectory_wo_user_reward.reward_breakdown} | "
            f"All Rewards - NL: {nl_all_rewards}, Trajectory: {traj_all_rewards}, Trajectory_wo_user: {traj_wo_user_all_rewards}"
        )
    else:
        # Use multiple evaluations for single evaluation type
        reward_info = evaluate_simulation_multiple_times(
            domain=domain,
            task=task,
            simulation=simulation,
            evaluation_type=evaluation_type,
            llm_evaluator=llm_evaluator,
            llm_args_evaluator=llm_args_evaluator,
            language=language,
        )
        simulation.reward_info = reward_info
        
        # Extract multiple evaluation details for logging
        all_rewards = []
        if reward_info.info:
            # Count how many evaluations we have
            num_evaluations = len([k for k in reward_info.info.keys() if k.startswith("evaluation_index_")])
            all_rewards = [reward_info.info.get(f"evaluation_index_{i+1}", {}).get("reward", 0) 
                          for i in range(num_evaluations) if f"evaluation_index_{i+1}" in reward_info.info]
        
        # Calculate variance
        if len(all_rewards) > 1:
            import statistics
            variance = statistics.variance(all_rewards)
        else:
            variance = 0.0
        
        logger.info(
            f"FINISHED SIMULATION: Domain: {domain}, Task: {task.id}, Agent: {agent.__class__.__name__}, User: {user.__class__.__name__}. "
            f"Reward (Mode): {reward_info.reward} | {reward_info.reward_breakdown} | "
            f"All Rewards: {all_rewards} | Variance: {variance:.4f}"
        )
    
    return simulation


def get_info(
    domain: str,
    agent: str,
    user: str,
    llm_agent: Optional[str] = None,
    llm_args_agent: Optional[dict] = None,
    llm_user: Optional[str] = None,
    llm_args_user: Optional[dict] = None,
    num_trials: int = 1,
    max_steps: int = 100,
    max_errors: int = 10,
    seed: Optional[int] = None,
) -> Info:
    def clean_llm_args(llm_args: Optional[dict]) -> Optional[dict]:
        """Clean LLM arguments to make them JSON serializable"""
        if llm_args is None:
            return None
        
        cleaned = {}
        for key, value in llm_args.items():
            if hasattr(value, '__class__') and value.__class__.__name__ == 'type':
                # Replace type objects with their class name
                cleaned[key] = value.__name__
            else:
                cleaned[key] = value
        return cleaned
    
    user_info = UserInfo(
        implementation=user,
        llm=llm_user,
        llm_args=clean_llm_args(llm_args_user),
        global_simulation_guidelines=get_global_user_sim_guidelines(),
    )
    agent_info = AgentInfo(
        implementation=agent,
        llm=llm_agent,
        llm_args=clean_llm_args(llm_args_agent),
    )
    environment_info = get_environment_info(
        domain, include_tool_info=False
    )
    return Info(
        git_commit=get_commit_hash(),
        num_trials=num_trials,
        max_steps=max_steps,
        max_errors=max_errors,
        user_info=user_info,
        agent_info=agent_info,
        environment_info=environment_info,
        seed=seed,
    )


def _clear_global_state():
    from vita.data_model.tasks import (
        StoreBaseModel, ProductBaseModel
    )

    base_classes = [
        StoreBaseModel, ProductBaseModel
    ]
    
    for base_class in base_classes:
        try:
            if hasattr(base_class, 'clear_thread_data'):
                base_class.clear_thread_data()
        except Exception as e:
            pass


def re_evaluate_simulation(config: RunConfig) -> Results:
    """
    Re-evaluate simulations from a saved simulation file, with optional re-running of specific tasks.
    
    Args:
        config (RunConfig): The run configuration containing:
            - re_evaluate_file (str): Path to the simulation file to load
            - evaluation_type (EvaluationType): The type of evaluation to use
            - save_to (Optional[str | Path]): Path to save the re-evaluation results
            - re_run (bool): Whether to re-run tasks specified by task_ids
            - task_ids (Optional[list[str]]): Task IDs to re-run (only used if re_run is True)

    Returns:
        Results: The re-evaluation results
    """
    re_evaluate_file = config.re_evaluate_file
    evaluation_type = config.evaluation_type
    save_to = config.save_to
    re_run = getattr(config, 're_run', False)
    task_ids_to_rerun = config.task_ids if re_run else None
    
    # Load the original simulation results
    simulation_path = Path(re_evaluate_file)
    if not simulation_path.exists():
        raise FileNotFoundError(f"Simulation file not found: {re_evaluate_file}")
    
    # Load the original results
    with open(simulation_path, "r") as fp:
        original_results = Results.model_validate_json(fp.read())
    
    logger.info(f"Loaded simulation file: {re_evaluate_file}")
    logger.info(f"Found {len(original_results.simulations)} simulations")
    
    # Handle re-running specific tasks if requested
    if re_run and task_ids_to_rerun:
        logger.info(f"Re-running tasks: {task_ids_to_rerun}")
        
        # Get tasks to re-run
        if config.task_set_name is None:
            task_set_name = config.domain
        else:
            task_set_name = config.task_set_name
        
        tasks_to_rerun = get_tasks(task_set_name, task_ids_to_rerun, None, config.language)
        
        # Run the specific tasks
        rerun_results = run_tasks(
            domain=config.domain,
            tasks=tasks_to_rerun,
            agent=config.agent,
            user=config.user,
            llm_agent=config.llm_agent,
            llm_args_agent=config.llm_args_agent,
            llm_user=config.llm_user,
            llm_args_user=config.llm_args_user,
            num_trials=config.num_trials,
            max_steps=config.max_steps,
            max_errors=config.max_errors,
            save_to=None,  # Don't save intermediate results
            console_display=True,
            evaluation_type=evaluation_type,
            max_concurrency=config.max_concurrency,
            seed=config.seed,
            log_level=config.log_level,
            enable_think=config.enable_think,
            llm_evaluator=config.llm_evaluator,
            llm_args_evaluator=config.llm_args_evaluator,
            language=config.language,
        )
        
        # Remove old simulations for the re-run task IDs
        original_simulations = [
            sim for sim in original_results.simulations 
            if sim.task_id not in task_ids_to_rerun
        ]
        
        # Combine original simulations (excluding re-run tasks) with new simulations
        combined_simulations = original_simulations + rerun_results.simulations
        
        # Update original_results with combined simulations
        original_results.simulations = combined_simulations
        
        logger.info(f"Combined {len(original_simulations)} existing simulations with {len(rerun_results.simulations)} re-run simulations")
    
    logger.info(f"Total simulations to re-evaluate: {len(original_results.simulations)}")
    
    # Update tasks list if we re-ran any tasks (to ensure we have the latest task definitions)
    final_tasks = original_results.tasks
    if re_run and task_ids_to_rerun:
        # Create a mapping of task_id to task for efficient lookup
        existing_task_ids = {task.id for task in original_results.tasks}
        new_tasks = [task for task in tasks_to_rerun if task.id not in existing_task_ids]
        if new_tasks:
            final_tasks = original_results.tasks + new_tasks
            logger.info(f"Added {len(new_tasks)} new tasks to the task list")
    
    # Create new results object for re-evaluation
    re_eval_results = Results(
        timestamp=get_now(),
        info=original_results.info,
        tasks=final_tasks,
        simulations=[],
    )
    
    def _re_evaluate_single(simulation, task_dict, domain_name, progress_str):
        logger.info(f"{progress_str} Re-evaluating simulation: {simulation.task_id}")
        
        task = task_dict.get(simulation.task_id)
        if task is None:
            logger.warning(f"Task {simulation.task_id} not found, skipping simulation")
            return {"status": "skipped", "simulation": simulation, "reason": "task_not_found"}
        
        try:
            reward_info = evaluate_simulation(
                simulation=simulation,
                task=task,
                evaluation_type=evaluation_type,
                domain=domain_name,
                llm_evaluator=config.llm_evaluator,
                llm_args_evaluator=config.llm_args_evaluator,
                language=config.language,
            )
            
            # Create a new simulation run with updated reward info
            re_eval_simulation = SimulationRun(
                id=simulation.id,
                task_id=simulation.task_id,
                timestamp=simulation.timestamp,
                start_time=simulation.start_time,
                end_time=simulation.end_time,
                duration=simulation.duration,
                termination_reason=simulation.termination_reason,
                agent_cost=simulation.agent_cost,
                user_cost=simulation.user_cost,
                reward_info=reward_info,  # Updated reward info
                messages=simulation.messages,  # Keep original messages
                states=simulation.states,  # Keep original states
                trial=simulation.trial,
                seed=simulation.seed,
            )
            
            logger.info(f"Re-evaluation completed for {simulation.task_id}: reward = {reward_info.reward}")
            return {"status": "success", "simulation": re_eval_simulation}
            
        except Exception as e:
            logger.error(f"Error re-evaluating simulation {simulation.task_id}: {e}")
            return {"status": "failed", "simulation": simulation, "error": str(e)}
    
    task_dict = {task.id: task for task in final_tasks}
    domain_name = original_results.info.environment_info.domain_name
    
    args = []
    for i, simulation in enumerate(original_results.simulations):
        progress_str = f"({i + 1}/{len(original_results.simulations)})"
        args.append((simulation, task_dict, domain_name, progress_str))
    
    max_concurrency = getattr(config, 'max_concurrency', 4) 
    with ThreadPoolExecutor(max_workers=max_concurrency) as executor:
        results = list(executor.map(_re_evaluate_single, *zip(*args)))
    
    successful_count = 0
    failed_count = 0
    skipped_count = 0
    
    for result in results:
        if result["status"] == "success":
            re_eval_results.simulations.append(result["simulation"])
            successful_count += 1
        elif result["status"] == "failed":
            re_eval_results.simulations.append(result["simulation"])
            failed_count += 1
            print(f"Error of {result['simulation'].task_id} trial {result['simulation'].trial} re-evaluate: {result['error']}")
        elif result["status"] == "skipped":
            re_eval_results.simulations.append(result["simulation"])
            skipped_count += 1
    
    ConsoleDisplay.console.print(
        f"\n✨ [bold green]Re-evaluation completed![/bold green]\n"
        f"📊 [bold blue]Statistics:[/bold blue]\n"
        f"  ✅ Successfully re-evaluated: {successful_count}\n"
        f"  ❌ Failed: {failed_count}\n"
        f"  ⏭️  Skipped: {skipped_count}\n"
        f"  📝 Total: {len(results)}"
    )
    
    metrics = compute_metrics(re_eval_results)
    ConsoleDisplay.display_agent_metrics(metrics)

    if config.csv_output_file and re_eval_results.simulations:
        try:
            csv_output = config.csv_output_file
            save_results_to_csv(re_eval_results, csv_output, config, metrics)
            ConsoleDisplay.console.print(f"\n💾 [bold green]Results appended to CSV: {csv_output}[/bold green]")
        except Exception as e:
            ConsoleDisplay.console.print(f"\n[bold red]Error saving to CSV: {e}[/bold red]")

    # Save results if save_to is specified
    if save_to is not None:
        if isinstance(save_to, str):
            save_to = Path(save_to)
        
        # Create parent directories if needed
        if not save_to.parent.exists():
            save_to.parent.mkdir(parents=True, exist_ok=True)
        
        # Generate filename if not provided
        if save_to.is_dir() or save_to.name == "":
            original_name = simulation_path.stem
            save_to = save_to / f"{original_name}_re_eval_{evaluation_type}.json"
        
        logger.info(f"Saving re-evaluation results to: {save_to}")
        re_eval_results.save(save_to)
    
    return re_eval_results