from dotenv import load_dotenv
load_dotenv()

import os, sys, shutil, time
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ['USER_AGENT'] = "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:91.0) Gecko/20100101 Firefox/91.0"
os.environ["HF_EMBEDDING_MODEL"] = os.getenv("HF_EMBEDDING_MODEL")

from argparse import ArgumentParser, Namespace
parser = ArgumentParser()
parser.add_argument("--port_num", type=int, default=8888)

args = parser.parse_args()
args.CACHE = os.getenv("CACHE")
args.TOKEN = os.getenv("TOKEN")
args.MODEL = os.getenv("MODEL")
args.CODE_MODEL = os.getenv("CODE_MODEL")
args.GUIDANCE_FILE = os.getenv("GUIDANCE_FILE")
args.MODEL_INDEX = int(os.getenv("HF_MODEL_INDEX"))
args.CODE_MODEL_INDEX = int(os.getenv("HF_CODE_MODEL_INDEX"))
args.MAX_TOKENS = int(os.getenv("HF_MAX_TOKENS"))
args.MAX_CODE_TOKENS = int(os.getenv("HF_MAX_CODE_TOKENS"))
args.temperature = float(os.getenv("HF_TEMPERATURE"))
args.top_p = float(os.getenv("HF_TOP_P"))

import numpy as np
import pickle as pkl
import json, yaml, ast, copy
import tempfile, uuid, re
import pathlib, importlib
import uvicorn, requests
import asyncio, subprocess
import traceback
import zipfile
import huggingface_hub
import torch
import accelerate
import nest_asyncio

from os import listdir
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
from os.path import isfile, isdir, join
from PIL import Image
from io import BytesIO
from uuid import UUID
from importlib import util
from pydantic import BaseModel, Field
from accelerate.state import PartialState, AcceleratorState
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Annotated, Literal, Tuple, Sequence, TypeVar, Union, Callable

from fastapi import FastAPI, HTTPException, Request, File, UploadFile, Response
from fastapi.responses import StreamingResponse, FileResponse
from fastapi.middleware.cors import CORSMiddleware

from langchain import hub
from langchain_openai import ChatOpenAI
from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace

from langchain.chains.base import Chain
from langchain.memory import ConversationBufferMemory
from langchain.tools import tool

from langchain_core.prompts import PromptTemplate
from langchain_core.prompt_values import ChatPromptValue
from langchain_core.prompts.chat import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.prompts.chat import SystemMessagePromptTemplate, HumanMessagePromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_core.messages.base import BaseMessage
from langchain_core.messages.system import SystemMessage
from langchain_core.messages.human import HumanMessage
from langchain_core.messages.ai import AIMessage
from langchain_core.messages.tool import ToolMessage
from langchain_core.runnables.graph import CurveStyle, MermaidDrawMethod, NodeStyles

from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, START, StateGraph, MessagesState
from langgraph.graph.message import add_messages
# from langgraph.prebuilt import ToolNode

from rag.rag_agent import set_rag_stream, reset_rag_stream
from rag.cmrag import *

from utils.callbacks import IntermediateStepsCallback
from utils.crag_agent import init_crag_llm_state, InformationRetrieverAgent
from utils.agent_templates import EXECUTE_STEP_SYSTEM, ANSWER_PROMPT_TEMPLATE
from utils.tool_definitions import set_tools_llm_state, set_reasoning_llm, get_available_tools, set_rag_update, set_filesystem_root, get_filesystem_root
from utils.node_definitions import set_node_llm_state, set_solver_strategy, set_llm_owner, agentic_graph, kripke_structure, ReWOO, kripke_agent, properties_agent, NUSMV_planner

app = FastAPI()
model_dict = {}
code_model_dict = {}
GLOBAL_RETRIES = 5
GLOBAL_VERBOSE = False
GLOBAL_VECTORSTORE = None
GLOBAL_RETRIEVER = None
GLOBAL_SQL_UPLOAD_ROOT = None
BASE_LLM: Optional[RunnableLambda] = None
GLOBAL_LLM: Optional[RunnableLambda] = None
CODING_LLM: Optional[RunnableLambda] = None
REASONING_LLM: Optional[RunnableLambda] = None
GLOBAL_CANCELLED: Optional[bool] = False
GLOBAL_STREAM: Optional[List[str]] = [""]
GLOBAL_STREAM_CALLBACK: Optional[IntermediateStepsCallback] = None
GLOBAL_CONVERSATION_MEMORY: Optional[ConversationBufferMemory] = None

