from __future__ import annotations
from langgraph.graph import StateGraph, END, START
from typing import TypedDict, Literal, List, Dict, Tuple
import os, pandas as pd
from ast import literal_eval
from pydantic import BaseModel, Field
from langchain_openai import ChatOpenAI
from langchain_ollama import ChatOllama
from dotenv import load_dotenv

load_dotenv()


class GraphState(TypedDict, total=False):
    repromting_type: Literal["cut_based", "no", "naive", "strong_cut_based"]
    prompting_type: Literal["vanilla", "cot"]
    llm: Literal["llama3:8b", "gpt-4o-mini"]

    graph: Dict[int, List[int]]
    current_state: int
    target_state: int
    max_itr: int

    iteration: int
    path: List[int]
    valid_path: bool
    parse_error: bool
    path_error: bool

    message: str
    error: str
    suggestion: str
    recommendation: str
    cot: str

    tokens: int
    prompt_tokens: int
    output_tokens: int

    agent_log: List[str]
    output_log: List
    recommendation_log: List[str]
    cot_log: List[str]
    message_log: List[str]
    path_log: List[List[int]]
    error_log: List[str]
    suggestion_log: List[str]
    valid_path_log: List[bool]
    token_log: List[int]
    prompt_token_log: List[int]
    output_token_log: List[int]
    iteration_log: List[int]

    file_name: str
    edges: int
    node_count: int
    shortest_path_len: int

    invalid_transitions: Dict[int, List[int]]
    tried_paths: List[Tuple[int, ...]]


class GraphOutput(BaseModel):
    message: str = Field(
        description="Entire AIMessage returned by llm. Keep all lines, including newlines."
    )
    path: List[int] = Field(
        description="A valid path from current state to target state in the graph"
    )


def _canonical_path_tuple(path: List[int]) -> Tuple[int, ...]:
    return tuple(path)


def _accumulate_usage(state: GraphState, raw_resp: object) -> None:
    meta = (
        getattr(raw_resp, "usage_metadata", None)
        or getattr(raw_resp, "response_metadata", None)
        or {}
    )
    total = getattr(meta, "total_tokens", None) or meta.get("total_tokens") or 0
    prompt = getattr(meta, "input_tokens", None) or meta.get("input_tokens") or 0
    output = getattr(meta, "output_tokens", None) or meta.get("output_tokens") or 0
    state["tokens"] = int(state.get("tokens", 0) + (total or 0))
    state["prompt_tokens"] = int(state.get("prompt_tokens", 0) + (prompt or 0))
    state["output_tokens"] = int(state.get("output_tokens", 0) + (output or 0))


def _allowed_successors(
    graph: Dict[int, List[int]], cuts: Dict[int, List[int]], u: int
) -> List[int]:
    allowed = set(graph.get(u, []))
    banned = set(cuts.get(u, []))
    return sorted(list(allowed - banned))


def initialize_state(state: GraphState) -> GraphState:
    state["iteration"] = 0
    state["error"] = "Avoid the following transitions in your next attempt:"
    state["suggestion"] = ""
    state["recommendation"] = ""
    state["max_itr"] = state.get("max_itr", 15)

    state["path"] = [state["current_state"]]
    state["valid_path"] = False
    state["parse_error"] = False
    state["path_error"] = False

    state["tokens"] = 0
    state["prompt_tokens"] = 0
    state["output_tokens"] = 0

    state["agent_log"] = []
    state["path_log"] = []
    state["token_log"] = []
    state["prompt_token_log"] = []
    state["output_token_log"] = []
    state["message_log"] = []
    state["valid_path_log"] = []
    state["error_log"] = []
    state["suggestion_log"] = []
    state["cot_log"] = []
    state["output_log"] = []
    state["iteration_log"] = []

    state["invalid_transitions"] = {}
    state["tried_paths"] = []
    return state


def _make_llm(model_name: str, temperature: float = 1):
    if model_name.startswith("gpt"):
        llm = ChatOpenAI(model=model_name, temperature=temperature, timeout=60)
    else:
        # Ollama local models
        llm = ChatOllama(
            model=model_name, temperature=temperature, base_url="http://localhost:11434"
        )
    return llm


