from .colmind import ColMind
from .channel import CommunicationChannel
from threading import Thread

class MasterProcess:
    """MasterProcess handles all the MindForge agents and acts as a master process."""

    def __init__(self, config):
        self.config = config

    def start(self, cooperation: bool = False, learning: bool = False):
        threads = []
        env_config = self.config["env"]
        for agent_name, agent_config in self.config["players"].items():
            if not learning:
                thread = Thread(target=self.run, args=(agent_config | {"username": agent_name}, env_config, agent_config["port"], cooperation))
            else:
                thread = Thread(target=self.learn, args=(agent_config | {"username": agent_name}, env_config, agent_config["port"], cooperation))
            threads.append(thread)
            thread.start()

        # launch communication channel
        if cooperation:
            thread = Thread(target=self.communication_channel, args=(self.config | {"agents": ["weak", "strong"], "num_conversation_turns": 6},))
            threads.append(thread)
            thread.start()

        for thread in threads:
            thread.join()

    def run(self, agent_config: dict, env_config: dict, server_port: int, cooperation: bool = False):
        colmind = ColMind(agent_config, env_config)
        # TODO: this does not support continual learning
        sub_goals = agent_config["plan"]
        _ = colmind.inference(sub_goals=sub_goals, reset_env=False, evaluate=True, cooperation=cooperation)

        # close agent connection
        colmind.close()

    def learn(self, agent_config: dict, env_config: dict, server_port: int, cooperation: bool = False):
        print(self.config["resume"])
        colmind = ColMind(agent_config, env_config, resume=self.config["resume"])
        _ = colmind.learn(max_iterations=160, cooperate=cooperation)

        # close agent connection
        colmind.close()


    def communication_channel(self, config: dict):
        channel = CommunicationChannel(config)
        channel.run()