class QueryInput(BaseModel):
    query: str
    use_rag: bool = False
    use_crag: bool = False
    kripke_extractor: bool = False
    properties_extractor: bool = False
    nusmv_planner: bool = False
    strategy: Literal["ReWOO", "Plan and Execute"] = "ReWOO"
    selected_tools: List[str] = []
    chat_history: Optional[List[Dict[str, str]]] = []
    guidance_prompt: Optional[str] = None
    thread_id: Optional[str] = None
    cache_dir: Optional[str] = None

def initialize_rag():
    global GLOBAL_VECTORSTORE, GLOBAL_RETRIEVER
    global GLOBAL_LLM
    assert GLOBAL_LLM is not None
    GLOBAL_VECTORSTORE, GLOBAL_RETRIEVER = get_rag(GLOBAL_LLM, callback_manager=[GLOBAL_STREAM_CALLBACK])
    return GLOBAL_VECTORSTORE, GLOBAL_RETRIEVER

async def initialize_agent(memory, chat_history = None, 
                    tools = ["All"], 
                    strategy: Literal["ReWOO", "plan-and-execute"] = "ReWOO", 
                    kripke_extractor = False,
                    properties_extractor = False,
                    nusmv_planner = False,
                    sop_file_path = None,
                    nusmv_model_file = None
                    ):
    global BASE_LLM
    global GLOBAL_LLM
    global CODING_LLM
    global REASONING_LLM
    global model_dict

    set_node_llm_state(BASE_LLM, GLOBAL_LLM, CODING_LLM, REASONING_LLM)
    set_solver_strategy(strategy)
    agentic_graph_obj = agentic_graph(filesystem_root=get_filesystem_root())
    agentic_graph_obj.set_node_tools(tools)
    
    if chat_history and not (kripke_extractor or properties_extractor or nusmv_planner):
        agentic_graph_obj.set_previous_conversation(chat_history)
    if memory is None:
        memory = MemorySaver()
    
    builder = StateGraph(ReWOO)
    if nusmv_planner:
        img_name = "NuSMV AGENT"
        print("Using NuSMV AGENT")
        assert nusmv_model_file
        nusmv_planner_obj = NUSMV_planner(filesys_root=get_filesystem_root(),
                                          previous_conversation=agentic_graph_obj.PREVIOUS_CONVERSATION, 
                                          nusmv_model_file=nusmv_model_file)
        agentic_graph_obj.set_auxilary_planner(nusmv_planner=nusmv_planner_obj)

        builder.add_node("planning_node1", agentic_graph_obj.nusmv_planner.planning_node1)
        builder.add_node("planning_node2", agentic_graph_obj.nusmv_planner.planning_node2)
        builder.add_node("replanning_node", agentic_graph_obj.nusmv_planner.replanning_node)
        builder.add_node("nusmv_react_agent_node", agentic_graph_obj.nusmv_planner.react_agent_node)
        builder.add_node("nusmv_tool_execution_node", agentic_graph_obj.nusmv_planner.tool_execution_node)
        builder.add_node("react_agent_node", agentic_graph_obj.react_agent_node)
        builder.add_node("tool_execution_node", agentic_graph_obj.tool_execution_node)
        builder.add_node("verification_node", agentic_graph_obj.verification_node)

        # builder.add_edge(START, "planning_node1")
        # builder.add_edge("planning_node1", "react_agent_node")
        # builder.add_edge("planning_node2", "react_agent_node")
        # builder.add_edge("react_agent_node", "tool_execution_node")
        # builder.add_conditional_edges("tool_execution_node", agentic_graph_obj.solver)
        # builder.add_edge("replanning_node", "react_agent_node")
        # builder.add_conditional_edges("verification_node", agentic_graph_obj.nusmv_plan_orchestrator)
        builder.add_edge(START, "nusmv_react_agent_node")
        builder.add_edge("nusmv_react_agent_node", "nusmv_tool_execution_node")
        builder.add_conditional_edges("nusmv_tool_execution_node", agentic_graph_obj.nusmv_planner.solver)

    elif kripke_extractor:
        img_name = "KRIPKE EXTRACTOR"
        print("USING KRIPKE_EXTRACTOR")
        assert sop_file_path
        kripke_agent_obj = kripke_agent(filesys_root=get_filesystem_root(), 
                                        previous_conversation=agentic_graph_obj.PREVIOUS_CONVERSATION, 
                                        SOP_pth=join(get_filesystem_root(), sop_file_path))
        agentic_graph_obj.set_auxilary_planner(kripke_extractor=kripke_agent_obj)
        
        builder.add_node("planning_node1", agentic_graph_obj.kripke_extractor.planning_node1)
        builder.add_node("planning_node2", agentic_graph_obj.kripke_extractor.planning_node2)
        builder.add_node("replanning_node", agentic_graph_obj.kripke_extractor.replanning_node)
        builder.add_node("kripke_react_agent_node", agentic_graph_obj.kripke_extractor.react_agent_node)
        builder.add_node("kripke_tool_execution_node", agentic_graph_obj.kripke_extractor.tool_execution_node)
        builder.add_node("react_agent_node", agentic_graph_obj.react_agent_node)
        builder.add_node("tool_execution_node", agentic_graph_obj.tool_execution_node)
        builder.add_node("verification_node", agentic_graph_obj.verification_node)

        builder.add_edge(START, "planning_node1")
        builder.add_edge("planning_node1", "react_agent_node")
        builder.add_edge("planning_node2", "react_agent_node")
        builder.add_edge("react_agent_node", "tool_execution_node")
        builder.add_conditional_edges("tool_execution_node", agentic_graph_obj.solver)
        builder.add_edge("replanning_node", "react_agent_node")
        builder.add_conditional_edges("verification_node", agentic_graph_obj.kripke_plan_orchestrator)
        # builder.add_edge(START, "kripke_react_agent_node")
        builder.add_edge("kripke_react_agent_node", "kripke_tool_execution_node")
        builder.add_conditional_edges("kripke_tool_execution_node", agentic_graph_obj.kripke_extractor.solver)

    elif properties_extractor:
        img_name = "PROPERTIES EXTRACTOR"
        print("USING PROPERTIES EXTRACTOR")
        assert sop_file_path
        properties_agent_obj = properties_agent(filesys_root=get_filesystem_root(), 
                                                previous_conversation= agentic_graph_obj.PREVIOUS_CONVERSATION,
                                                SOP_pth= join(get_filesystem_root(), sop_file_path))
        agentic_graph_obj.set_auxilary_planner(properties_extractor=properties_agent_obj)

        builder.add_node("planning_node1", agentic_graph_obj.properties_extractor.planning_node1)
        builder.add_node("planning_node2", agentic_graph_obj.properties_extractor.planning_node2)
        builder.add_node("replanning_node", agentic_graph_obj.properties_extractor.replanning_node)
        builder.add_node("prop_react_agent_node", agentic_graph_obj.properties_extractor.react_agent_node)
        builder.add_node("prop_tool_execution_node", agentic_graph_obj.properties_extractor.tool_execution_node)
        builder.add_node("react_agent_node", agentic_graph_obj.react_agent_node)
        builder.add_node("tool_execution_node", agentic_graph_obj.tool_execution_node)
        builder.add_node("verification_node", agentic_graph_obj.verification_node)

        builder.add_edge(START, "planning_node1")
        builder.add_edge("planning_node1", "react_agent_node")
        builder.add_edge("planning_node2", "react_agent_node")
        builder.add_edge("react_agent_node", "tool_execution_node")
        builder.add_conditional_edges("tool_execution_node", agentic_graph_obj.solver)
        builder.add_edge("replanning_node", "react_agent_node")
        builder.add_conditional_edges("verification_node", agentic_graph_obj.properties_plan_orchestrator)
        # builder.add_edge(START, "prop_react_agent_node")
        builder.add_edge("prop_react_agent_node", "prop_tool_execution_node")
        builder.add_conditional_edges("prop_tool_execution_node", agentic_graph_obj.properties_extractor.solver)

    else:
        img_name = "DEFAULT PLANNER"
        print("USING DEFAULT_PLANNER")
        
        builder.add_node("planning_node", agentic_graph_obj.planning_node)
        builder.add_node("replanning_node", agentic_graph_obj.replanning_node)
        builder.add_node("react_agent_node", agentic_graph_obj.react_agent_node)
        builder.add_node("tool_execution_node", agentic_graph_obj.tool_execution_node)
        builder.add_node("verification_node", agentic_graph_obj.verification_node)

        builder.add_edge(START, "planning_node")
        builder.add_edge("planning_node", "react_agent_node")
        builder.add_edge("react_agent_node", "tool_execution_node")
        builder.add_conditional_edges("tool_execution_node", agentic_graph_obj.solver)
        builder.add_edge("replanning_node", "react_agent_node")
        builder.add_edge("verification_node", END)

    workflow = builder.compile(checkpointer=memory)
    loop = asyncio.get_event_loop()
    img_task = loop.create_task(workflow.get_graph().draw_mermaid_png(
        curve_style=CurveStyle.LINEAR,
        node_colors=NodeStyles(first="#ffdfba", last="#baffc9", default="#fad7de"),
        wrap_label_n_words=9,
        output_file_path=None,
        draw_method=MermaidDrawMethod.PYPPETEER,
        background_color="white",
        padding=10,
    ))
    img_bytes = await img_task
    img = Image.open(BytesIO(img_bytes))
    
    if not isdir(join(get_filesystem_root(), "images")):
        os.mkdir(join(get_filesystem_root(), "images"))
    img.save(join(get_filesystem_root(), "images", img_name+".png"))
    
    print("WORKFLOW CREATED SUCCESSFULLY")
    return workflow, memory, agentic_graph_obj

