from src.searchlight.gameplay.simulators import DialogueGameSimulator
from src.Avalon.baseline_models_Avalon import *
from src.Avalon.engine import AvalonBasicConfig, AvalonGameEnvironment
from src.dialogue_improve.action_planner import AvalonActionPlannerAgent, MuteDeafAvalonActionPlannerAgent
from src.searchlight.gameplay.agents import HumanDialogueAgent
from src.searchlightimprove.llm_utils.llm_api_models import GPT35Multi
from good_examples.Avalon.value_heuristics.list import functions as value_heuristics
# from good_examples.Avalon.dialogue_guide.list import guides as dialogue_guides
from good_examples.Avalon.dialogue_guide.english import role_to_guide as search_guide
from good_examples.Avalon.dialogue_guide.recon import role_to_guide as recon_guide

import logging

# set logging level to debug
logging.basicConfig(level=logging.DEBUG)

# create config
avalon_config = AvalonBasicConfig.from_num_players(5)
avalon_env = AvalonGameEnvironment(avalon_config)
total_games = 5
avg_score_list = []
trajectory_list = []
search_win_counter = 0
for _ in range(total_games):
    avalon_env.reset()
    start_state = AvalonState.init_from_env(avalon_env)

    actor_action_enumerator = AvalonActorActionEnumerator(avalon_env=avalon_env, default_player_order=tuple(range(avalon_config.num_players))) # player order is just natural order
    forward_transitor = AvalonTransitor(env=avalon_env)
    speaker_enumerator = AvalonSpeakerEnumerator(avalon_env=avalon_env)
    information_function = AvalonInformationFunction(config=avalon_config)
    action_parser = AvalonActorActionEnumerator.parse_str_to_action

    # create game simulator
    simulator = DialogueGameSimulator(transitor=forward_transitor, actor_action_enumerator=actor_action_enumerator, speaker_enumerator=speaker_enumerator, information_function=information_function, start_state=start_state)

    # create inputs to AvalonActionPlannerAgent
    llm_model = GPT35Multi(model="gpt-4o-2024-08-06")
    # llm_model = GPT35Multi(model="gpt-4")
    # llm_model = GPT35Multi(model="gpt-4o-mini")
    value_heuristic = AvalonLLMFunctionalValueHeuristic(value_heuristics[0])
    # dialogue_guide = dialogue_guides[0]

    rng = np.random.default_rng(12)
    agents = dict()

    # define the agents here
    str_roles = start_state.get_roles_in_str_list()
    for i, player in enumerate(list(range(avalon_config.num_players))):
        # if str_roles[i] in ["Merlin", "Servant"]:
        if str_roles[i] in ["Merlin"]:
            print(f"Player {player} is a good role. Assigning search guide.")
            role_to_dialogue_guide = search_guide
        # elif str_roles[i] in ["Minion", "Assassin"]:
        elif str_roles[i] in ["Assassin"]:
            print(f"Player {player} is an evil role. Assigning recon guide.")
            role_to_dialogue_guide = recon_guide
        else:
            role_to_dialogue_guide = search_guide
        # else:
        #     raise ValueError(f"Unknown role: {str_roles[i]}")

        agents[player] = AvalonActionPlannerAgent(
            config=avalon_config,
            llm_model=llm_model,
            player=player,
            value_heuristic=value_heuristic,
            role_to_dialogue_guide=role_to_dialogue_guide,
            rng=rng,
            num_rollout=32,
            node_budget=32
        )

    # simulate games
    num_games = 1
    avg_scores, trajectories = simulator.simulate_games(agents, num_games, display=False)
    if sum(avg_scores.values()) > 0:
        search_win_counter += 1
    avg_score_list.append(avg_scores)
    trajectory_list.append(trajectories)

    print(f"Final scores: {avg_scores}")

print(f"Average scores: {avg_score_list}")
print(f"Number of games search win: {search_win_counter}")