from causal_discovery_llm import causal_discovery_llm
from utils.causalgraph_utils import save_graph_to_file
from utils.metrics import eval_metrics
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage
from langgraph.graph import StateGraph, START, END
from state import CausalDiscoveryState
import streamlit as st
import pandas as pd
import networkx as nx
import json
import os
import yaml
import datetime
import dotenv

title = "Multi Agent Causal Discovery"

icons_dir_path = os.path.join(os.path.dirname(__file__), "ui_icons")
st.set_page_config(page_title=title, page_icon=os.path.join(icons_dir_path, "graph_7.svg"), layout="wide")
st.title(title)

# Load environment variables from .env file
dotenv_file = dotenv.find_dotenv()
dotenv.load_dotenv(dotenv_file)

# Get the SELECTED_MODEL_NAME variable
selected_model_name = os.getenv("SELECTED_MODEL_NAME")

# Determine the model_name based on SELECTED_MODEL_NAME
if selected_model_name == "openai":
    model_name = os.getenv("MODEL_NAME")
else:
    model_name = selected_model_name

tool_name_mapping = {
    "rag_assistant": "Web Search Team",
    "human_in_the_loop": "User Input"
}

agents_to_display = ["Explainer", "Divide Hypothesis", "Divide Critic", "Divide Critic - Data-driven Refinement", "Hypothesis", "Critic", "Critic - Data-driven Refinement", "FCI", "Merge Hypothesis", "Merge Critic"]

# Function that iterates over the state update from each "named" agent, and displays the content in the UI
# It displays the auto-generated user prompt, tool calls, tool results within collapsible extenders, and the AI Responses
def display_messages(messages: list, role: str):
    """
    Display messages in the Streamlit app.

    Args:
        messages (list): List of langchain's BaseMessage ojects to display
        role (str): Role of the agent (e.g., "Explainer", "Divide")
    """
    for message in messages:
        # The first message should be the user prompt, which is auto-generated by the system to include the input data
        if isinstance(message, HumanMessage):
            with st.expander(label=f"Auto-generated user prompt for {role}"):
                    with st.chat_message(name="user", avatar=os.path.join(icons_dir_path, "person.svg")):
                        st.write(f"**User prompt**")
                        st.write(message.content)

        elif isinstance(message, AIMessage) and message.tool_calls:
            # If the message is a tool call, we want to display the query to the user in a nice way
            # Tool calls are made by the anonymous "agent" contained within the prebuilt ReAct subgraph
            # So to check who is actually making the call we need to fetch the process name from the first element in the streamed state tuple
            with st.chat_message(name="assistant", avatar=os.path.join(icons_dir_path, "robot_2.svg")):
                st.write(f"**{role}**")
                st.write(f"**Tool Calls:**")
                for tool_call in message.tool_calls:
                    if tool_call["name"] == "rag_assistant":
                        tool_query = tool_call['args']['research_topic']
                        st.write(f"Researching: '{tool_query}' ...")
                for tool_call in message.tool_calls:
                    if tool_call["name"] == "human_in_the_loop":
                        tool_query = tool_call['args']['question']
                        st.write(f"Asked user for: {tool_query}")

        # If instead the message is a tool answer, we display the content of the output
        elif isinstance(message, ToolMessage):
            # We first get the name of the tool from the first element of the tuple and print it
            tool_name = message.name
            with st.expander(label="Tool Call result"):
                if tool_name == "rag_assistant":
                    with st.chat_message(name="tool", avatar=os.path.join(icons_dir_path, "manage_search.svg")):
                        st.write(f"**{tool_name_mapping.get(tool_name, tool_name)}**")
                        if message.content:
                            st.write(message.content)
                elif tool_name == "human_in_the_loop":
                    with st.chat_message(name="tool", avatar=os.path.join(icons_dir_path, "person.svg")):
                        st.write(f"**{tool_name_mapping.get(tool_name, tool_name)}**")
                        if message.content:
                            st.write(message.content)

        elif isinstance(message, AIMessage):
            with st.chat_message(name="assistant", avatar=os.path.join(icons_dir_path, "robot_2.svg")):
                st.write(f"**{role}**")
                st.write(message.content)

dataset_file = None
description_file = None
ground_truth_file = None

# Single file uploader widget for dataset, description, and ground truth files
uploaded_files = st.file_uploader(
    "Upload Dataset (CSV), Description (JSON), and Ground Truth (JSON) Files",
    type=["csv", "json"],
    accept_multiple_files=True,
)

for uploaded_file in uploaded_files or []:
    if uploaded_file.name.endswith(".csv"):
        dataset_file = uploaded_file
        df = pd.read_csv(dataset_file)
        df = df.head(20000)  # Truncate to 20,000 samples

    elif uploaded_file.name.endswith(".json"):
        file_content = uploaded_file.read().decode("utf-8")
        json_content = json.loads(file_content)
        if "domain" in json_content and "description" in json_content:
            description_file = uploaded_file
            description_content = json_content

        elif "edges" in json_content:
            ground_truth_file = uploaded_file
            ground_truth_content = json_content