def initialize_llm():
    global BASE_LLM
    global GLOBAL_LLM
    global CODING_LLM
    global REASONING_LLM
    global GLOBAL_CONVERSATION_MEMORY
    global model_dict, code_model_dict

    GLOBAL_CONVERSATION_MEMORY = ConversationBufferMemory()
    huggingface_hub.login(token=os.getenv("TOKEN"), add_to_git_credential=False)
    
    while True:
        try:
            llama_cpp_server = args.MODEL.strip('/').split('/v1')[0]
            model_dict = requests.get(f"{llama_cpp_server}/v1/models").json()
            model_path = model_dict['data'][args.MODEL_INDEX]['id']
            args.MODEL_PATH = model_path
            if model_dict['data'][args.MODEL_INDEX].get('owned_by', ''):
                print("LLM server type:", model_dict['data'][args.MODEL_INDEX]['owned_by'])
            break
        except Exception as e:
            print(f"Waiting for language server at: {llama_cpp_server}")
            time.sleep(1.)

    llm = ChatOpenAI(
        openai_api_key="EMPTY",
        openai_api_base=f"{llama_cpp_server}/v1",
        model_name=args.MODEL_PATH,
        temperature=args.temperature,
        streaming=True,
        extra_body={'repetition_penalty': 1.03}, 
        top_p=args.top_p,
    )

    try:
        if args.CODE_MODEL is None or args.CODE_MODEL.strip() == "":
            GLOBAL_LLM = llm
            CODING_LLM = llm
            code_model_dict = copy.deepcopy(model_dict)
            set_llm_owner(model_dict, args.MODEL_INDEX)
        else:
            code_llama_cpp_server = args.CODE_MODEL.strip('/').split('/v1')[0]
            code_model_dict = requests.get(f"{code_llama_cpp_server}/v1/models").json()
            code_model_path = code_model_dict['data'][args.CODE_MODEL_INDEX]['id']
            args.CODE_MODEL_PATH = code_model_path
            set_llm_owner(code_model_dict, args.MODEL_INDEX)
            
            code_llm = ChatOpenAI(
                openai_api_key="EMPTY",
                openai_api_base=f"{code_llama_cpp_server}/v1",
                model_name=args.CODE_MODEL_PATH,
                temperature=args.temperature,
                streaming=False if model_dict['data'][args.CODE_MODEL_INDEX]['owned_by']=='llamacpp' else True,
                top_p=args.top_p,
            )
            GLOBAL_LLM = llm
            CODING_LLM = code_llm
    except Exception as e:
        print(f"CODE LLM server at {args.CODE_MODEL} failed with error: {e}")
        GLOBAL_LLM = llm
        CODING_LLM = llm
        
    if len(os.environ["REASONING_MODEL"]):
        if int(os.environ["HF_REASONING_MODEL_INDEX"])>-1:
            REASONING_MODEL_INDEX = int(os.environ["HF_REASONING_MODEL_INDEX"])
        else:
            REASONING_MODEL_INDEX = 0
        reasoning_server = os.environ["REASONING_MODEL"].split('/v1')[0]
        reasoning_model_dict = requests.get(f"{reasoning_server}/v1/models").json()
        reasoning_model_path = reasoning_model_dict['data'][REASONING_MODEL_INDEX]['id']
        REASONING_LLM = ChatOpenAI(
            openai_api_key="EMPTY",
            openai_api_base=f"{reasoning_server}/v1",
            model_name=reasoning_model_path,
            streaming=True,
            temperature=0.7,
            top_p=0.95,
        )
        print(f"USING STANDALONE REASONING SERVER {reasoning_model_path}")
        set_reasoning_llm(REASONING_LLM)
    elif int(os.environ["HF_REASONING_MODEL_INDEX"])>-1:
        REASONING_MODEL_INDEX = int(os.environ["HF_REASONING_MODEL_INDEX"])
        if args.CODE_MODEL:
            reasoning_server = code_llama_cpp_server
        else:
            reasoning_server = llama_cpp_server
        reasoning_model_dict = requests.get(f"{reasoning_server}/v1/models").json()
        reasoning_model_path = reasoning_model_dict['data'][REASONING_MODEL_INDEX]['id']
        REASONING_LLM = ChatOpenAI(
            openai_api_key="EMPTY",
            openai_api_base=f"{reasoning_server}/v1",
            model_name=reasoning_model_path,
            streaming=True,
            temperature=0.7,
            top_p=0.95,
        )
        print(f"USING STANDALONE REASONING SERVER {reasoning_model_path}")
        set_reasoning_llm(REASONING_LLM)
    else:
        print("NO STANDALONE REASONING SERVER")
        REASONING_LLM = None

    print(GLOBAL_LLM)
    print(CODING_LLM)
    BASE_LLM, GLOBAL_LLM, CODING_LLM = set_tools_llm_state(GLOBAL_LLM, CODING_LLM, has_thinking_tokens=True if os.environ["HF_REASONING_MODEL"]=="true" else False)
    return BASE_LLM

