import unittest
# from src.GOPS.value_heuristic_evaluators import GOPSValueHeuristicsSSGEvaluator
from src.searchlightimprove.llm_utils.llm_api_models import GPT35Multi
from src.searchlightimprove.proposers import LLMImprovementProposer
from src.searchlightimprove.prompts.improvement_prompts import IMPROVEMENT_PROMPTS
from src.GOPS.baseline_models_GOPS import *
from src.GOPS.examples.func_list import *
from src.searchlight.gameplay.simulators import GameSimulator
from src.GOPS.examples.abstract_list3 import abstract_list
from src.GOPS.examples.func_list3 import func_list
from src.searchlightimprove.analyzers import HeuristicsAnalyzer
from src.searchlightimprove.evolvers import BeamEvolver
from src.searchlightimprove.prompts.prompt_generators import PromptGenerator
from src.searchlightimprove.value_heuristic_improve import ValueHeuristicsSSGEvaluator

class TestImprovementProcess(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        # This method will be executed once before the tests in this class.
        cls.gpt = GPT35Multi(temperature=0.7, num_responses=1)
        check_function = LLMFunctionalValueHeuristic.test_evaluate_static
        parse_function = LLMFunctionalValueHeuristic.parse_llm_function
        prompt_generator = PromptGenerator(GOPS_RULES, GOPS_FUNCTION_SIGNATURE, SYS_PROMPT)

        transitor = GOPSForwardTransitor2()
        actor_enumerator = GOPSActorEnumerator()
        action_enumerator = GOPSActionEnumerator()
        start_state = GOPSState2({-1}, tuple(), tuple(), tuple(), 3)

        cls.simulator = GameSimulator(transitor=transitor, actor_enumerator=actor_enumerator, action_enumerator=action_enumerator, start_state=start_state)

        # cls.evaluator = GOPSValueHeuristicsSSGEvaluator(simulator=cls.simulator, num_batch_runs=1, players={0, 1}, against_benchmark=True)
        cls.evaluator = ValueHeuristicsSSGEvaluator(simulator=cls.simulator, num_batch_runs=1, players={0, 1}, against_benchmark=True, search_budget=8, random_rollouts=4, transitor=transitor, actor_enumerator=actor_enumerator, action_enumerator=action_enumerator, check_function=check_function, llm_func_value_heuristic_class=LLMFunctionalValueHeuristic)

        cls.analyzer = HeuristicsAnalyzer(prompt_generator=prompt_generator)

        seed_functions = [(func_list[0], {'abstract': abstract_list[0]}), (func_list[1],  {'abstract': abstract_list[1]}), (func_list[2],  {'abstract': abstract_list[2]})]
        cls.evolver = BeamEvolver(evaluator=cls.evaluator, analyzer=cls.analyzer, seed_functions=seed_functions, prompt_generator=prompt_generator, check_function=check_function, parse_function=parse_function, batch_size=2)

    def test_evaluator(self):
        scores, notes = self.evaluator.evaluate([func_list[0], func_list[1], func_list[2]])
        # assert that scores and notes have length 3
        self.assertEqual(len(scores), 3)
        self.assertEqual(len(notes), 3)
        # assert scores are floats
        self.assertIsInstance(scores[0], float)
        # Add more specific assertions here based on expected scores and notes.

    def test_proposer_with_feedback(self):
        func = func_list[0]
        _, notes = self.evaluator.evaluate([func])
        processed_feedback = self.analyzer.translate(notes[0])

        # assert that feedback0 is a string
        self.assertIsInstance(processed_feedback, str)

        prompt = self.evolver.prompt_generator.gen_draw_conclusions_from_feedback_prompt(func, processed_feedback)
        conclusions = self.evolver.model.generate(prompt, 1)[0]
        prompt = self.evolver.prompt_generator.gen_improved_function_prompt(prompt + conclusions)
        new_func = self.evolver.generate_function(prompt)


        # assert that new_func is a string
        self.assertIsInstance(new_func, str)

        # Add more specific assertions here based on expected outcomes.


        test_function_heuristic = LLMFunctionalValueHeuristic(func = new_func)
        result = test_function_heuristic.test_evaluate(new_func, safe_mode=True)
        self.assertIsNotNone(result)
        # Add more specific assertions here.

        self.evaluator.set_num_batch_runs(4)

        function_scores, function_notes, benchmark_scores = self.evaluator.evaluate_with_benchmark([new_func])
        self.assertIsNotNone(function_scores)
        self.assertIsNotNone(benchmark_scores)

        # assert that function_scores is a list of floats of length 1
        self.assertIsInstance(function_scores, list)
        self.assertIsInstance(function_scores[0], float)
        self.assertEqual(len(function_scores), 1)

        # assert that benchmark_scores is dict with float values
        self.assertIsInstance(benchmark_scores, dict)
        self.assertIsInstance(list(benchmark_scores.values())[0], float)

        # Reset number of batch runs for future tests.
        self.evaluator.set_num_batch_runs(1)
        # Add more specific assertions here based on expected benchmarking outcomes.

    def test_evolution_process(self):
        initial_function_count = len(self.evolver.get_fittest(-1))
        self.evolver.evolve_once()
        new_function_count = len(self.evolver.get_fittest(-1))
        self.assertNotEqual(initial_function_count, new_function_count)
        # Add more specific assertions here to validate the evolution process.

if __name__ == '__main__':
    unittest.main()
