from src.searchlight.utils import AbstractLogged
from ..Avalon.baseline_models_Avalon import AvalonState
from src.searchlight.gameplay.agents import MuteMCTSAgent
from src.searchlight.datastructures.graphs import PartialValueGraph
from src.searchlight.datastructures.adjusters import PUCTAdjuster
from src.searchlight.datastructures.estimators import UtilityEstimatorMean
from .dialogue_discrimination import DialogueDiscriminator
from .dialogue_generator import DialogueGenerator
from .prompt_generator import PromptGenerator as PPromptGenerator


import numpy as np
from collections import defaultdict
from src.Avalon.baseline_models_Avalon import *


import numpy as np


class AvalonActionPlannerAgent(MuteMCTSAgent):

    observed_dialogue: dict[tuple[int, int], list[tuple[int, str]]]
    role_to_dialogue_guide: dict[int, str]
    information_prior: DialogueDiscriminator
    summary: str # summary of the game so far as observed by the agent

    def __init__(self, config: AvalonBasicConfig, llm_model: LLMModel, player: int, value_heuristic: ValueHeuristic, role_to_dialogue_guide: dict[int, str], num_rollout: int = 100, node_budget: int = 100, rng: np.random.Generator = np.random.default_rng()):
        
        
        # create new game environment for simulation so that we don't mess up the main game environment
        env = AvalonGameEnvironment.from_num_players(config.num_players)
        num_players = env.config.num_players
        forward_transitor = AvalonTransitor(env)
        # default player order should be random but with the player as the first player
        default_player_order = list(range(num_players))
        default_player_order.remove(player)
        rng.shuffle(default_player_order)
        default_player_order.insert(0, player)
        actor_action_enumerator = AvalonActorActionEnumerator(avalon_env=env, default_player_order=tuple(default_player_order))
        policy_predictor = PolicyPredictor()
        information_function = AvalonInformationFunction(config=config)
        
        # start_state = AvalonState.init_from_env(env)
        players = set([i for i in range(num_players)])
        self.config = config
        self.observed_dialogue = defaultdict(list)
        self.dialogue_gudies = role_to_dialogue_guide

        dialogue_discriminator = DialogueDiscriminator(config=config, llm_model=llm_model, prompt_generator=PPromptGenerator(self.config), player=player, players=players, rng=rng)

        self.dialogue_generator = DialogueGenerator(llm_model=llm_model, prompt_generator=PPromptGenerator(self.config),)
        self.summary = 'Game just started, nothing happened yet.'

        super(). __init__(players=players, player=player, forward_transitor=forward_transitor, actor_action_enumerator=actor_action_enumerator, value_heuristic=value_heuristic, policy_predictor=policy_predictor, information_function=information_function, information_prior=dialogue_discriminator, num_rollout=num_rollout, node_budget=node_budget, rng=rng)

        self.logger.debug(f"Creating AvalonActionPlannerAgent for player {player}")

    def _observe_dialogue(self, state: AvalonInformationSet, new_dialogue: list[tuple[int, str]]):
        '''
        Observes new dialogue and updates internal states
        '''
        if not self.information_prior.check_role_equivalent(state.self_role):
            # reset the dialogue discriminator if the role is not equivalent
            self.information_prior.reset(known_sides=state.known_sides, player_role=state.self_role)
            self.logger.debug(f"Resetting dialogue discriminator for player {self.player} with role {state.self_role}")

        # if new_dialogue is empty, skip
        if not new_dialogue:
            return

        # we need to first convert the state and new_dialogue to string form here before passing it to the discriminator
        state_string_description = state.gen_str_description()
        new_dialogue_string = self.dialogue_list_to_str(new_dialogue)
        full_description = f"""\n*Current State*\n{state_string_description}\n\n*New Dialogue*\nThis round, players have said the following so far:\n{new_dialogue_string}"""

        self.logger.debug(f"Observed new dialogue: {new_dialogue_string}")
        # print("Full description: ", full_description)
        self.information_prior.update_beliefs(full_description)
        self.observed_dialogue[(state.turn, state.round)].extend(new_dialogue)

    def _act(self, state: AvalonInformationSet, actions: set[str]) -> str:
        '''
        Queries the actor action enumerator for the intended action
        '''
        if state.phase == 1: # team voting phase

            dialogue_string = self.dialogue_list_to_str(self.observed_dialogue[(state.turn, state.round)])
            state_string = state.gen_str_description()
            full_description = f"""\n*Current State*\n{state_string}\n\n*Discussion*\nThis round, players have said the following this round:\n{dialogue_string}"""
            # we need update the summary at the end of each round, before the team voting phase
            self.summary = self.dialogue_generator.generated_updated_summary(self.summary, full_description)
            self.logger.debug(f"Updated summary: {self.summary}")

        return super()._act(state, actions)
    
    @staticmethod
    def dialogue_list_to_str(dialogue: list[tuple[int, str]]):
        if not dialogue:
            return 'Nothing has been spoken yet.'
        return '\n---\n'.join([f"> Player {player} said:\n{dialogue}" for player, dialogue in dialogue])

    def _produce_utterance(self, state: AvalonInformationSet,) -> str:
        '''
        Produces a dialogue given a history
        '''

        # get the dialogue for this turn and round
        dialogue = self.observed_dialogue[(state.turn, state.round)]
        self.logger.debug(f"Producing dialogue for player {self.player} at turn {state.turn} and round {state.round} based on dialogue: {dialogue}")

        # we need to first convert the state and new_dialogue to string form here before passing it to the discriminator
        state_string_description = state.gen_str_description()
        new_dialogue_string = self.dialogue_list_to_str(dialogue) # NOTE: might need to add preamble
        full_description = f"""\n*Current State*\n{state_string_description}\n\n*Discussion*\nThis round, players have said the following so far:\n{new_dialogue_string}"""

        # query action planner for intended action
        # actor, actions = self.actor_action_enumerator.enumerate(state)
        intended_action = self._act(state, actions=set(["not yet decided"])) # NOTE: we might need to change this to actions
        tips = self.dialogue_gudies[state.self_role]

        response, notes = self.dialogue_generator.generate_dialogue(history=full_description, phase=state.phase, intended_action=intended_action, tips=tips)
        return response
    