def save_memory(memory_var, args, mem_name='st_memory.pkl'):
    with open(os.path.join(args.CACHE, mem_name), 'wb') as fp:
        pkl.dump(memory_var, fp)
    print('Memory Saved.')

def load_memory(memory_var, args, mem_name='st_memory.pkl'):
    if os.path.isfile(os.path.join(args.CACHE, mem_name)):
        with open(os.path.join(args.CACHE, mem_name), 'rb') as fp:
            memory_var = pkl.load(fp)
        print('Memory Loaded.')
    return memory_var

def clear_memory(memory_var, args, mem_name='st_memory.pkl'):
    if memory_var is not None:
        memory_var.clear()
    if os.path.isfile(os.path.join(args.CACHE, mem_name)):
        os.remove(os.path.join(args.CACHE, mem_name))
    print('Memory Cleared')
    return memory_var

def chat_history_to_conversation_buffer(chat_history):
    cb = ConversationBufferMemory()
    for ch in chat_history:
        if ch['type']=='human':
            cb.chat_memory.add_user_message(ch['content'])
        elif ch['type']=='ai':
            cb.chat_memory.add_ai_message(ch['content'])
    return cb

async def sleep_coroutine(future, sleep_timer=0.1):
    await asyncio.sleep(sleep_timer)
    try:
        if future is not None:
            future.set_result("Waited for tokens.")
    except Exception as e:
        print(f"future cancelled with error: {e}")
        
