import unittest

from src.searchlightimprove.graph_search.self_improve_search import *
from src.searchlight.algorithms.best_first_search import BestFirstSearch
from src.searchlight.datastructures.graphs import ValueGraph2
from src.searchlightimprove.value_heuristic_improve import ValueHeuristicsSSGEvaluator
from src.searchlightimprove.llm_utils.llm_api_models import GPT35Multi
from src.searchlightimprove.proposers import LLMImprovementProposer
from src.searchlightimprove.prompts.improvement_prompts import IMPROVEMENT_PROMPTS
from src.searchlight.datastructures.adjusters import QValueAdjuster
from src.searchlight.datastructures.estimators import UtilityEstimatorLast
from src.GOPS.baseline_models_GOPS import *
from src.GOPS.examples.func_list import *
from src.searchlight.gameplay.simulators import GameSimulator



class TestEvaluator(unittest.TestCase):

    def test_gops_evaluator(self):
        # create GOPSValueHeuristicsSSGEvaluator
        transitor=GOPSForwardTransitor2()
        actor_enumerator=GOPSActorEnumerator()
        action_enumerator=GOPSActionEnumerator()
        start_state=GOPSState2({-1}, tuple(),tuple(), tuple(), 3)

        # create game simulator
        simulator = GameSimulator(transitor=transitor, actor_enumerator=actor_enumerator, action_enumerator=action_enumerator, start_state=start_state)

        check_function = LLMFunctionalValueHeuristic.test_evaluate_static
        llm_func_value_heuristic_class = LLMFunctionalValueHeuristic


        # create evaluator
        evaluator = ValueHeuristicsSSGEvaluator(simulator=simulator, num_batch_runs=2, players = {0,1}, against_benchmark=False, transitor=transitor, actor_enumerator=actor_enumerator, action_enumerator=action_enumerator, check_function=check_function, llm_func_value_heuristic_class=llm_func_value_heuristic_class)

        scores, notes = evaluator.evaluate([test_func, test_func, test_func])
    
        # make sure that 'trajectory' is in notes[0]
        self.assertTrue('trajectory_data' in notes[0])

        # make sure that notes[0]['trajectory'] is a list of length 4
        self.assertEqual(len(notes[0]['trajectory_data']), 4)

        # make sure that notes[0]['trajectory'][0] is a dict
        self.assertIsInstance(notes[0]['trajectory_data'][0], dict)

        # make sure that notes[0]['trajectory_data'][0]['trajectory'] is a list of length 6
        self.assertEqual(len(notes[0]['trajectory_data'][0]['trajectory']), 6)

        # make sure that notes[0]['trajectory_data'][0]['trajectory'][0][2] is a GOPSState2
        self.assertIsInstance(notes[0]['trajectory_data'][0]['trajectory'][0][2], GOPSState2)

        # make sure that notes[0]['trajectory_data'][0]['trajectory'][0][1] is a dict of length 2
        self.assertEqual(len(notes[0]['trajectory_data'][0]['trajectory'][0][1]), 2)
        self.assertIsInstance(notes[0]['trajectory_data'][0]['trajectory'][0][1], dict)

        # make sure that notes[0]['trajectory_data'][0]['heuristics_trajectory'] is a list of length 6
        self.assertEqual(len(notes[0]['trajectory_data'][0]['heuristics_trajectory']), 6)

        # make sure that notes[0]['trajectory_data'][0]['heuristics_trajectory'][0][1] is a dict
        self.assertIsInstance(notes[0]['trajectory_data'][0]['heuristics_trajectory'][0], dict)