class MuteDeafAvalonActionPlannerAgent(MuteMCTSAgent):

    observed_dialogue: dict[tuple[int, int], list[tuple[int, str]]]
    role_to_dialogue_guide: dict[int, str]

    def __init__(self, config: AvalonBasicConfig, llm_model: LLMModel, player: int, value_heuristic: ValueHeuristic, role_to_dialogue_guide: dict[int, str], num_rollout: int = 100, node_budget: int = 100, rng: np.random.Generator = np.random.default_rng()):
        
        
        # create new game environment for simulation so that we don't mess up the main game environment
        env = AvalonGameEnvironment.from_num_players(config.num_players)
        num_players = env.config.num_players
        forward_transitor = AvalonTransitor(env)
        default_player_order = list(range(num_players))
        default_player_order.remove(player)
        rng.shuffle(default_player_order)
        default_player_order.insert(0, player)
        actor_action_enumerator = AvalonActorActionEnumerator(avalon_env=env, default_player_order=tuple(default_player_order))
        policy_predictor = PolicyPredictor()
        information_function = AvalonInformationFunction(config=config)
        
        # start_state = AvalonState.init_from_env(env)
        players = set([i for i in range(num_players)])
        self.config = config
        self.observed_dialogue = defaultdict(list)
        self.dialogue_gudies = role_to_dialogue_guide

        self.dialogue_discriminator = DialogueDiscriminator(llm_model=llm_model, prompt_generator=PPromptGenerator(self.config), player=player, players=players,)
        information_prior = AvalonInformationPrior(config=config, belief_p_is_merlin=self.dialogue_discriminator.get_p_is_merlin(), belief_p_is_good=self.dialogue_discriminator.get_p_is_good(),)

        self.dialogue_generator = DialogueGenerator(llm_model=llm_model, prompt_generator=PPromptGenerator(self.config),)

        super(). __init__(players=players, player=player, forward_transitor=forward_transitor, actor_action_enumerator=actor_action_enumerator, value_heuristic=value_heuristic, policy_predictor=policy_predictor, information_function=information_function, information_prior=information_prior, num_rollout=num_rollout, node_budget=node_budget, rng=rng)

    def _observe_dialogue(self, state: AvalonInformationSet, new_dialogue: list[tuple[int, str]]):
        '''
        Observes new dialogue and updates internal states
        '''
        pass
    
    @staticmethod
    def dialogue_list_to_str(dialogue: list[tuple[int, str]]):
        return '\n'.join([f"Player {player} said: {dialogue}" for player, dialogue in dialogue])

    def _produce_utterance(self, state: AvalonInformationSet,) -> str:
        '''
        Produces a dialogue given a history
        '''
        return ''