from itertools import combinations
from .team import Team
from typing import List, Dict, Set, Optional
from agent.llm import LLM
import unittest

class BeliefModel(Team):
    def __init__(self, name: str = "Team"):
        super().__init__(name)
        self.scores: Dict[str, Dict[str, float]] = self._initialize_scores()
        self.current_matching: Dict[str, Set[str]] = {}
        self.epsilon: float = 0.01
        self.rematching_required: bool = False

    def _initialize_scores(self) -> Dict[str, Dict[str, float]]:
        """Initialize score matrix for all agents"""
        scores = {}
        for agent_i in self.agents:
            scores[agent_i.name] = {}
            for agent_j in self.agents:
                if agent_i.name != agent_j.name:  # Exclude self-scores
                    scores[agent_i.name][agent_j.name] = 0.0
        return scores

        
    def update_score(self, agent_i: LLM|str, agent_j: LLM|str, score: float):
        """Update score between agents, ignoring self-scores"""
        if isinstance(agent_i, LLM):
            agent_i = agent_i.name
        if isinstance(agent_j, LLM):
            agent_j = agent_j.name
            
        if agent_i == agent_j:  # Skip self-scores
            return

        if agent_i not in self.scores:
            self.scores[agent_i] = {}
        self.scores[agent_i][agent_j] = score

    def get_score(self, agent_i: LLM|str, agent_j: LLM|str) -> float:
        """Get score between agents with safety checks"""
        name_i = agent_i.name if isinstance(agent_i, LLM) else agent_i
        name_j = agent_j.name if isinstance(agent_j, LLM) else agent_j

        if name_i not in self.scores or name_j not in self.scores[name_i]:
            return 0.0  # Default score for unknown pairs

        return self.scores[name_i][name_j]


    def find_stable_coalitions(self, min_coalition_size: int, max_coalition_size: int) -> Optional[List[Set[str]]]:
        """Find stable or approximately stable coalitions using social welfare maximization"""
        if not self.scores:
            self.scores = self.initialize_scores()

        # First try exact stable matching
        coalitions = self._find_exact_stable_coalitions(min_coalition_size, max_coalition_size)
        if coalitions:
            return coalitions

        # If no stable matching, find socially optimal matching
        return self._find_socially_optimal_coalitions(min_coalition_size, max_coalition_size)

    def _find_socially_optimal_coalitions(self, min_size: int, max_size: int) -> List[Set[str]]:
        """Find coalitions that maximize social welfare while minimizing instability"""
        best_matching = []
        best_score = float('-inf')
        min_instability = float('inf')

        # Try all possible coalition combinations
        agent_names = [agent.name for agent in self.agents]
        possible_sizes = range(min_size, min(max_size + 1, len(agent_names) + 1))

        for size in possible_sizes:
            for coalition_partition in self._generate_partitions(agent_names, size):
                score = self._calculate_social_welfare(coalition_partition)
                instability = self._measure_instability(coalition_partition)

                # Update best matching if better found
                if (instability < min_instability or 
                    (abs(instability - min_instability) < self.epsilon and score > best_score)):
                    min_instability = instability
                    best_score = score
                    best_matching = coalition_partition

        self.rematching_required = min_instability > self.epsilon
        return best_matching

    def _calculate_social_welfare(self, coalitions: List[Set[str]]) -> float:
        """Calculate total social welfare of a matching"""
        total_welfare = 0.0
        for coalition in coalitions:
            for agent1 in coalition:
                for agent2 in coalition:
                    if agent1 != agent2:
                        total_welfare += self.get_score(agent1, agent2)
        return total_welfare

    def _measure_instability(self, coalitions: List[Set[str]]) -> float:
        """Measure degree of instability in matching"""
        instability = 0.0
        
        # Get mapping of agent to their coalition
        agent_coalitions = {
            agent: coalition
            for coalition in coalitions
            for agent in coalition
        }

        # Check all possible blocking pairs
        for agent1 in agent_coalitions:
            current_partners1 = agent_coalitions[agent1]
            for agent2 in agent_coalitions:
                if agent1 == agent2:
                    continue
                    
                current_partners2 = agent_coalitions[agent2]
                
                # Calculate current utilities
                util1_current = sum(self.get_score(agent1, p) 
                                  for p in current_partners1 if p != agent1)
                util2_current = sum(self.get_score(agent2, p) 
                                  for p in current_partners2 if p != agent2)
                
                # Calculate utilities if they paired
                util1_new = self.get_score(agent1, agent2)
                util2_new = self.get_score(agent2, agent1)
                
                # Add to instability if blocking pair found
                if util1_new > util1_current and util2_new > util2_current:
                    instability += min(util1_new - util1_current,
                                     util2_new - util2_current)

        return instability

    def _generate_partitions(self, agents: List[str], size: int) -> List[List[Set[str]]]:
        """Generate possible partitions of agents into coalitions of given size"""
        if len(agents) < size:
            return []
            
        if len(agents) == size:
            return [[set(agents)]]
            
        result = []
        first = agents[0]
        
        # Try all possible coalitions containing first agent
        for coalition_members in combinations(agents[1:], size - 1):
            coalition = {first} | set(coalition_members)
            remaining = [a for a in agents if a not in coalition]
            
            # Recursively partition remaining agents
            for sub_partition in self._generate_partitions(remaining, size):
                result.append([coalition] + sub_partition)
                
        return result

    def _find_exact_stable_coalitions(self, min_size: int, max_size: int) -> Optional[List[Set[str]]]:
        """Try to find exactly stable coalitions first"""
        unmatched = set(agent.name for agent in self.agents)
        coalitions = []

        while len(unmatched) >= min_size:
            # Find agent with highest total score as proposer
            proposer = max(unmatched,
                         key=lambda x: sum(self.get_score(x, y) 
                                         for y in unmatched if y != x))

            # Find best coalition for proposer
            best_coalition = self._find_best_coalition(
                proposer, 
                unmatched - {proposer},
                min_size - 1,
                max_size - 1
            )

            if not best_coalition:
                break

            # Check if coalition is stable
            coalition = {proposer} | best_coalition
            if self._is_coalition_stable(coalition, unmatched - coalition):
                coalitions.append(coalition)
                unmatched -= coalition
            else:
                break

        return coalitions if len(unmatched) < min_size else None

    def _find_best_coalition(self, agent: str, available: Set[str], 
                           min_size: int, max_size: int) -> Optional[Set[str]]:
        """Find best coalition for an agent"""
        best_coalition = None
        best_score = float('-inf')

        for size in range(min_size, min(max_size + 1, len(available) + 1)):
            for members in combinations(available, size):
                coalition = set(members)
                score = sum(self.get_score(agent, member) for member in coalition)
                
                if score > best_score:
                    best_score = score
                    best_coalition = coalition

        return best_coalition

    def _is_coalition_stable(self, coalition: Set[str], outside: Set[str]) -> bool:
        """Check if a coalition is stable against outside agents"""
        for agent in coalition:
            current_utility = sum(self.get_score(agent, p) 
                                for p in coalition if p != agent)
            
            # Check if agent would prefer to join any other agent
            for other in outside:
                if self.get_score(agent, other) > current_utility:
                    if self.get_score(other, agent) > sum(self.get_score(other, p) 
                                                        for p in coalition if p != other):
                        return False
        return True