import unittest
# from src.GOPS.value_heuristic_evaluators import GOPSValueHeuristicsSSGEvaluator
from src.searchlightimprove.value_heuristic_improve import ValueHeuristicsSSGEvaluator
from src.searchlightimprove.llm_utils.llm_api_models import GPT35Multi
from src.searchlightimprove.proposers import LLMImprovementProposer
from src.searchlightimprove.prompts.improvement_prompts import IMPROVEMENT_PROMPTS
from src.GOPS.baseline_models_GOPS import *
from src.GOPS.examples.func_list import *
from src.searchlight.gameplay.simulators import GameSimulator
from src.GOPS.examples.abstract_list3 import abstract_list
from src.GOPS.examples.func_list3 import func_list
from src.searchlightimprove.analyzers import HeuristicsAnalyzer
from src.searchlightimprove.evolvers import *
from src.searchlightimprove.prompts.prompt_generators import PromptGenerator


class TestEvolvers(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        # This method will be executed once before the tests in this class.
        cls.gpt = GPT35Multi(temperature=0.7, num_responses=1)
        check_function = LLMFunctionalValueHeuristic.test_evaluate_static
        parse_function = LLMFunctionalValueHeuristic.parse_llm_function
        prompt_generator = PromptGenerator(GOPS_RULES, GOPS_FUNCTION_SIGNATURE, SYS_PROMPT)

        transitor = GOPSForwardTransitor2()
        actor_enumerator = GOPSActorEnumerator()
        action_enumerator = GOPSActionEnumerator()
        start_state = GOPSState2({-1}, tuple(), tuple(), tuple(), 3)

        cls.simulator = GameSimulator(transitor=transitor, actor_enumerator=actor_enumerator, action_enumerator=action_enumerator, start_state=start_state)

        # cls.evaluator = GOPSValueHeuristicsSSGEvaluator(simulator=cls.simulator, num_batch_runs=1, players={0, 1}, against_benchmark=True, search_budget=8, random_rollout_num_rollouts=4)

        cls.evaluator = ValueHeuristicsSSGEvaluator(simulator=cls.simulator, num_batch_runs=1, players={0, 1}, against_benchmark=True, search_budget=8, random_rollouts=4, transitor=transitor, actor_enumerator=actor_enumerator, action_enumerator=action_enumerator, check_function=check_function, llm_func_value_heuristic_class=LLMFunctionalValueHeuristic)

        cls.analyzer = HeuristicsAnalyzer(prompt_generator=prompt_generator)

        seed_functions = [(func_list[0], {'abstract': abstract_list[0]}), (func_list[1],  {'abstract': abstract_list[1]}), (func_list[2],  {'abstract': abstract_list[2]})]
        cls.evolver1 = ThoughtBeamEvolver(evaluator=cls.evaluator, analyzer=cls.analyzer, seed_functions=seed_functions, prompt_generator=prompt_generator, check_function=check_function, parse_function=parse_function, model=cls.gpt, batch_size=2)

        cls.evolver2 = ImprovementLibraryEvolver(evaluator=cls.evaluator, analyzer=cls.analyzer, seed_functions=seed_functions, prompt_generator=prompt_generator, check_function=check_function, parse_function=parse_function, model=cls.gpt, batch_size=2)

    def test_thought_beam_evolver(self):
        initial_function_count = len(self.evolver1.get_fittest(-1))
        self.evolver1.evolve_once()
        new_function_count = len(self.evolver1.get_fittest(-1))
        self.assertNotEqual(initial_function_count, new_function_count)
        # Add more specific assertions here to validate the evolution process.

    def test_improvement_library_evolver(self):
        initial_function_count = len(self.evolver2.get_fittest(-1))
        self.evolver2.evolve_once()
        new_function_count = len(self.evolver2.get_fittest(-1))
        self.assertNotEqual(initial_function_count, new_function_count)
        # Add more specific assertions here to validate the evolution process.

if __name__ == '__main__':
    unittest.main()
