from __future__ import annotations
from langgraph.graph import StateGraph, END, START
from typing import TypedDict, Literal, List, Dict, Tuple
import os, pandas as pd
from pydantic import BaseModel, Field
from langchain_openai import ChatOpenAI
from langchain_ollama import ChatOllama
from dotenv import load_dotenv
import json

load_dotenv()


class GraphState(TypedDict, total=False):
    repromting_type: Literal["cut_based", "no", "naive", "strong_cut_based"]
    prompting_type: Literal["vanilla", "cot"]
    llm: Literal["llama3:8b", "gpt-4o-mini"]

    graph: Dict[int, List[int]]
    current_state: int
    target_state: int
    max_itr: int

    val_fsm: Dict[int, List[int]]
    val_current_state: int
    val_target_state: int
    key_loc: List[int]
    door_loc: List[int]
    iteration: int
    path: List[List[int]]
    valid_path: bool
    parse_error: bool
    path_error: bool

    message: str
    error: str
    suggestion: str
    recommendation: str
    cot: str

    tokens: int
    prompt_tokens: int
    output_tokens: int

    agent_log: List[str]
    output_log: List
    recommendation_log: List[str]
    cot_log: List[str]
    message_log: List[str]
    path_log: List[List[List[int]]]
    error_log: List[str]
    suggestion_log: List[str]
    valid_path_log: List[bool]
    token_log: List[int]
    prompt_token_log: List[int]
    output_token_log: List[int]
    iteration_log: List[int]

    file_name: str
    edges: int
    node_count: int
    shortest_path_len: int

    invalid_transitions: Dict[int, List[List[int]]]
    tried_paths: List[List[int]]


class GraphOutput(BaseModel):
    message: str = Field(
        description="Entire AIMessage returned by llm. Keep all lines, including newlines."
    )
    path: List[List[int]] = Field(
        description="A valid path from current state to target state in the graph"
    )


def _canonical_path_tuple(path: List[List[int]]) -> Tuple[int, ...]:
    return tuple(path)


def _accumulate_usage(state: GraphState, raw_resp: object) -> None:
    meta = (
        getattr(raw_resp, "usage_metadata", None)
        or getattr(raw_resp, "response_metadata", None)
        or {}
    )
    total = getattr(meta, "total_tokens", None) or meta.get("total_tokens") or 0
    prompt = getattr(meta, "input_tokens", None) or meta.get("input_tokens") or 0
    output = getattr(meta, "output_tokens", None) or meta.get("output_tokens") or 0
    state["tokens"] = int(state.get("tokens", 0) + (total or 0))
    state["prompt_tokens"] = int(state.get("prompt_tokens", 0) + (prompt or 0))
    state["output_tokens"] = int(state.get("output_tokens", 0) + (output or 0))


def _allowed_successors(
    graph: Dict[int, List[int]], cuts: Dict[int, List[int]], u: int
) -> List[int]:
    allowed = set(graph.get(u, []))
    banned = set(cuts.get(u, []))
    return sorted(list(allowed - banned))


def initialize_state(state: GraphState) -> GraphState:
    state["iteration"] = 0
    state["error"] = "Avoid the following transitions in your next attempt:"
    state["suggestion"] = ""
    state["recommendation"] = ""
    state["max_itr"] = state.get("max_itr", 15)

    state["path"] = [state["current_state"]]
    state["valid_path"] = False
    state["parse_error"] = False
    state["path_error"] = False

    state["tokens"] = 0
    state["prompt_tokens"] = 0
    state["output_tokens"] = 0

    state["agent_log"] = []
    state["path_log"] = []
    state["token_log"] = []
    state["prompt_token_log"] = []
    state["output_token_log"] = []
    state["message_log"] = []
    state["valid_path_log"] = []
    state["error_log"] = []
    state["suggestion_log"] = []
    state["cot_log"] = []
    state["output_log"] = []
    state["iteration_log"] = []

    state["invalid_transitions"] = {}
    state["tried_paths"] = []
    return state


def _make_llm(model_name: str, temperature: float = 0):
    if model_name.startswith("gpt"):
        llm = ChatOpenAI(model=model_name, temperature=temperature, timeout=60)
    else:
        # Ollama local models
        llm = ChatOllama(
            model=model_name, temperature=temperature, base_url="http://localhost:11434"
        )
    return llm


