import unittest
from src.searchlightimprove.graph_search.self_improve_search import SelfImprovementInitialInferencer
from src.searchlightimprove.llm_utils.llm_api_models import GPT35Multi
from src.searchlightimprove.proposers import LLMImprovementProposer
from src.searchlightimprove.evaluators import DummyEvaluator
from src.searchlight.headers import State
from src.GOPS.baseline_models_GOPS import *
from src.searchlightimprove.prompts.prompt_generators import PromptGenerator

class TestSelfImproveSearch(unittest.TestCase):

    # create LLM improvement proposer
    gpt = GPT35Multi(temperature=0.7, num_responses=1)
    check_function = LLMFunctionalValueHeuristic.test_evaluate_static
    prompt_generator = PromptGenerator(GOPS_RULES, GOPS_FUNCTION_SIGNATURE, SYS_PROMPT)
    proposer = LLMImprovementProposer(gpt, ['function to improve:', 'improve function:', 'improve the function:'], check_function=check_function, prompt_generator=prompt_generator)

    # create evaluator
    evaluator = DummyEvaluator()

    # create initial inferencer
    initial_inferencer = SelfImprovementInitialInferencer(gpt, proposer, evaluator)


    def test_predict(self):

        initial_function = 'def f(x):\n    return x\n'
        initial_state = State(initial_function, {'score': 0.5, 'notes': 'initial', 'done': False})

        actors, actor_to_action_to_prob, next_state_values, action_to_actor_to_reward, action_to_next_state = self.initial_inferencer.predict(initial_state)

        # assert that actors is {0}
        self.assertEqual(actors, {0})

        # assert that actor_to_action_to_prob is {0: {0: 1/4, 1: 1/4, 2: 1/4, 3: 1/4}}
        self.assertEqual(actor_to_action_to_prob, {0: {0: 1/4, 1: 1/4, 2: 1/4, 3: 1/4}})

        # assert that next_state_values are 0.5 no matter the next state
        for value in next_state_values.values():
            self.assertEqual(value, {0: 0.5})

        # assert that action_to_actor_to_reward is 0 unless it's the last action, which should have score 0.5
        for action, reward in action_to_actor_to_reward.items():
            if action == ((0, 3),):
                self.assertEqual(reward, {0: 0.5})
            else:
                self.assertEqual(reward, {0: 0})

        # assert that action_to_next_state is a state for each action
        self.assertEqual(len(action_to_next_state), 4)


