"""
Team Cooperation Module

This module implements team cooperation functionality for software development teams,
including coalition formation, belief modeling, and task execution.
"""

import argparse
import asyncio
import json
import logging
import os
import sys
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from typing import Dict, List, Optional, Set, Tuple

from agent.agent_abs import Agent
from agent.belief_model import BeliefModel
from agent.software import Engineer, ProjectManager
from agent.utils.logger import setup_logger
from config.const import CONFIGS as configs
from utils.load_dataset import load_task

# Data Models
@dataclass
class CoalitionRecord:
    """Records details about a coalition's composition and work"""
    round_id: int
    members: Dict[str, Agent]
    task_assignments: Dict[str, str]
    implementations: Dict[str, str]
    belief_alignment_scores: Dict[str, float]

class CooperationState(Enum):
    """Possible states of the team cooperation process"""
    PLANNING = "planning"
    IMPLEMENTING = "implementing" 
    REVIEWING = "reviewing"
    COMPLETED = "completed"
    FAILED = "failed"

# Main Team Cooperation Class
class TeamCooperation:
    """Manages cooperation between team members including coalition formation and task execution"""
    
    def __init__(
        self, 
        project_manager: ProjectManager,
        engineers: Dict[str, Engineer],
        min_coalition_size: int = 2,
        eps: float = 0.1,
        max_rounds: int = 5,
        use_matching: bool = True
    ):
        self.project_manager = project_manager
        self.engineers = engineers
        self.belief_model = BeliefModel("Development Team")
        self.min_coalition_size = min_coalition_size
        self.max_rounds = max_rounds
        self.use_matching = use_matching
        
        # Add all agents to belief model
        self.belief_model.add_agent(project_manager)
        for eng in engineers.values():
            self.belief_model.add_agent(eng)
        
        self.coalition_history: List[CoalitionRecord] = []
        self.current_round = 0
        self.rematch = -1
        self.eps = 0.1
        self.state = CooperationState.PLANNING
        self.logger = setup_logger(
            name="TeamCooperation",
            module_path="agent.team_cooperation",
            log_level=logging.DEBUG
        )
        self.api_semaphore = asyncio.Semaphore(5)  # Limit to 5 concurrent requests

    async def initialize_belief_scores(self):
        """Initialize belief scores based on initial task understanding"""
        self.logger.info("Initializing belief scores...")
        task_outlines = await self.project_manager.run()
        
        self.logger.debug(f"Initial task outlines: {json.dumps(task_outlines, indent=2)}")
        
        # Setup initial beliefs for each engineer
        for eng_i in self.engineers.values():
            beliefs = await eng_i.analyze_task(task_outlines)
            self.logger.debug(f"Engineer {eng_i.name} beliefs: {json.dumps(beliefs, indent=2)}")
        
        # Set high scores for initial beliefs
        for agent in self.belief_model.agents:
            for teammate in self.belief_model.agents:
                if agent != teammate:
                    self.belief_model.update_score(agent, teammate, 1.0)
                    self.logger.debug(
                        f"Compatibility score between {agent.name} and {teammate.name}: 1.0"
                    )

        # Set the initial coalition with all engineers and project manager
        self.coalition_history.append(CoalitionRecord(
            round_id=0,
            implementations={},
            belief_alignment_scores=json.dumps(self.belief_model.scores),
            members=self.engineers | {self.project_manager.name: self.project_manager},
            task_assignments=task_outlines,
        ))
        
        self.logger.info(f"Initial coalition: {self.coalition_history[0].members}")

    async def form_coalitions(self) -> List[Set[str]]:
        """Form coalitions based on current belief model"""
        if not self.use_matching:
            coalitions = [agent.name for agent in self.belief_model.agents]
            return [set(coalitions)]
        
        coalitions = self.belief_model.find_stable_coalitions(
            min_coalition_size=self.min_coalition_size,
            max_coalition_size=len(self.engineers.values())
        )
        
        if not coalitions:
            self.logger.warning("No stable coalitions found")
            return []   
            
        return coalitions

    async def execute_round(self) -> bool:
        """Execute one round of team cooperation"""
        self.current_round += 1
        self.logger.info(f"\n=== Starting Round {self.current_round} ===")
        
        if self.current_round > self.max_rounds:
            self.logger.warning("Maximum rounds reached without convergence")
            self.state = CooperationState.FAILED
            return False

        try:
            # Planning phase
            self.state = CooperationState.PLANNING
            self.logger.info("Phase: Engineers take actions")
            task_outlines = self.project_manager.action
    
            # Implementation phase
            self.state = CooperationState.IMPLEMENTING
            self.logger.info("Phase: Implementing solutions within coalitions")
            implementations = {}
            
            current_coalition = self.coalition_history[-1].members
            
            async def run_with_semaphore(eng):
                async with self.api_semaphore:
                    return await self.engineers[eng].run(task_outlines, self.project_manager.name)
            
            # Execute all engineer implementations in parallel
            engineer_tasks = [
                run_with_semaphore(eng)
                for eng in current_coalition
                if eng != self.project_manager.name
            ]
            
            if engineer_tasks:
                implementation_results = await asyncio.gather(*engineer_tasks)
                implementations = {
                    eng: result
                    for eng, result in zip(
                        [eng for eng in current_coalition if eng != self.project_manager.name],
                        implementation_results
                    )
                }

            # Update belief scores based on implementations
            has_unaligned = False
            
            async def update_agent_scores(agent_name, is_manager=False):
                agent = self.project_manager if is_manager else self.engineers[agent_name]
                scores = await (agent.check_belief_alignment(implementations) if is_manager 
                              else agent.analyze_teammate_implementations(implementations))
                
                for name, score in scores.items():
                    score_value = float(score["score"])
                    source = self.project_manager if is_manager else self.engineers[agent_name]
                    self.belief_model.update_score(source, self.engineers[name], score_value)
                    
                    unaligned_msg = f"Your action seems not aligned with {agent.name} in last round."
                    if score_value < self.eps:
                        self.engineers[name].update_default_prompts(unaligned_msg)
                        nonlocal has_unaligned
                        has_unaligned = True
                    elif self.engineers[name].default_prompts[-1]["content"] == unaligned_msg:
                        self.engineers[name].default_prompts.pop()

            # Process project manager first
            if self.project_manager.name in current_coalition:
                await update_agent_scores(self.project_manager.name, is_manager=True)
                
            # Process other engineers
            engineer_tasks = [
                update_agent_scores(agent)
                for agent in current_coalition 
                if agent != self.project_manager.name
            ]
            if engineer_tasks:
                await asyncio.gather(*engineer_tasks)

            # Form new coalitions
            self.logger.info("Phase: Updating coalitions")
            if self.current_round == 1 or self.rematch == 1:
                coalitions = await self.form_coalitions()
            else:
                coalitions = [set(self.coalition_history[-1].members.keys())] # keep the same coalitions if no unaligned engineers
                
            if not coalitions or has_unaligned:
                self.rematch += 1
                self.logger.warning(f"Round {self.current_round}: No stable coalitions formed")
                return False
            
            self.rematch = -1
            
            self.logger.info(f"Formed coalitions: {coalitions}")
            
            # Update prompts based on collaboration context
            if self.belief_model.collaboration_context:
                observation_prompts = [
                    "\nTeam Collaboration Context:",
                    "Required Interfaces:",
                    json.dumps(self.belief_model.collaboration_context.interfaces, indent=2),
                    "\nDependencies:",
                    json.dumps(self.belief_model.collaboration_context.dependencies, indent=2)
                ]
                
                for agent in self.belief_model.agents:
                    agent.update_default_prompts(observation_prompts)

            # Update task outlines
            await self.project_manager.run()

            # Record keeping
            coalition_record = CoalitionRecord(
                round_id=self.current_round,
                members={
                    agent_name: self.engineers[agent_name] if agent_name in self.engineers else self.project_manager
                    for agent_name in coalitions[0]
                },
                task_assignments=task_outlines,
                implementations=implementations,
                belief_alignment_scores=self.belief_model.scores
            )
            self.coalition_history.append(coalition_record)

            self.logger.info(f"=== Round {self.current_round} completed successfully ===\n")
            return True

        except Exception as e:
            self.logger.error(f"Round {self.current_round} failed: {str(e)}", exc_info=True)
            self.state = CooperationState.FAILED
            return False

    async def run(self) -> Tuple[bool, Dict[str, str]]:
        """Run the team cooperation process"""
        try:
            await self.initialize_belief_scores()
            
            while self.current_round < self.max_rounds:
                success = await self.execute_round()
                if not success:
                    continue
                    
                self.state = CooperationState.COMPLETED
                return True, self._get_final_solution()

            self.state = CooperationState.FAILED
            return False, {}

        except Exception as e:
            self.logger.error(f"Team cooperation failed: {str(e)}")
            self.state = CooperationState.FAILED
            return False, {}

    def _get_final_solution(self) -> Dict[str, str]:
        """Get the final implementation solution"""
        if not self.coalition_history:
            return {}
            
        last_record = self.coalition_history[-1]
        return last_record.implementations

