import warnings
from lazzzy.ucs import ucs
from utils import enumerate_resume, write_jsonl
from executors import executor_factory
from generators import generator_factory

from typing import List, Set, Tuple


DEBUG = True

def debug_print(*args):
    if DEBUG:
        print(*args, flush=True)

class State:
    def __init__(self, code: str, feedback: str, reflection: str, state: Tuple[bool]):
        self.code = code
        self.feedback = feedback
        self.reflection = reflection
        self.state = state

    def __repr__(self):
        return f"State(code={self.code}, feedback={self.feedback}, reflection={self.reflection}, state={self.state})"

    def is_goal(self):
        return all(self.state)

    def __hash__(self):
        return hash((self.code, self.feedback, self.reflection))

    def get_unique_id(self):
        res = 0
        for i in range(len(self.state)):
            res += self.state[i] * (2**i)

        return res


def run_reflexion_ucs(
    dataset: List[dict],
    model: str,
    language: str,
    max_iters: int,
    pass_at_k: int,
    log_path: str,
    verbose: bool,
    expansion_factor: int,
    is_leetcode: bool = False
) -> None:
    exe = executor_factory(language, is_leet=is_leetcode)
    gen = generator_factory(language)

    num_items = len(dataset)
    num_success = 0
    for i, item in enumerate_resume(dataset, log_path):
        cur_pass = 0
        is_solved = False
        reflections = []
        cur_func_impl = ""
        while cur_pass < pass_at_k and not is_solved:
            debug_print(f"item {i} pass {cur_pass}")
            tests_i = gen.internal_tests(item["prompt"], model, 1)
            if len(tests_i) == 0:
                warnings.warn(f"no internal tests generated for item {i}")

            # first attempt
            debug_print("first attempt")
            cur_func_impl = gen.func_impl(item["prompt"], model, "simple")
            assert isinstance(cur_func_impl, str)  # num_comps of 1
            is_passing, feedback, state = exe.execute(cur_func_impl, tests_i)

            debug_print(f"first attempt: \n{cur_func_impl}\n{feedback}\n{state}")

            # if solved, exit--pass_at_k 1 early
            if is_passing:
                debug_print("solved at first attempt")
                is_solved = exe.evaluate(item["entry_point"], cur_func_impl, item["test"])
                num_success += 1 if is_solved else 0
                break

            reflection = gen.self_reflection(
                cur_func_impl, feedback, model)
            reflections.append(reflection)

            start = State(cur_func_impl, feedback, reflection, state)

            def expand(state: State) -> Set[Tuple[State, float]]:
                nonlocal max_iters
                nonlocal expansion_factor
                nonlocal item
                nonlocal model
                nonlocal tests_i
                nonlocal reflections

                new_states: Set[Tuple[State, float]] = set()

                debug_print(f"start expansion of: {state.state}")
                new_funcs = gen.func_impl(
                    func_sig=item["prompt"],
                    model=model,
                    strategy="reflexion",
                    prev_func_impl=state.code,
                    feedback=state.feedback,
                    self_reflection=state.reflection,
                    num_comps=expansion_factor,
                    temperature=0.75
                )
                assert isinstance(new_funcs, list)
                debug_print(f"generated num of funcs: {len(new_funcs)}")

                already_seen = set()

                for new_func in new_funcs:
                    if new_func in already_seen:
                        debug_print(f"skipping a func because already seen.")
                        continue

                    already_seen.add(new_func)

                    is_passing, feedback, new_state = exe.execute(new_func, tests_i)
                    debug_print(f"expanding: \n{new_func}\n{feedback}\n{new_state}")

                    if is_passing:
                        # return immediately if solved
                        return set([(State(new_func, feedback, "", new_state), 0)])

                    new_reflection = gen.self_reflection(new_func, feedback, model)
                    reflections.append(new_reflection)

                    num_failing = len([x for x in new_state if not x])
                    new_states.add(
                        (State(new_func, feedback, new_reflection, new_state), num_failing))


                debug_print(f"returning new states: {new_states}")

                return new_states

            def when_none(l: List[State]) -> State:
                debug_print(f"when_none called on: {l}")
                return min(l, key=lambda x: len([y for y in x.state if not y]))

            # this is either the goal state, or if not found, the current best state (lowest failed tests)
            best = ucs(
                start=start,
                expand=expand,
                is_goal=lambda x: x.is_goal(),
                # NOTE: this way we reduce our search space significantly
                # the maximum number of nodes is 2^num_tests,
                # which is 2^5 = 32
                get_unique_id=lambda x: x.get_unique_id(),
                when_none=when_none
            )
            assert best is not None  # impossible due to our when_none

            debug_print("BEST CODE:\n\n\n")
            debug_print(best.code)
            is_passing = exe.evaluate(
                item["entry_point"], best.code, item["test"], timeout=5)
            if is_passing:
                item["solution"] = best.code
                is_solved = True
                num_success += 1
                break  # breaking pass@k loop

            cur_pass += 1

        item["is_solved"] = is_solved
        item["reflections"] = reflections
        item["solution"] = cur_func_impl
        write_jsonl(log_path, [item], append=True)

        if verbose:
            print(
                f'completed {i+1}/{num_items}: acc = {round(num_success/(i+1), 2)}')