def action_agent(state: GraphState) -> GraphState:
    model_name = state["llm"]
    llm = _make_llm(model_name, temperature=1)

    action_agent_llm = llm.with_structured_output(GraphOutput, include_raw=True)
    if state["iteration"] == 0:
        base_instructions = (
            "You are given a graph and a task to find a valid path from the start node to the goal node.\n"
            f"Graph (adjacency list):\n{state['graph']}\n\n"
            f"Start Node: {state['current_state']}\nGoal Node: {state['target_state']}\n"
        )
    else:
        ban_lines = []
        for u, vs in state["invalid_transitions"].items():
            if not vs:
                continue
            allowed = sorted(set(state["graph"].get(u, [])) - set(vs))
            ban_lines.append(
                f"- Do not choose {u}→{vs}. Node {u} is only connected to: {sorted(state['graph'].get(u, []))}. Allowed now: {allowed}"
            )
        cuts_text = "\n".join(ban_lines) if ban_lines else "(no cuts)"
        novelty_note = (
            ""
            if not state["path_log"]
            else f"Previously generated paths (avoid exact repeats): {state['path_log']}\n"
        )
        base_instructions = (
            "You are given a graph and a task to find a valid path from the start node to the goal node.\n"
            f"Graph (adjacency list):\n{state['graph']}\n\n"
            f"Start Node: {state['current_state']}\nGoal Node: {state['target_state']}\n\n"
            f"Constraints (prompt-level cuts):\n{cuts_text}\n\n{novelty_note}"
        )

    if state.get("prompting_type") == "cot":
        thinking = (
            "Think step-by-step:\n"
            "1) Enumerate legal successors at each step; 2) avoid revisiting nodes; 3) stop at the goal; 4) prefer shortest valid path.\n"
        )
    else:
        thinking = ""

    human_prompt = (
        base_instructions
        + thinking
        + "Return a JSON object with fields `message` (free text) and `path` (list of integers)."
    )

    messages = [
        (
            "system",
            "You are a Graph traversal expert. The graph is a dict mapping node IDs to lists of successors. Always output valid paths.",
        ),
        ("human", human_prompt),
    ]

    output = action_agent_llm.invoke(messages)
    _accumulate_usage(state, output["raw"])
    try:
        if "parsed" in output and output["parsed"]:
            state["path"] = output["parsed"].path
            state["message"] = output["parsed"].message
            state["parse_error"] = False
        else:
            state["path"] = []
            state["message"] = "No parsed output"
            state["parse_error"] = True
    except Exception as e:
        state["path"] = []
        state["message"] = f"Parse error: {str(e)}"
        state["parse_error"] = True

    state.setdefault("agent_log", []).append("ActionAgent")
    state.setdefault("iteration_log", []).append(state["iteration"])
    state.setdefault("path_log", []).append(state["path"])
    state.setdefault("message_log", []).append(state.get("message", ""))
    state.setdefault("token_log", []).append(state.get("tokens", 0))
    state.setdefault("prompt_token_log", []).append(state.get("prompt_tokens", 0))
    state.setdefault("output_token_log", []).append(state.get("output_tokens", 0))

    tp = _canonical_path_tuple(state["path"])
    if tp and tp not in state["tried_paths"]:
        state["tried_paths"].append(tp)

    return state


def validation_agent(state: GraphState) -> GraphState:
    valid = True
    path = state.get("path", [])
    G = state["graph"]

    if not path:
        valid = False
    else:
        if path[0] != state["current_state"] or path[-1] != state["target_state"]:
            valid = False
        for i, node in enumerate(path):
            if node not in G:
                valid = False
                state["path_error"] = True
                break
            if i < len(path) - 1:
                u, v = path[i], path[i + 1]
                if v not in G.get(u, []):
                    valid = False
                    break

    state.setdefault("valid_path_log", []).append(valid)
    state["valid_path"] = valid
    return state