# Command Line Interface
def parse_arguments():
    """Parse command line arguments"""
    parser = argparse.ArgumentParser(description='Team Cooperation Simulation')
    parser.add_argument('--pm_tom_k', type=int, default=0,
                      help='Theory of Mind level for Project Manager (default: 0)')
    parser.add_argument('--eng_tom_k', type=int, default=1,
                      help='Theory of Mind level for Engineers (default: 1)')
    parser.add_argument('--num_engineers', type=int, default=5,
                      help='Number of engineers in the team (default: 5)')
    
    # matching arguments
    parser.add_argument('--matching', action='store_true',
                      help='Use matching algorithm to form coalitions (default: False)')
    parser.add_argument('--min_coalition_size', type=int, default=3,
                      help='Minimum size of coalitions (default: 3)')
    parser.add_argument('--max_rounds', type=int, default=5,
                      help='Maximum number of cooperation rounds (default: 5)')
    parser.add_argument('--task', type=str,
                      help='Task identifier')
    parser.add_argument('--output_path', type=str, default="results",
                      help='Output path (default: results)')
    parser.add_argument('--overwrite', action='store_true',
                      help='Overwrite existing results (default: False)')
    
    return parser.parse_args()

async def main():
    """Main entry point for team cooperation simulation"""
    args = parse_arguments()
    
    # Setup logging
    logger = setup_logger(
        name="Main",
        module_path="agent.team_cooperation.main",
        log_level=logging.DEBUG
    )
    
    logger.info("Starting team cooperation simulation...")
    logger.info(f"Parameters: PM ToM={args.pm_tom_k}, Eng ToM={args.eng_tom_k}, "
                f"#Engineers={args.num_engineers}, Min Coalition={args.min_coalition_size}, "
                f"Max Rounds={args.max_rounds}")
    
    results_output_path = f"results/{args.task}"
    os.makedirs(results_output_path, exist_ok=True)
    
    append_file_name = f"pm_ToM{args.pm_tom_k}_{args.num_engineers}_eng_ToM{args.eng_tom_k}_matching{args.matching}.jsonl"
    results_file_path = os.path.join(results_output_path, f"tmp_{append_file_name}")
    
    if args.overwrite:
        start_idx = 0
        # clear file ends with append_file_name
        for file in os.listdir(results_output_path):
            if file.endswith(append_file_name):
                os.remove(os.path.join(results_output_path, file))

    else:
        with open(results_file_path, "r") as f:
            start_idx = len(f.readlines())
        
    tasks = load_task(args.task, start_idx=start_idx)
    
    for task_dict in tasks:
        # Initialize team
        pm = ProjectManager(id=1, tom_k=args.pm_tom_k, task=task_dict["idea"])
        engineers = [
            Engineer(id=i, task=task_dict["idea"], tom_k=args.eng_tom_k)
            for i in range(args.num_engineers)
        ]

        logger.info(f"Created team with PM and {len(engineers)} engineers")

        # Create cooperation manager
        team_coop = TeamCooperation(
            project_manager=pm,
            engineers={ 
                eng.name: eng
                for eng in engineers
            },
            min_coalition_size=args.min_coalition_size,
            max_rounds=args.max_rounds,
            use_matching=args.matching
        )

        # Run cooperation process
        success, final_solution = await team_coop.run()

        final_output_solution = await pm.get_final_solution({"idea": task_dict["idea"], "implementation": final_solution})
        final_output_solution_dict = {
            **task_dict,
            "output": final_output_solution,
        }
        
        initial_output_solution = await pm.get_final_solution({"idea": task_dict["idea"], "implementation": team_coop.coalition_history[0].implementations})
        initial_output_solution_dict = {
            **task_dict,
            "output": initial_output_solution,
        }
        
        if success:
            logger.info("Successfully reached stable solution!")
            logger.info("Final implementations:")
            # save final solution to file
            final_solution["task_id"] = task_dict["task_id"] # record each engineer's final implementation

            with open(results_file_path, "a") as f:
                f.write(json.dumps(final_solution) + "\n")
            
            # append initial and final output solutions to jsonl file
            with open(os.path.join(results_output_path, f"Round_{0}_{append_file_name}"), "a") as f:
                f.write(json.dumps(initial_output_solution_dict) + "\n")
                
            with open(os.path.join(results_output_path, f"Round_{args.max_rounds}_{append_file_name}"), "a") as f:
                f.write(json.dumps(final_output_solution_dict) + "\n")
        else:
            logger.error("Failed to reach stable solution")

if __name__ == "__main__":
    asyncio.run(main())