import streamlit as st
from transformers import AutoTokenizer

import prometheus_client
# Define a dummy `disable_created_metrics` function if it does not exist
if not hasattr(prometheus_client, "disable_created_metrics"):
    setattr(prometheus_client, "disable_created_metrics", lambda: None)


from vllm import LLM, SamplingParams
from data_utils import SYSTEM_MESSAGE_INSTRUCTION_MODEL
from grpo_data_util import generate_tictactoe_prompt, extract_final_answer

MODEL_CHECKPOINT="/mnt/data/data/stlm-logic/updated-checkpoints/meta-llama-Llama-3.2-1B-Instruct_canconical-symmetry-grouping_legal_move_nl_grpo-nl-expt-final-fixed-prompt-and-data-loading/checkpoint-600"

# Load tokenizer
@st.cache_resource
def load_tokenizer(model_checkpoint):
    tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token})
    return tokenizer

# Load vLLM model
@st.cache_resource
def load_model(model_checkpoint):
    return LLM(model_checkpoint)

# Chat formatting for Meta-Llama
def format_chat_message(history, user_input):
    messages = [{"role": "system", "content": SYSTEM_MESSAGE_INSTRUCTION_MODEL}]
    messages.extend(history)
    messages.append({"role": "user", "content": user_input})
    
    formatted_chat = "\n".join([f"{m['role'].capitalize()}: {m['content']}" for m in messages])
    return formatted_chat, messages

# Streamlit App UI
st.title("🧠 Tic-Tac-Toe AI - Interactive Testing")

# Sidebar settings
with st.sidebar:
    st.header("⚙️ Settings")
    model_checkpoint = st.text_input("Model Checkpoint Path", value=MODEL_CHECKPOINT)
    max_new_tokens = st.slider("Max Tokens", min_value=10, max_value=512, value=256)
    temperature = st.slider("Temperature", min_value=0.0, max_value=1.0, value=0.7)
    top_k = st.slider("Top-K", min_value=1, max_value=100, value=50)
    top_p = st.slider("Top-P", min_value=0.1, max_value=1.0, value=0.9)

# Load model and tokenizer
tokenizer = load_tokenizer(model_checkpoint)
engine = load_model(model_checkpoint)
sampling_params = SamplingParams(max_tokens=max_new_tokens, temperature=temperature, top_k=top_k, top_p=top_p)

# Mode selection
mode = st.radio("Choose Interaction Mode:", ["Game", "Chat"])

# Chat history
if "chat_history" not in st.session_state:
    st.session_state.chat_history = []

# User input
user_input = st.text_input("Enter your message:")

# Process input when user submits
if st.button("Submit"):
    if user_input.strip():
        if mode == "Game":
            # Generate Tic-Tac-Toe move suggestion
            sample = {"text_instruction": user_input, "board": []}
            prompt_dict = generate_tictactoe_prompt(sample, tokenizer, "nl", instruct_model=True)
            prompt_text = prompt_dict.get("prompt")
            st.session_state.chat_history = []  # Reset history in game mode
        else:
            # Chat mode with context
            prompt_text, st.session_state.chat_history = format_chat_message(st.session_state.chat_history, user_input)

        # Generate response
        with st.spinner("🤖 Thinking..."):
            outputs = engine.generate(prompt_text, sampling_params)
            completion = outputs[0].outputs[0].text if outputs else "No response."

        # Display response
        if mode == "Game":
            predicted_move = extract_final_answer(completion)
            st.success(f"✨ Suggested Move: {predicted_move}")
        else:
            st.markdown(f"**🤖 AI:** {completion}")
            st.session_state.chat_history.append({"role": "assistant", "content": completion})

        # Maintain chat history in chat mode
        if mode == "Chat":
            st.subheader("💬 Chat History")
            for message in st.session_state.chat_history[-10:]:
                role = "🧑‍💻 You" if message["role"] == "user" else "🤖 AI"
                st.text(f"{role}: {message['content']}")