async def add_thought_tokens(IntermediateStepsCallbackObj_fn, stream_type:Literal["RAG", "STREAM"] = "STREAM"):
    global GLOBAL_CANCELLED, GLOBAL_STREAM
    while True:
        local_stream = IntermediateStepsCallbackObj_fn.get_global_stream()
        if local_stream[-1] == f"<END_OF_{stream_type}>" or GLOBAL_CANCELLED:
            break
        else:
            GLOBAL_STREAM = IntermediateStepsCallbackObj_fn.add_to_global_stream("")
            await asyncio.sleep(0.1)

async def check_cancellation(stream_type:Literal["RAG", "STREAM"] = "STREAM"):
    global GLOBAL_STREAM, GLOBAL_CANCELLED
    while not GLOBAL_CANCELLED:
        if len(GLOBAL_STREAM)>0 and GLOBAL_STREAM[-1] == f"<END_OF_{stream_type}>":
            break
        await asyncio.sleep(0.1)

async def stream_tokens(stream_idx, IntermediateStepsCallbackObj_fn, event_loop=None, 
                        stream_type:Literal["RAG", "STREAM", "MESSAGE"] = "STREAM"):
    global GLOBAL_CANCELLED
    stream_end_flag = False
    while True:
        local_stream = IntermediateStepsCallbackObj_fn.get_global_stream()
        if stream_end_flag or GLOBAL_CANCELLED:
            break
        while stream_idx < len(local_stream):
            output = local_stream[stream_idx]
            output_json = json.dumps({'type': 'token', 'content': output})
            if stream_type=="MESSAGE":
                yield {"token": output}
            else:
                yield f"data: {output_json}\n\n"
            if local_stream[stream_idx] == f"<END_OF_{stream_type}>":
                stream_end_flag = True
            stream_idx += 1
        if event_loop is None:
            await asyncio.sleep(0.1)
        else:
            future_event = event_loop.create_future()
            sleeping_task = event_loop.create_task(sleep_coroutine(future_event))
            if stream_type=="MESSAGE":
                await future_event
            else:
                print(await future_event)
                
async def combined_generators(workflow_generator, stream_generator):
    global GLOBAL_STREAM, GLOBAL_STREAM_CALLBACK
    global GLOBAL_CANCELLED
    tasks = [asyncio.create_task(workflow_generator.__anext__()), asyncio.create_task(stream_generator.__anext__())]
    workflow_is_complete = False
    
    while tasks:
        if workflow_is_complete or GLOBAL_CANCELLED:
            try:
                for t in tasks:
                    t.cancel()
            except:
                pass
            tasks = []
            break
        
        done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
        for task in done:
            try:
                yield task.result()
                if task==tasks[0]:
                    tasks.remove(task)
                    try:
                        tasks.insert(0, asyncio.create_task(workflow_generator.__anext__()))
                    except:
                        GLOBAL_STREAM = GLOBAL_STREAM_CALLBACK.add_to_global_stream("<END_OF_MESSAGE>")
                        workflow_is_complete = True
                        print("Workflow is complete")
                        break
                else:
                    tasks.remove(task)
                    try:
                        tasks.append(asyncio.create_task(stream_generator.__anext__()))
                    except:
                        print("Stream is complete")
                break
            except Exception as e:
                GLOBAL_STREAM = GLOBAL_STREAM_CALLBACK.add_to_global_stream("<END_OF_MESSAGE>")
                workflow_is_complete = True
                print("Workflow is complete")
                break

