from vita.data_model.simulation import RewardInfo, EvaluationType, SimulationRun, TerminationReason
from vita.data_model.tasks import Task
from vita.evaluator.evaluator_rubric import RubricEvaluator
from vita.evaluator.evaluator_traj import TrajectoryEvaluator
from vita.evaluator.evaluator_traj_wo_user import TrajectoryEvaluatorWithoutUser
from collections import Counter
import random
import pandas as pd


def evaluate_simulation(
    simulation: SimulationRun,
    task: Task,
    evaluation_type: EvaluationType,
    domain: str,
    llm_evaluator: str = None,
    llm_args_evaluator: dict = None,
    language: str = None,
) -> RewardInfo:
    """
    Evaluate the simulation based on the evaluation type.
    """
    if simulation.termination_reason in {
        TerminationReason.TOO_MANY_ERRORS,
        TerminationReason.MAX_STEPS,
        TerminationReason.INVALID_AGENT_MESSAGE,
    }:
        return RewardInfo(
            reward=0.0,
            info={
                "note": f"Simulation terminated prematurely. Termination reason: {simulation.termination_reason}"
            },
        )
    if task.evaluation_criteria is None:
        return RewardInfo(
            reward=1.0,
            info={"note": "No evaluation criteria"},
        )
    if evaluation_type == "nl_rubrics":
        reward_info = RubricEvaluator.calculate_reward(
            rubric_type="separate",
            task=task,
            final_state=simulation.states,
            final_messages=simulation.messages,
            llm_evaluator=llm_evaluator,
            llm_args_evaluator=llm_args_evaluator,
            language=language,
        )
    elif evaluation_type == "all":
        reward_info = RubricEvaluator.calculate_reward(
            rubric_type="combined",
            task=task,
            final_state=simulation.states,
            final_messages=simulation.messages,
            llm_evaluator=llm_evaluator,
            llm_args_evaluator=llm_args_evaluator,
            language=language,
        )
    elif evaluation_type == "trajectory":
        reward_info = TrajectoryEvaluator.calculate_reward(
            task=task,
            full_trajectory=simulation.messages,
            final_state=simulation.states,
            llm_evaluator=llm_evaluator,
            llm_args_evaluator=llm_args_evaluator,
            language=language,
        )
    elif evaluation_type == "trajectory_ablation1":
        reward_info = TrajectoryEvaluator.calculate_reward_ablation1(
            task=task,
            full_trajectory=simulation.messages,
            final_state=simulation.states,
            llm_evaluator=llm_evaluator,
            llm_args_evaluator=llm_args_evaluator,
            language=language,
        )
    elif evaluation_type == "trajectory_ablation2":
        reward_info = TrajectoryEvaluator.calculate_reward_ablation2(
            task=task,
            full_trajectory=simulation.messages,
            final_state=simulation.states,
            llm_evaluator=llm_evaluator,
            llm_args_evaluator=llm_args_evaluator,
            language=language,
        )
    elif evaluation_type == "trajectory_ablation3":
        reward_info = TrajectoryEvaluator.calculate_reward_ablation3(
            task=task,
            full_trajectory=simulation.messages,
            final_state=simulation.states,
            llm_evaluator=llm_evaluator,
            llm_args_evaluator=llm_args_evaluator,
            language=language,
        )
    elif evaluation_type == "trajectory_wo_user":
        reward_info = TrajectoryEvaluatorWithoutUser.calculate_reward(
            task=task,
            full_trajectory=simulation.messages,
            final_state=simulation.states,
            llm_evaluator=llm_evaluator,
            llm_args_evaluator=llm_args_evaluator,
            language=language,
        )
    else:
        raise ValueError(f"Unknown evaluation type: {evaluation_type}")
    return reward_info