if dataset_file and description_file and ground_truth_file:

    with st.chat_message(name="user", avatar=os.path.join(icons_dir_path, "person.svg")):
        st.write("**CSV File Header:**")
        st.write(df.head())

        st.write("**Description Details:**")
        st.write(f"Domain: {description_content['domain']}")
        st.write(f"Description: {description_content['description']}")

    # Load config.yaml
    yaml_file = "src/config.yaml"
    with open(yaml_file, 'r') as file:
        config = yaml.safe_load(file)

    causal_discovery = lambda state: causal_discovery_llm(mode="UI",
                                                        state=state, 
                                                        dataset_content=df, 
                                                        description_content=description_content, 
                                                        ground_truth_content=ground_truth_content,
                                                        config=config
                                                        )

    builder = StateGraph(CausalDiscoveryState)
    builder.add_edge(START, "Causal Discovery")
    builder.add_node("Causal Discovery", causal_discovery)
    builder.add_edge("Causal Discovery", END)
    graph = builder.compile()

    # Since the above graph is a simple call to causal_discovery_llm, here we just use the same state and initialite it as empty
    # Within causal_discovery_llm, the state will be updated with the messages and the causal graph
    causal_discovery_state = CausalDiscoveryState(messages=[], causal_graph=None, elapsed_time=0, input_token_count=0, output_token_count=0, tool_calls=[])

    state_stream = graph.stream(input=causal_discovery_state, stream_mode="updates", subgraphs=True)

    for state in state_stream:
        # While streaming subgraphs, the state is usually a tuple: the first element is the name and ID of the stategraph node, while the second element contains the update to the state
        # Since we are using Langgraph's prebuilt ReAct agents, each one constitutes its own subgraph, composed of "agent" and "tools" nodes

        # We first get the name of the active agent that from the second element of the tuple
        # It will be either "agent", "tools", the name of an agent in the causal discovery stategraph(s), or the name of an agent within the RAG system
        role = next(iter(state[1].keys()))

        # We then check if the role is one of the agents we want to display in the UI
        # If it is, we display the messages contained in the state update
        if role in agents_to_display:
            st.write("\n\n\n\n\n")  # Add whitespace between agents
            display_messages(state[1][role]['messages'], role)

        # Handle the case where the agent is making a tool call to the human_in_the_loop tool so that the message is displayed in real time in the ui
        if role == "agent" and state[1]["agent"]["messages"][-1].tool_calls:
            for tool_call in state[1]["agent"]["messages"][-1].tool_calls:
                if tool_call["name"] == "human_in_the_loop":
                        with st.chat_message(name="tool", avatar=os.path.join(icons_dir_path, "help.svg")):
                            tool_query = tool_call['args']['question']
                            st.write(f"**Asking for user input (check terminal)**\n{tool_query}")

    causal_graph_file_path = save_graph_to_file(
        input_file_name=dataset_file.name,
        causal_graph=state[1]['Causal Discovery']['causal_graph'],
        filetype="json",
    )

    # Check if the file is JSON or Pickle
    if causal_graph_file_path.endswith(".json"):
        with open(causal_graph_file_path, "r") as file:
            file_content = file.read()
        mime_type = "application/json"
        file_extension = "json"
    elif causal_graph_file_path.endswith(".pkl"):
        with open(causal_graph_file_path, "rb") as file:
            file_content = file.read()
        mime_type = "application/octet-stream"
        file_extension = "pkl"

    # Generate the file name based on dataset_file name and current date-time
    dataset_name = os.path.splitext(dataset_file.name)[0]  # Remove the ".csv" extension
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M")  # Format date and time

    st.download_button(
        label="Download the discovered causal graph",
        data=file_content,
        file_name=f"{dataset_name}_{timestamp}_causal_graph.{file_extension}",
        mime=mime_type,
        on_click="ignore",
        icon=":material/graph_3:",
    )

    if config["fci_conquer"]:
        from utils.causalgraph_utils import edgelist_to_generalgraph
        ground_truth_digraph = edgelist_to_generalgraph(ground_truth_content["edges"])
        learned_causal_graph = state[1]['Causal Discovery']['causal_graph']
    else:
        ground_truth_digraph = nx.DiGraph(ground_truth_content["edges"])
        learned_causal_graph = nx.DiGraph(state[1]['Causal Discovery']['causal_graph'])

    metrics_file_path = eval_metrics(
        save_path=os.path.dirname(causal_graph_file_path),
        ground_truth_graph=ground_truth_digraph,
        learned_causal_graph=learned_causal_graph,
        dataset_name=os.path.basename(dataset_file.name),
        discovery_time=state[1]["Causal Discovery"]["elapsed_time"],
        input_token_count=state[1]['Causal Discovery']['input_token_count'],
        output_token_count=state[1]['Causal Discovery']['output_token_count'],
        tool_calls=state[1]['Causal Discovery']['tool_calls'],
        model_name=model_name,
        tools=config["tool_list"],
    )

    with open(metrics_file_path, "r") as file:
        file_content = file.read()

    st.download_button(
        label="Download the evaluation metrics",
        data=file_content,
        file_name=f"{dataset_name}_{timestamp}_cd_metrics.csv",
        mime="text/csv",
        on_click="ignore",
        icon=":material/readiness_score:",
    )

    intermediate_metrics = state[1]['Causal Discovery'].get("intermediate_metrics", None)
    if intermediate_metrics:
        intermediate_metrics_df = pd.DataFrame(intermediate_metrics)
        # Save the intermediate metrics DataFrame to a CSV file
        intermediate_metrics_csv = intermediate_metrics_df.to_csv(index=False)

        # Add a download button for the intermediate metrics CSV
        st.download_button(
            label="Download Intermediate (leaf-by-leaf) Metrics",
            data=intermediate_metrics_csv,
            file_name=f"{dataset_name}_{timestamp}_intermediate_metrics.csv",
            mime="text/csv",
            on_click="ignore",
            icon=":material/file_download:",
        )