async def process_agent_stream(input: QueryInput):
    global GLOBAL_LLM, CODING_LLM
    global GLOBAL_STREAM, GLOBAL_STREAM_CALLBACK
    global GLOBAL_CONVERSATION_MEMORY, GLOBAL_CANCELLED
    GLOBAL_CANCELLED = False
    
    if input.chat_history:
        GLOBAL_CONVERSATION_MEMORY = chat_history_to_conversation_buffer(input.chat_history)
    else:
        GLOBAL_CONVERSATION_MEMORY = ConversationBufferMemory()
    chat_history = [(f.type, f.content) for f in GLOBAL_CONVERSATION_MEMORY.chat_memory.messages]
    if input.cache_dir is not None:
        set_filesystem_root(input.cache_dir)
    
    workflow, memory, agentic_graph_obj = await initialize_agent(None, chat_history=chat_history, 
                                                            tools=input.selected_tools, 
                                                            strategy=input.strategy.lower().replace(' ', '-'), 
                                                            kripke_extractor=input.kripke_extractor,
                                                            properties_extractor=input.properties_extractor,
                                                            nusmv_planner=input.nusmv_planner,
                                                            sop_file_path=input.query,
                                                            nusmv_model_file=input.query,
                                                           )
    messages = [HumanMessage(input.query)]
    event_loop = asyncio.get_event_loop()
    GLOBAL_STREAM = GLOBAL_STREAM_CALLBACK.init_global_stream()
    
    config = {"configurable": {"thread_id": input.thread_id if input.thread_id else str(uuid.uuid4())},
            "recursion_limit": 100}
    # output = workflow.invoke(ReWOO(messages=messages), config)
    # print(output)
    # try:
    #     output = workflow.invoke(ReWOO(messages=messages), config)
    #     print(output)
    # except Exception as e:
    #     print(f"Graph failed with error: {e}")
    # outp = workflow.invoke(ReWOO(messages=messages), config)
    # print(outp)
    
    workflow_generator = workflow.astream(ReWOO(messages=messages), config)
    stream_generator = stream_tokens(0, GLOBAL_STREAM_CALLBACK, event_loop, stream_type="MESSAGE")
    combined_generator = combined_generators(workflow_generator, stream_generator)
    print("Generator initialized")
    
    current_step = 0
    async for chunk in combined_generator:
        # print(chunk)
        if GLOBAL_CANCELLED:
            break
        message_node = list(chunk.keys())[0]
        if message_node=="token":
            response = chunk['token']
            if response.startswith("data: ") and response.endswith("\n\n"):
                response = json.loads(response[len("data: "):].strip())["content"]
            output_json = json.dumps({'type': message_node, 'content': response})
            yield f"data: {output_json}\n\n"
        else:
            if "results" in agentic_graph_obj.agent_state.keys():
                if current_step==0 and agentic_graph_obj.agent_state["results"] is not None:
                    output_json = json.dumps({'type': "token", 'content': '## Step {}. {} \n\n'.format(len(agentic_graph_obj.agent_state["results"]), 
                                                                agentic_graph_obj.agent_state["steps"][len(agentic_graph_obj.agent_state["results"])-1])})
                    current_step = len(agentic_graph_obj.agent_state["results"])
                    yield f"data: {output_json}\n\n"
            
            messages_list = chunk[message_node]['messages']
            pretty_strings = []
            if type(messages_list)==list:
                for m in messages_list[-1:]:
                    pretty_strings.append(m.pretty_repr(html=False).replace('\n', '\n\n'))
                response = '\n\n'.join(pretty_strings)
                output_json = json.dumps({'type': message_node, 'content': response+'\n\n'})
                yield f"data: {output_json}\n\n"
            else:
                m = messages_list
                pretty_strings.append(m.pretty_repr(html=False).replace('\n', '\n\n'))
                response = '\n\n'.join(pretty_strings)
                output_json = json.dumps({'type': message_node, 'content': response+'\n\n'})
                yield f"data: {output_json}\n\n"
            
            if "results" in agentic_graph_obj.agent_state.keys():
                if agentic_graph_obj.agent_state["results"] is None:
                    current_step=0
                    output_json = json.dumps({'type': "token", 'content': ""})
                    yield f"data: {output_json}\n\n"
                elif len(agentic_graph_obj.agent_state["results"])>current_step:
                    output_json = json.dumps({'type': "token", 'content': '## Step {}. {} \n\n'.format(len(agentic_graph_obj.agent_state["results"]), 
                                                                agentic_graph_obj.agent_state["steps"][len(agentic_graph_obj.agent_state["results"])-1])})
                    current_step = len(agentic_graph_obj.agent_state["results"])
                    yield f"data: {output_json}\n\n"
        
    if not GLOBAL_CANCELLED and input.strategy=="ReWOO":
        # history = workflow.get_state(config)
        # answer = "\n\n".join([m.content for m in history.values['messages'][1:-2]]+['Final Answer: \n\n'+history.values['messages'][-2].content])
        step_response = ["\n\n".join([g.content.strip() for g in agentic_graph_obj.GLOBAL_ANSWERS[f]]) for f in range(len(agentic_graph_obj.GLOBAL_ANSWERS))]
        answer = "\n\n".join([f"## Step {i_f+1}. {f.strip()} \n\n{step_response[i_f]}" for i_f, f in enumerate(agentic_graph_obj.GLOBAL_PLAN)])
        chat_history.append(("human", input.query))
        chat_history.append(("ai", answer))
        GLOBAL_CONVERSATION_MEMORY = chat_history_to_conversation_buffer([{'type':f[0], 'content':f[1]} for f in chat_history])
        output_json = json.dumps({'type': 'DONE', 'content': answer})
        yield f"data: {output_json}\n\n"
        
    if not GLOBAL_CANCELLED and input.strategy.lower().replace(' ', '-')=="plan-and-execute":
        step_response = ["\n\n".join([g.content.strip() for g in agentic_graph_obj.GLOBAL_ANSWERS[f]]) for f in range(len(agentic_graph_obj.GLOBAL_ANSWERS))]
        answer = "\n\n".join([f"## Step {i_f+1}. {f.strip()} \n\n{step_response[i_f]}" for i_f, f in enumerate(agentic_graph_obj.GLOBAL_PLAN)])
        chat_history.append(("human", input.query))
        chat_history.append(("ai", answer))
        GLOBAL_CONVERSATION_MEMORY = chat_history_to_conversation_buffer([{'type':f[0], 'content':f[1]} for f in chat_history])
        output_json = json.dumps({'type': 'DONE', 'content': answer})
        yield f"data: {output_json}\n\n"
        
    