def action_agent(state: GraphState) -> GraphState:
    model_name = state["llm"]
    llm = _make_llm(model_name, temperature=0)

    action_agent_llm = llm.with_structured_output(GraphOutput, include_raw=True)
    if state["iteration"] == 0:
        base_instructions = (
            "You are given a graph and a task to find a valid path from the start node to the goal node.\n"
            f"Graph (adjacency list):\n{state['graph']}\n\n"
            f"Start Node: {state['current_state']}\nGoal Node: {state['target_state']}\n"
        )
    else:
        ban_lines = []
        for u, vs in state["invalid_transitions"].items():
            if not vs:
                continue
            allowed = sorted(set(state["val_fsm"].get(u, [])) - set(vs))
            ban_lines.append(
                f"- Do not choose {u}→{vs}. Node {u} is only connected to: {sorted(state['val_fsm'].get(u, []))}. Allowed now: {allowed}"
            )
        cuts_text = "\n".join(ban_lines) if ban_lines else "(no cuts)"
        novelty_note = (
            ""
            if not state["path_log"]
            else f"Previously generated paths (avoid exact repeats): {state['path_log']}\n"
        )
        base_instructions = (
            "You are given a graph and a task to find a valid path from the start node to the goal node.\n"
            f"Graph (adjacency list):\n{state['graph']}\n\n"
            f"Start Node: {state['current_state']}\nGoal Node: {state['target_state']}\n\n"
            f"Constraints (prompt-level cuts):\n{cuts_text}\n\n{novelty_note}"
        )

    if state.get("prompting_type") == "cot":
        thinking = (
            "Think step-by-step:\n"
            "1) Enumerate legal successors at each step; 2) avoid revisiting nodes; 3) stop at the goal; 4) prefer shortest valid path.\n"
        )
    else:
        thinking = ""

    human_prompt = (
        base_instructions
        + thinking
        + "Return a JSON object with fields `message` (free text) and `path` (list of coordinate pairs as [x, y], e.g. [[1,1],[1,2],[2,2]])."
    )

    messages = [
        (
            "system",
            "You are a Graph traversal expert. The graph is a dict mapping node IDs to lists of successors. Always output valid paths.",
        ),
        ("human", human_prompt),
    ]

    output = action_agent_llm.invoke(messages)
    _accumulate_usage(state, output["raw"])
    try:
        if "parsed" in output and output["parsed"]:
            state["path"] = output["parsed"].path
            state["message"] = output["parsed"].message
            state["parse_error"] = False
        else:
            state["path"] = []
            state["message"] = "No parsed output"
            state["parse_error"] = True
    except Exception as e:
        state["path"] = []
        state["message"] = f"Parse error: {str(e)}"
        state["parse_error"] = True

    state.setdefault("agent_log", []).append("ActionAgent")
    state.setdefault("iteration_log", []).append(state["iteration"])
    state.setdefault("path_log", []).append(state["path"])
    state.setdefault("message_log", []).append(state.get("message", ""))
    state.setdefault("token_log", []).append(state.get("tokens", 0))
    state.setdefault("prompt_token_log", []).append(state.get("prompt_tokens", 0))
    state.setdefault("output_token_log", []).append(state.get("output_tokens", 0))

    tp = _canonical_path_tuple(state["path"])
    if tp and tp not in state["tried_paths"]:
        state["tried_paths"].append(tp)

    return state


def validation_agent(state: GraphState) -> GraphState:
    valid = True
    path = state.get("path", [])
    G = state["val_fsm"]
    if not path:
        valid = False
    else:
        tuple_path = [tuple(node) for node in path]
        if (
            state["val_current_state"] != tuple_path[0]
            or state["val_target_state"] != tuple_path[-1]
        ):
            valid = False
        for i, node in enumerate(tuple_path):
            if node not in G:
                valid = False
                state["path_error"] = True
                break
            if i < len(tuple_path) - 1:
                u, v = tuple_path[i], tuple_path[i + 1]
                if v not in G.get(u, []):
                    valid = False
                    break
        
        key_loc = state["key_loc"]
        door_loc = state["door_loc"]

        key_index = tuple_path.index(key_loc) if key_loc in tuple_path else -1
        door_index = tuple_path.index(door_loc) if door_loc in tuple_path else -1

        if key_index == -1 or door_index == -1 or key_index > door_index:
            valid = False
            
    state.setdefault("valid_path_log", []).append(valid)
    state["valid_path"] = valid
    return state

def reprompt_agent(state: GraphState) -> GraphState:
    state["iteration"] = int(state.get("iteration", 0)) + 1
    if state.get("repromting_type") in {
        "cut_based",
        "strong_cut_based",
        "naive",
    } and not state.get("path_error", False):
        path = state.get("path", [])
        path = [tuple(node) for node in path]
        G = state["val_fsm"]
        if state["repromting_type"] == "naive":
            state["error"] = "Avoid the following transitions in your next attempt:"
            state["invalid_transitions"] = {}

        for i in range(1, len(path)):
            u, v = path[i - 1], path[i]
            if v not in G.get(u, []):
                state["invalid_transitions"].setdefault(u, [])
                if v not in state["invalid_transitions"][u]:
                    state["invalid_transitions"][u].append(v)
                if state["repromting_type"] == "strong_cut_based":
                    allowed_now = _allowed_successors(
                        G, state["invalid_transitions"], u
                    )
                    state[
                        "error"
                    ] += f"\n- Do not choose {u}→{v}; node {u} is only connected to {sorted(G.get(u, []))}. Allowed now: {allowed_now}"
                else:
                    state["error"] += f" {u}->{v},"

    return state