def evaluate_simulation_multiple_times(
    simulation: SimulationRun,
    task: Task,
    evaluation_type: EvaluationType,
    domain: str,
    llm_evaluator: str = None,
    llm_args_evaluator: dict = None,
    language: str = None,
    num_evaluations: int = 1,
) -> RewardInfo:
    """
    Evaluate the simulation multiple times and return the mode (most frequent) reward.
    
    Args:
        simulation: The simulation to evaluate
        task: The task being evaluated
        evaluation_type: Type of evaluation to perform
        domain: Domain name
        llm_evaluator: LLM evaluator name
        llm_args_evaluator: LLM evaluator arguments
        language: Language for evaluation
        num_evaluations: Number of independent evaluations to perform (default: 3)
    
    Returns:
        RewardInfo: The final reward info with mode reward and evaluation details
    """
    # Handle special cases first
    if simulation.termination_reason in {
        TerminationReason.TOO_MANY_ERRORS,
        TerminationReason.MAX_STEPS,
    }:
        return RewardInfo(
            reward=0.0,
            info={
                "note": f"Simulation terminated prematurely. Termination reason: {simulation.termination_reason}"
            },
        )
    
    if task.evaluation_criteria is None:
        return RewardInfo(
            reward=1.0,
            info={"note": "No evaluation criteria"},
        )
    
    # Perform multiple evaluations
    evaluation_results = []
    for i in range(num_evaluations):
        # Add some randomness to ensure independence
        if llm_args_evaluator is None:
            llm_args_evaluator = {}
        
        # Create a copy of llm_args_evaluator with slight variation for independence
        eval_args = llm_args_evaluator.copy()
        base_temp = eval_args.get('temperature', 0.7)
        temp_variation = random.uniform(-0.1, 0.1)
        eval_args['temperature'] = max(0.0, min(1.0, base_temp + temp_variation))
        
        reward_info = evaluate_simulation(
            simulation=simulation,
            task=task,
            evaluation_type=evaluation_type,
            domain=domain,
            llm_evaluator=llm_evaluator,
            llm_args_evaluator=eval_args,
            language=language,
        )
        evaluation_results.append(reward_info)
    
    # Calculate mode (most frequent reward value)
    rewards = [result.reward for result in evaluation_results]
    reward_counter = Counter(rewards)
    mode_reward = reward_counter.most_common(1)[0][0]
    
    # If there's a tie, use the highest reward among the tied values
    if len(reward_counter) > 1 and reward_counter.most_common(2)[0][1] == reward_counter.most_common(2)[1][1]:
        tied_rewards = [reward for reward, count in reward_counter.most_common() 
                       if count == reward_counter.most_common(1)[0][1]]
        mode_reward = max(tied_rewards)
    
    # Find the evaluation result that matches the mode reward
    mode_result = next((result for result in evaluation_results if result.reward == mode_reward), evaluation_results[0])
    
    # Calculate average reward_breakdown across all evaluations
    avg_reward_breakdown = None
    if evaluation_results[0].reward_breakdown is not None:
        # Collect all reward_breakdown dictionaries
        breakdowns = [result.reward_breakdown for result in evaluation_results if result.reward_breakdown is not None]
        if breakdowns:
            # Get all unique keys from all breakdowns
            all_keys = set()
            for breakdown in breakdowns:
                all_keys.update(breakdown.keys())
            
            # Calculate average for each key
            avg_reward_breakdown = {}
            for key in all_keys:
                values = [breakdown.get(key, 0.0) for breakdown in breakdowns]
                avg_reward_breakdown[key] = sum(values) / len(values)
    
    # Create final reward info with individual evaluation details
    final_info = {}
    
    # Record each individual evaluation result
    for i, result in enumerate(evaluation_results):
        eval_info = {
            "reward": result.reward,
            "nl_rubrics": result.nl_rubrics,
            "reward_breakdown": result.reward_breakdown,
            "info": result.info,
            "window_evaluations": result.window_evaluations
        }
        final_info[f"evaluation_index_{i+1}"] = eval_info
    
    return RewardInfo(
        reward=mode_reward,
        reward_breakdown=avg_reward_breakdown,  # Use average reward_breakdown
        info=final_info
    )
