import os
import io
import gc
import pandas as pd
import streamlit as st
from vllm import LLM, SamplingParams
import pyarrow.feather as feather

st.set_page_config(page_title="vLLM Chat")

############################
# Ensure required session keys
############################
if "messages" not in st.session_state:
    st.session_state["messages"] = []
if "llm" not in st.session_state:
    st.session_state["llm"] = None  # placeholder until a model is loaded

############################
# Sidebar – model controls #
############################
with st.sidebar:
    st.header("Model setup")
    model_path = st.text_input("Model path", value="path/to/model")
    cuda_device = st.text_input("CUDA device id", value="0")
    gpu_util = st.slider("GPU memory fraction", 0.1, 1.0, 0.9, 0.05)

    # Load / reload button
    if st.button("Load / Reload model", key="load_model"):
        # If another model is already in memory, close & free it first
        if st.session_state["llm"] is not None:
            try:
                st.session_state["llm"].close()  # vLLM clean‑up
            except Exception as e:
                st.warning(f"Previous model cleanup error: {e}")
            st.session_state["llm"] = None
            gc.collect()

        os.environ["CUDA_VISIBLE_DEVICES"] = cuda_device
        st.session_state["llm"] = LLM(model=model_path, gpu_memory_utilization=gpu_util)
        st.success(f"Model '{model_path}' loaded on GPU {cuda_device}")

    # Reset conversation only (keep model)
    if st.button("Clear conversation", key="clear_chat"):
        st.session_state["messages"] = []
        st.info("Conversation cleared")

############################
# DataFrame loader section #
############################
st.header("DataFrame viewer")

col1, col2 = st.columns(2)

# --- Option A: Path on disk ---
with col1:
    file_path = st.text_input("Path to JSON/CSV/Feather on server")
    if st.button("Load from path", key="load_path") and file_path:
        try:
            if file_path.lower().endswith(".json"):
                df = pd.read_json(file_path)
            elif file_path.lower().endswith(".csv"):
                df = pd.read_csv(file_path)
            elif file_path.lower().endswith(".feather"):
                df = feather.read_feather(file_path)
            else:
                st.error("Unsupported extension. Use .json, .csv or .feather")
                df = None
            if df is not None:
                st.session_state["df"] = df
                st.success("DataFrame loaded from path")
        except Exception as e:
            st.error(f"Failed to load: {e}")

# --- Option B: Upload file ---
with col2:
    upload = st.file_uploader("Upload JSON, CSV, or Feather")
    if upload is not None:
        try:
            if upload.name.lower().endswith(".json"):
                df = pd.read_json(upload)
            elif upload.name.lower().endswith(".csv"):
                df = pd.read_csv(upload)
            else:  # feather
                df = feather.read_feather(io.BytesIO(upload.getvalue()))
            st.session_state["df"] = df
            st.success("DataFrame loaded from upload")
        except Exception as e:
            st.error(f"Failed to load: {e}")

# Show the DataFrame if available
if "df" in st.session_state:
    st.dataframe(st.session_state["df"])

############################
# Chat interface with vLLM #
############################
st.header("Chat")

system_prompt = st.text_area("System message", "You are a helpful assistant.", height=100)

if st.session_state["llm"] is not None:
    # Display existing messages
    for msg in st.session_state.get("messages", []):
        with st.chat_message(msg["role"]):
            st.markdown(msg["content"])

    user_input = st.chat_input("Type your message")
    if user_input:
        st.session_state["messages"].append({"role": "user", "content": user_input})
        with st.chat_message("user"):
            st.markdown(user_input)

        # Compose prompt including history
        parts = [f"system: {system_prompt}"] + [
            f"{m['role']}: {m['content']}" for m in st.session_state["messages"]
        ] + ["assistant:"]
        prompt_text = "\n".join(parts)

        sampling = SamplingParams(temperature=0.1, max_tokens=2048)
        output = st.session_state["llm"].generate([prompt_text], sampling)[0].outputs[0].text

        st.session_state["messages"].append({"role": "assistant", "content": output})
        with st.chat_message("assistant"):
            st.markdown(output)
else:
    st.info("Load a model to start chatting.")
