import streamlit as st
import json
from PIL import Image
import io

# --- Page Config ---
st.set_page_config(
    page_title="CerebraGloss-Bench Visualizer",
    layout="wide"
)

# --- Data Loading Functions (with cache) ---
@st.cache_data
def load_benchmark_data(filepath):
    """Loads and caches the benchmark.json file"""
    try:
        with open(filepath, 'r', encoding='utf-8') as f:
            return json.load(f)
    except FileNotFoundError:
        st.error(f"Error: Benchmark file '{filepath}' not found. Please ensure it's in the correct location.")
        return None
    except json.JSONDecodeError:
        st.error(f"Error: File '{filepath}' is not a valid JSON file.")
        return None

def parse_model_output(uploaded_file):
    """Parses the uploaded .jsonl file"""
    model_data = {}
    # Convert UploadedFile to a text stream
    stringio = io.StringIO(uploaded_file.getvalue().decode("utf-8"))
    for line in stringio:
        try:
            data = json.loads(line)
            key = data.get("image")
            if key:
                model_data[key] = data.get("output", {})
        except json.JSONDecodeError:
            st.warning(f"Skipping unparsable line: {line.strip()}")
    return model_data

# --- Main Program ---
st.title("CerebraGloss-Bench Visualizer")

# Load benchmark data
benchmark_data = load_benchmark_data("benchmark.json")

if benchmark_data:
    # Initialize session_state
    if 'model_data' not in st.session_state:
        st.session_state.model_data = {}

    # --- Sidebar ---
    with st.sidebar:
        st.header("Control Panel")

        # Sample selection
        sample_ids = list(benchmark_data.keys())
        selected_sample_id = st.selectbox(
            "Select a sample to view:",
            options=sample_ids,
            index=0
        )

        # Model output upload
        uploaded_file = st.file_uploader(
            "Upload model output file (.jsonl)",
            type=["jsonl"]
        )
        if uploaded_file is not None:
            # If a new file is uploaded, parse and store it in session_state
            st.session_state.model_data = parse_model_output(uploaded_file)
            st.success(f"Successfully loaded and parsed {uploaded_file.name}!")

    # --- Main Layout ---
    col1, col2 = st.columns([2, 3])

    # --- Left Column: Image Display ---
    with col1:

        st.header("EEG Image")
        
        image_path = f"img/{selected_sample_id.replace('.npy', '.jpg')}"
        
        try:
            image = Image.open(image_path)
            st.image(image, caption=selected_sample_id, width='stretch')
        except FileNotFoundError:
            st.warning(f"Image file not found: {image_path}")
            st.image(Image.new('RGB', (800, 400), color = 'grey'), caption="Image not found")

    # --- Right Column: Information Display ---
    with col2:
        st.markdown("""
            <style>
            /* Target the column container generated by Streamlit and make it scrollable */
            div[data-testid="stHorizontalBlock"] > div:nth-child(2) > div {
                height: 95vh; /* Set a height close to the screen height */
                overflow-y: auto; /* Show vertical scrollbar when content overflows */
                padding-right: 1rem; /* Add space for the scrollbar */
            }
            </style>
            """, unsafe_allow_html=True)
        
        benchmark_sample_data = benchmark_data.get(selected_sample_id, {})
        model_sample_data = st.session_state.model_data.get(selected_sample_id, {})

        # --- 1. Summary Comparison ---
        st.subheader("Summary")
        if benchmark_sample_data:
            with st.expander("Standard Answer Summary", expanded=True):
                st.markdown(f'{benchmark_sample_data.get("summary", "N/A")}')
        
        if model_sample_data:
            with st.expander("Model Output Summary", expanded=True):
                st.markdown(f'{model_sample_data.get("summary", "N/A")}')
        elif st.session_state.model_data:
            st.warning("No corresponding model output for this sample.")

        st.divider()

        # --- 2. Conversation Comparison ---
        st.subheader("Conversation")
        if benchmark_sample_data:
            user_questions = [q for q in benchmark_sample_data.get("conversation", []) if q.startswith("**User:**")]
            benchmark_answers = [a for a in benchmark_sample_data.get("conversation", []) if a.startswith("**Agent:**")]
            
            # Display user questions
            for q in user_questions:
                 with st.chat_message("user"):
                    st.markdown(q.replace("**User:**", "", 1).strip())
            
            # Display standard answer
            with st.chat_message("assistant"):
                st.markdown("**Standard Answer:**")
                for ans in benchmark_answers:
                    st.markdown(ans.replace("**Agent:**", "", 1).strip())

            # Display model output
            if model_sample_data:
                with st.chat_message("assistant"):
                    st.markdown("**Model Output:**")
                    model_convo_str = model_sample_data.get("conversation", "")
                    st.markdown(model_convo_str)

        st.divider()

        # --- 3. Multiple Choice Question Comparison ---
        st.subheader("Multiple Choice Question")
        if benchmark_sample_data:
            selection_list = benchmark_sample_data.get("selection", ["", ""])
            full_question_text = selection_list[0]
            benchmark_answer = selection_list[1] if len(selection_list) > 1 else "N/A"

            # Parse question and options
            question_parts = full_question_text.split('\n')
            question_title = ""
            options = []
            # Find the position of the first option (e.g., A))
            first_option_index = -1
            for i, part in enumerate(question_parts):
                if part.strip().startswith('A)'):
                    first_option_index = i
                    break
            
            if first_option_index != -1:
                question_title = "\n".join(question_parts[:first_option_index])
                options = question_parts[first_option_index:]
            else:
                # If parsing fails, display as is
                question_title = full_question_text

            # Create two columns
            q_col, a_col = st.columns(2)

            with q_col:
                st.markdown("**Question:**")
                st.markdown(question_title)
                if options:
                    # Use a code block for nice formatting of options
                    st.code("\n".join(options), language=None)

            with a_col:
                st.markdown("**Answers:**")
                st.markdown(f"**Standard Answer:** {benchmark_answer}")

                if model_sample_data:
                    model_answer = model_sample_data.get("selection", "N/A")
                    st.markdown(f"**Model Output:** {model_answer}")
        
    if not st.session_state.model_data:
            st.info("Upload a model output file in the sidebar to see a comparison.")