def should_end(state: GraphState):
    if state.get("valid_path"):
        return END
    if state.get("repromting_type") in {"cut_based", "naive", "strong_cut_based"}:
        return (
            "reprompt"
            if int(state.get("iteration", 0)) < int(state.get("max_itr", 15))
            else END
        )
    return END


if __name__ == "__main__":
    builder = StateGraph(GraphState)
    builder.add_node("initialize", initialize_state)
    builder.add_node("action_agent", action_agent)
    builder.add_node("validation_agent", validation_agent)
    builder.add_node("reprompt", reprompt_agent)

    builder.add_edge(START, "initialize")
    builder.add_edge("initialize", "action_agent")
    builder.add_edge("action_agent", "validation_agent")
    builder.add_edge("reprompt", "action_agent")
    builder.add_conditional_edges("validation_agent", should_end, ["reprompt", END])
    fsm_graph = builder.compile()

    for file in os.listdir("./data/minigrid/key_door"):
        if not file.endswith(".csv"):
            continue
        if file[:4] != "case":
            continue
        if file[28:-4]!="5x5-v0":
            continue
        file_path = os.path.join("./data/minigrid/key_door", file)

        df = pd.read_csv(
            file_path,
            converters={
                "fsm": eval,
                "door_loc": eval,
                "goal_loc": eval,
                "key_loc": eval,
                "agent_loc": eval,
            },
        )
        results_rows = []
        for cnt in range(min(10, len(df))):
            for trial in range(1):
                print(f"Processing {file} — row {cnt} — trial {trial}")
                inputs: GraphState = {
                    "graph": df["fsm"].iloc[cnt],
                    "val_fsm": df["fsm"].iloc[cnt],
                    "current_state": df["agent_loc"].iloc[cnt],
                    "target_state": df["goal_loc"].iloc[cnt],
                    "key_loc": df['key_loc'].iloc[cnt],
                    "door_loc": df['door_loc'].iloc[cnt],
                    "val_current_state": df["agent_loc"].iloc[cnt],
                    "val_target_state": df["goal_loc"].iloc[cnt],
                    "repromting_type": "no",
                    "prompting_type": "vanilla",
                    "llm": "llama3:8b",
                    "max_itr": 15,
                }
                if inputs["llm"] == "gpt-4o-mini":
                    graph_serializable = {}
                    for key, neighbors in inputs["graph"].items():
                        list_key = list(key)
                        list_neighbors = [list(n) for n in neighbors]
                        graph_serializable[json.dumps(list_key)] = list_neighbors

                    inputs["graph"] = json.dumps(graph_serializable, indent=2)
                    inputs["current_state"] = list(inputs["current_state"])
                    inputs["target_state"] = list(inputs["target_state"])

                result = fsm_graph.invoke(input=inputs, config={"recursion_limit": 50})
                print(
                    f"Result: valid_path={result.get('valid_path')}, path={result.get('path')}, iterations={result.get('iteration')}, tokens={result.get('tokens')}"
                )
                row = {
                    "file": file,
                    "row_id": cnt,
                    "trial": trial,
                    "graph": str(result.get("graph")),
                    "paths": result.get("path_log"),
                    "valid_path_log": result.get("valid_path_log"),
                    "final_valid": result.get("valid_path"),
                    "reprompts": result.get("iteration"),
                    "tokens_total": result.get("tokens"),
                    "tokens_prompt": result.get("prompt_tokens"),
                    "tokens_output": result.get("output_tokens"),
                }
                results_rows.append(row)
            result = pd.DataFrame(results_rows)
            result.to_csv(
                f'./results/{inputs["llm"][:-3]}/minigrid/{file[28:-4]}/{inputs["repromting_type"]}_reprompting/{inputs["prompting_type"]}_prompting/{inputs["llm"][:-3]}_minigrid_{cnt}.csv',
                index=False,
            )
    if results_rows:
        out_df = pd.DataFrame(results_rows)
        os.makedirs("./results", exist_ok=True)
        out_df.to_csv(f"./results/summary_{file}.csv", index=False)
        print(f"Saved ./results/summary_{file}.csv")
