"""Rule-balanced random policy implementation for the FRAME system.

This policy first groups actions by their rule names, then randomly selects a rule,
and finally randomly selects an action from that rule's available actions. This ensures
each rule type has an equal chance of being selected, regardless of how many valid
actions it may have.
"""

from typing import List, Dict, Any, Optional
from collections import defaultdict

from frame.policies.base import Policy
from frame.knowledge_base.knowledge_graph import KnowledgeGraph
from frame.environments.math_env import ValidAction, MathEnv


class RuleBalancedRandomPolicy(Policy):
    """
    Policy that ensures balanced random selection across different rule types.
    
    This policy first groups all valid actions by their rule names, then randomly
    selects a rule name, and finally randomly selects an action from that rule's
    group. This approach prevents rules with many possible applications (like Specialize)
    from dominating the selection process.
    """
    
    def __init__(self, **kwargs):
        """Initialize the rule-balanced random policy."""
        super().__init__(requires_enumeration=True, **kwargs)
        self.production_rules = None
    
    def set_production_rules(self, production_rules):
        """
        Set the production rules for the policy.
        
        Args:
            production_rules: List of production rules
        """
        self.production_rules = production_rules
        
    def set_rules(self, rules):
        """
        Set the rules for the policy (compatibility method for TheoryBuilder).
        
        Args:
            rules: List of production rules
        """
        self.production_rules = rules
    
    def select_action(self, env: MathEnv) -> Optional[int]:
        """
        Select a random action by first grouping by rule names and then selecting randomly.
        
        Args:
            env: The math environment containing the current state
            
        Returns:
            Index of the selected action or None if no valid actions
        """
        valid_actions = env.valid_actions
        
        if not valid_actions or self.production_rules is None:
            return None
        
        # Group valid actions by rule name
        rule_groups = defaultdict(list)
        for i, action in enumerate(valid_actions):
            rule = self.production_rules[action.rule_idx]
            rule_groups[rule.name].append(i)
        
        if not rule_groups:
            return None
            
        # First randomly select a rule name
        rule_name = self.rng.choice(list(rule_groups.keys()))
        
        # Then randomly select an action from that rule's group
        return int(self.rng.choice(rule_groups[rule_name])) 