def reprompt_agent(state: GraphState) -> GraphState:
    state["iteration"] = int(state.get("iteration", 0)) + 1

    if state.get("repromting_type") in {
        "cut_based",
        "strong_cut_based",
        "naive",
    } and not state.get("path_error", False):
        path = state.get("path", [])
        G = state["graph"]
        if state["repromting_type"] == "naive":
            state["error"] = "Avoid the following transitions in your next attempt:"
            state["invalid_transitions"] = {}

        for i in range(1, len(path)):
            u, v = path[i - 1], path[i]
            if v not in G.get(u, []):
                state["invalid_transitions"].setdefault(u, [])
                if v not in state["invalid_transitions"][u]:
                    state["invalid_transitions"][u].append(v)
                if state["repromting_type"] == "strong_cut_based":
                    allowed_now = _allowed_successors(
                        G, state["invalid_transitions"], u
                    )
                    state[
                        "error"
                    ] += f"\n- Do not choose {u}→{v}; node {u} is only connected to {sorted(G.get(u, []))}. Allowed now: {allowed_now}"
                else:
                    state["error"] += f" {u}->{v},"

    return state


def should_end(state: GraphState):
    if state.get("valid_path"):
        return END
    if state.get("repromting_type") in {"cut_based", "naive", "strong_cut_based"}:
        return (
            "reprompt"
            if int(state.get("iteration", 0)) < int(state.get("max_itr", 15))
            else END
        )
    return END


if __name__ == "__main__":
    builder = StateGraph(GraphState)
    builder.add_node("initialize", initialize_state)
    builder.add_node("action_agent", action_agent)
    builder.add_node("validation_agent", validation_agent)
    builder.add_node("reprompt", reprompt_agent)

    builder.add_edge(START, "initialize")
    builder.add_edge("initialize", "action_agent")
    builder.add_edge("action_agent", "validation_agent")
    builder.add_edge("reprompt", "action_agent")
    builder.add_conditional_edges("validation_agent", should_end, ["reprompt", END])
    fsm_graph = builder.compile()

    results_rows = []

    for file in os.listdir("./data"):
        if not file.endswith(".csv"):
            continue
        file_path = os.path.join("./data", file)
        df = pd.read_csv(file_path, converters={"Graphs": literal_eval})
        for cnt in range(min(10, len(df))):
            for trial in range(10):
                print(f"Processing {file} — row {cnt} — trial {trial}")
                inputs: GraphState = {
                    "graph": df["Graphs"].iloc[cnt],
                    "current_state": int(df["Start"].iloc[cnt]),
                    "target_state": int(df["End"].iloc[cnt]),
                    "edges": int(df["Edges"].iloc[cnt]),
                    "node_count": int(df["Nodes"].iloc[cnt]),
                    "repromting_type": "no",
                    "prompting_type": "vanilla",
                    "llm": "gpt-4o-mini",
                    "max_itr": 15,
                }

                result = fsm_graph.invoke(input=inputs, config={"recursion_limit": 50})

                row = {
                    "file": file,
                    "row_id": cnt,
                    "trial": trial,
                    "graph": str(result.get("graph")),
                    "edges": result.get("edges"),
                    "nodes": result.get("node_count"),
                    "paths": result.get("path_log"),
                    "valid_path_log": result.get("valid_path_log"),
                    "final_valid": result.get("valid_path"),
                    "reprompts": result.get("iteration"),
                    "tokens_total": result.get("tokens"),
                    "tokens_prompt": result.get("prompt_tokens"),
                    "tokens_output": result.get("output_tokens"),
                }
                results_rows.append(row)
            result = pd.DataFrame(results_rows)
            result.to_csv(
                f'./results/{inputs["llm"]}/{inputs["repromting_type"]}_reprompting/{inputs["prompting_type"]}_prompting/{inputs["llm"][:-3]}_graph_{inputs["node_count"]}_{cnt}.csv',
                index=False,
            )
    if results_rows:
        out_df = pd.DataFrame(results_rows)
        os.makedirs("./results", exist_ok=True)
        out_df.to_csv("./results/summary.csv", index=False)
        print("Saved ./results/summary.csv")