@app.post("/process_agent")
async def process_agent(input: QueryInput):
    try:
        if os.path.isfile(args.GUIDANCE_FILE) and not input.guidance_prompt:
            with open(args.GUIDANCE_FILE, 'r') as fp:
                guidance_prompt = fp.read()
            input.guidance_prompt = guidance_prompt
        if len(input.selected_tools)==0:
            input.selected_tools = ["direct_response"]
        
        set_rag_update(input.use_rag)
        set_corrective_multihop(input.use_crag)
        current_task = process_agent_stream(input)
        return StreamingResponse(current_task, media_type="text/event-stream")
    except ValueError as ve:
        print(str(ve))
        print("Stack trace:")
        print(traceback.format_exc())
        raise HTTPException(status_code=400, detail=str(ve))
    except Exception as e:
        print(f"Error processing request: {str(e)}")
        print("Stack trace:")
        print(traceback.format_exc())
        raise HTTPException(status_code=422, detail=str(e))
    
@app.post("/update_rag")
async def update_rag(files: List[UploadFile] = File(...)):
    global GLOBAL_VECTORSTORE
    try:
        assert len(files)>0
        temp_dir = f"{args.CACHE}/temp_uploads"
        os.makedirs(temp_dir, exist_ok=True)
        file_paths = []
        for file in files:
            file_path = os.path.join(temp_dir, file.filename)
            with open(file_path, "wb", buffering = 0) as buffer:
                buffer.write(await file.read())
            file_paths.append(DocumentFile(file_path))
        current_task = process_and_embed_documents(file_paths, GLOBAL_VECTORSTORE)
        return StreamingResponse(current_task, media_type="text/event-stream")
    except Exception as e:
        print(f"RAG upload terminated with error: {e}")
        print("Stack trace:")
        print(traceback.format_exc())
        raise HTTPException(status_code=500, detail=str(e))
    
