from src.dialogue_improve.data_loader import DataLoader
from src.dialogue_improve.prompting_improve import PromptSSGEvaluator
from src.dialogue_improve.prompt_generator import PromptGenerator
from src.searchlightimprove.llm_utils.llm_api_models import GPT35Multi
from src.Avalon.baseline_models_Avalon import AvalonBasicConfig

import unittest
import numpy as np


class TestPromptSSGEvaluator(unittest.TestCase):

    FILE_PATH = 'src/dialogue_improve/test_data.json'
    TEST_PROMPTS_MERLIN = ['As Merlin you should not reveal that you know who is Evil.', 'As Merlin you should focus on convincing the good players.', 'As Merlin you should try to point how the Evil players are acting using evidence.']
    TEST_PROMPTS_EVIL = ['As Evil you should try to blend in with the good players.', 'As Evil you should try to point out the good players that are acting suspiciously.', 'As Evil you should try to confuse the good players.']
    
    # first set up any instances to do not change
    @classmethod
    def setUpClass(cls):
        cls.config = AvalonBasicConfig.from_num_players(5)
        cls.data_loader = DataLoader()
        cls.data_loader.load_data(cls.FILE_PATH)
        cls.prompt_generator = PromptGenerator(cls.config)
        cls.llm_model = GPT35Multi()
        cls.role_to_evaluate = 0
        cls.num_batch_runs = 1
        cls.rng = np.random.default_rng()
        cls.players = {0, 1, 2, 3, 4}
    
    def test_evaluate_merlin(self):
        role_to_evaluate = 0
        evaluator = PromptSSGEvaluator(players=self.players, role_to_evaluate=role_to_evaluate, data_loader=self.data_loader, llm_model=self.llm_model, prompt_generator=self.prompt_generator, num_batch_runs=self.num_batch_runs, rng=self.rng)
        scores, notes = evaluator.evaluate(self.TEST_PROMPTS_MERLIN)
        self.assertTrue(len(scores) == len(self.TEST_PROMPTS_MERLIN), "Scores should be returned for each Merlin prompt.")
        self.assertTrue(all(isinstance(score, float) for score in scores), "All scores should be float values.")
        self.assertTrue(len(notes) == len(self.TEST_PROMPTS_MERLIN), "Notes should be returned for each Merlin prompt.")

    def test_evaluate_assassin(self):
        role_to_evaluate = 7
        evaluator = PromptSSGEvaluator(players=self.players, role_to_evaluate=role_to_evaluate, data_loader=self.data_loader, llm_model=self.llm_model, prompt_generator=self.prompt_generator, num_batch_runs=self.num_batch_runs, rng=self.rng)
        scores, notes = evaluator.evaluate(self.TEST_PROMPTS_EVIL)
        self.assertTrue(len(scores) == len(self.TEST_PROMPTS_EVIL), "Scores should be returned for each Evil prompt.")
        self.assertTrue(all(isinstance(score, float) for score in scores), "All scores should be float values.")
        self.assertTrue(len(notes) == len(self.TEST_PROMPTS_EVIL), "Notes should be returned for each Evil prompt.")




if __name__ == '__main__':
    unittest.main()

