from environment.environment import SokobanEnvImpl
from environment.visualization import animate, render
from knowledge_graph.knowledge_graph import KnowledgeGraph
from .selection import selection
from .expansion import expansion
from .simulation import simulation, eval_state
from .backprop import backprop
from .state import InputState, OverallState, OutputState, GlobalState
from langgraph.utils.runnable import RunnableLike
from langgraph.graph import END, START, StateGraph
from langgraph.graph.state import CompiledStateGraph
import time

class MonteCarloTreeSearch:
    mcts_step: CompiledStateGraph

    def __init__(self, builder) -> None:        
        langgraph = StateGraph(OverallState, input=InputState, output=OutputState)
        langgraph.add_node("selection", builder.selection)
        langgraph.add_node("expansion", builder.expansion)
        langgraph.add_node("simulation", builder.simulation)
        langgraph.add_node("backprop", builder.backprop)

        langgraph.add_edge(START, "selection")
        langgraph.add_edge("selection", "expansion")
        langgraph.add_edge("expansion", "simulation")
        langgraph.add_edge("simulation", "backprop")
        langgraph.add_edge("backprop", END)

        self.mcts_step = langgraph.compile()
    
    def solve(self, env:SokobanEnvImpl, log_path:str=None, agent_player_model:str="qwen3:8b") -> set[int, float]:
        fixed_env = env.as_fixated()
        GlobalState().env = fixed_env
        GlobalState().kg = KnowledgeGraph(fixed_env)
        GlobalState().set_agent_palyer(agent_player_model)

        if log_path:
            GlobalState().agent_player.write_log("{log_path}agent_player.log".format(log_path=log_path), clear_log_path=True)
            GlobalState().kg.client.write_log("{log_path}client_neoj4.log".format(log_path=log_path), clear_log_path=True)

        GlobalState().kg.backprop(eval_state())

        start_time = time.time()
        while True:
            result = self.mcts_step.invoke({})
            if result.get("reward") > 5:
                break
        
        needed_time = time.time() - start_time
        path_nodes, summary, keys = GlobalState().kg.client.read("""MATCH (p:Path) RETURN p""")
        num_explored_nodes = len(path_nodes)
        records, summary, keys = GlobalState().kg.client.read("""MATCH (p:Path) WHERE p.done RETURN p """)
        trajectory = records[0]["p"]["trajectory"]
        if log_path:
            fixed_env.reset()
            render(env=fixed_env, path=trajectory, save_fig="{log_path}solution.png".format(log_path=log_path), show_fig=False)
            animate(env=fixed_env, path=trajectory, save_ani="{log_path}solution.gif".format(log_path=log_path))
        return trajectory, num_explored_nodes, needed_time

class Builder:
    selection: RunnableLike
    expansion: RunnableLike
    simulation: RunnableLike
    backprop: RunnableLike

    def __init__(self) -> None:
        self.selection = selection
        self.expansion = expansion
        self.simulation = simulation
        self.backprop = backprop

    def setSelection(self, selection:RunnableLike):
        self.selection = selection
        return self

    def setExpansion(self, expansion:RunnableLike):
        self.expansion = expansion
        return self

    def setSimulation(self, simulation:RunnableLike):
        self.simulation = simulation
        return self

    def setBackprop(self, backprop:RunnableLike):
        self.backprop = backprop
        return self
    
    def build(self) -> MonteCarloTreeSearch:
        return MonteCarloTreeSearch(self)