@app.post("/reset_rag")
async def reset_rag_endpoint():
    global GLOBAL_SQL_UPLOAD_ROOT
    global GLOBAL_VECTORSTORE, GLOBAL_RETRIEVER
    try:
        current_loop = asyncio.get_running_loop()
        rag_process = current_loop.run_in_executor(None, lambda: reset_rag())
        await rag_process
        temp_dir = f"{args.CACHE}/temp_uploads"
        if isdir(temp_dir):
            file_paths = [join(temp_dir, f) for f in listdir(temp_dir) if isfile(join(temp_dir, f))]
            for file_path in file_paths:
                os.unlink(file_path)
        temp_dir = f"{args.CACHE}/papers"
        if isdir(temp_dir):
            file_paths = [join(temp_dir, f) for f in listdir(temp_dir) if isfile(join(temp_dir, f))]
            for file_path in file_paths:
                os.unlink(file_path)
        
        if GLOBAL_SQL_UPLOAD_ROOT is not None and isdir(GLOBAL_SQL_UPLOAD_ROOT):
            with open("./sql.yml", 'r') as fp:
                sql_metadata = yaml.safe_load(fp)
                file_names = [f for f in listdir(GLOBAL_SQL_UPLOAD_ROOT) if isfile(join(GLOBAL_SQL_UPLOAD_ROOT, f))]
                for file_name in file_names:
                    if file_name in [sql_metadata["patient_file"], sql_metadata["clinvar_file"], sql_metadata['database_name']+'.db']:
                        continue
                    os.unlink(join(GLOBAL_SQL_UPLOAD_ROOT, file_name))
        return {"message": "Knowledge base reset successful."}
    except Exception as e:
        print(f"RAG reset terminated with error: {e}")
        print("Stack trace:")
        print(traceback.format_exc())
        raise HTTPException(status_code=500, detail=str(e))
    
@app.post("/reset_chat_memory")
async def reset_chat_memory():
    global GLOBAL_CONVERSATION_MEMORY
    try:
        GLOBAL_CONVERSATION_MEMORY = clear_memory(GLOBAL_CONVERSATION_MEMORY, args)
    except Exception as e:
        print(f"RAG reset terminated with error: {e}")
        print("Stack trace:")
        print(traceback.format_exc())
        raise HTTPException(status_code=500, detail=str(e))
    return {"message": "Conversation memory reset"}
    
@app.post("/get_chat_memory")
async def get_chat_memory() -> Dict:
    chat_history = []
    try:
        if len(GLOBAL_CONVERSATION_MEMORY.chat_memory.messages)>0:
            chat_history = [{'type': f.type, 'content': f.content} for f in GLOBAL_CONVERSATION_MEMORY.chat_memory.messages if f.type!='system']
    except Exception as e:
        print(f"RAG reset terminated with error: {e}")
        print("Stack trace:")
        print(traceback.format_exc())
        raise HTTPException(status_code=500, detail=str(e))
    return {'chat_history': chat_history}

@app.post("/cancel_query")
async def cancel_query():
    global GLOBAL_CANCELLED
    GLOBAL_CANCELLED = True
    try:
        await asyncio.sleep(1.0)
        tasks = asyncio.all_tasks()
        for task in tasks:
            if task.get_name() in [f'Task-{f}' for f in range(1, 6)]:
                continue
            print(task)
            try:
                task.cancel()
            except asyncio.CancelledError():
                continue
    except Exception as e:
        print(f"Query cancelled with error: {e}")
        print("Stack trace:")
        print(traceback.format_exc())
        raise HTTPException(status_code=500, detail=str(e))
    return {"message": "Query cancellation requested."}

@app.post("/get_thread_ids")
async def get_thread_ids():
    task_ids = []
    try:
        tasks = asyncio.all_tasks()
        for task in tasks:
            print(task)
            task_ids.append(task)
    except Exception as e:
        print(f"Query cancelled with error: {e}")
        print("Stack trace:")
        print(traceback.format_exc())
        raise HTTPException(status_code=500, detail=str(e))
    return {"message": str(task_ids)}

@app.get("/get_tools")
async def get_tools() -> Dict[str,List[str]]:
    return {"tool_names": get_available_tools()}
    
if __name__ == "__main__":
    assert os.getenv("USER")
    GLOBAL_STREAM_CALLBACK = IntermediateStepsCallback()
    GLOBAL_STREAM = GLOBAL_STREAM_CALLBACK.init_global_stream()
    llm = initialize_llm()
    vectorstore, retriever = initialize_rag()
    set_rag_stream(GLOBAL_STREAM)
    set_corrective_multihop(False)
    
    init_crag_llm_state(os.getenv("MODEL"), model_dict['data'][args.MODEL_INDEX]['id'])
    agent_obj = InformationRetrieverAgent("Model Checking")
    # agent_obj.run("")
    try:
        port_num = args.port_num
        if port_num<8000 or port_num>8888:
            port_num = 8000 + (port_num%1000)
    except Exception as e:
        port_num = 8888
    print(f'Starting uvicorn at port number: {port_num}')
    uvicorn.run(app, host="localhost", port=port_num)