import argparse
from collections import Counter
from pathlib import Path
from typing import Any

import pandas as pd
import streamlit as st
import yaml

from deception_detection.utils import REPO_ROOT


def load_file_data(file_path: Path) -> None:
    """Load and parse data from the specified file path."""
    try:
        with open(file_path) as file:
            data = yaml.safe_load(file)
            if not isinstance(data, list):
                st.error("File must contain a list of dictionaries")
            st.session_state.data = data

    except Exception as e:
        st.error(f"Error loading file: {e}")


def save_yaml_data(data: list[dict[str, Any]], file_path: Path) -> bool:
    """Save the modified data back to the file."""
    try:
        with open(file_path, "w", encoding="utf-8") as file:
            yaml.dump(data, file, sort_keys=False, allow_unicode=True)
        return True
    except Exception as e:
        st.error(f"Error saving YAML: {e}")
        return False


def get_colored_value(value: float, min: float = 1, max: float = 5) -> str:
    """Get markdown string of value with appropriate color"""
    if value == min:
        return f":red[{value}]"
    elif value == max:
        return f":green[{value}]"
    else:
        return f":orange[{value}]"


def update_status(id: int, new_status: str, file_path: Path) -> None:
    """Update the status of a scenario and save the changes to the YAML file."""
    for item in st.session_state.data:
        if item["id"] == id:
            item["status"] = new_status
            break

    if save_yaml_data(st.session_state.data, file_path):
        st.success(f"Updated and saved status for Scenario {id} to {new_status}")
    else:
        st.error(f"Failed to save status change for Scenario {id}")


def create_index_table(data: list[dict[str, Any]]) -> None:
    """Create and display a table of all scenarios with their ratings."""
    # Create DataFrame with relevant columns
    table_data = []
    for item in data:
        table_data.append(
            {
                "ID": item["id"],
                "Status": item.get("status", "unexamined"),
                "Clarity": item["clarity_rating"]["rating"],
                "Goal": item["goal_rating"]["rating"],
                "Scenario": item["scenario"][:50] + "...",  # First 100 chars
            }
        )

    df = pd.DataFrame(table_data)

    def make_link(x: int) -> str:
        return f"?view=detail&id={x}"

    def make_status_emoji(x: str) -> str:
        return {"approved": "✅", "rejected": "❌", "unexamined": "❔"}[x]

    df["ID"] = df["ID"].apply(make_link)
    df["Status"] = df["Status"].apply(make_status_emoji)
    st.dataframe(
        df,
        column_config={
            "ID": st.column_config.LinkColumn(
                "ID", width="small", display_text=r"\?view=detail&id=(\d*)"
            ),
            "Status": st.column_config.TextColumn("Status", width="small"),
            "Clarity": st.column_config.ProgressColumn(
                "Clarity", width="small", min_value=2, max_value=5, format="%f"
            ),
            "Goal": st.column_config.ProgressColumn(
                "Goal", width="small", min_value=2, max_value=5, format="%f"
            ),
            "Scenario": st.column_config.TextColumn("Scenario"),
        },
        hide_index=True,
    )


def parse_args() -> Path:
    """Parse command line arguments."""
    parser = argparse.ArgumentParser(description="YAML Scenario Viewer")
    parser.add_argument(
        "--file",
        type=Path,
        default=REPO_ROOT / "data/roleplaying/output/output_20241025_105712.scored.yaml",
        help="Path to the YAML file containing scenario data",
    )
    # Need to handle Streamlit's own command line args
    args, _ = parser.parse_known_args()
    return args.file


def main() -> None:
    file_path = parse_args()

    st.set_page_config(
        page_title="Roleplaying Data",
        # layout="wide",
    )

    st.title("YAML Scenario Viewer")

    # Load data
    if "data" not in st.session_state:
        load_file_data(file_path)

    # Store file_path in session state for use in other functions
    if "file_path" not in st.session_state:
        st.session_state.file_path = file_path

    # Get all valid scenario IDs
    all_scenario_ids = [item["id"] for item in st.session_state.data]

    # Get scenario ID from URL query parameter
    params = st.query_params
    selected_id = params.get("id")
    view = params.get("view", "table")

    # Sidebar for navigation
    st.sidebar.title("Navigation")

    # Add a refresh button
    if st.sidebar.button("Reload Data"):
        load_file_data(st.session_state.file_path)

    # View selector
    view_col1, view_col2 = st.sidebar.columns(2)
    if view_col1.button("Table View"):
        params["view"] = "table"
        params["id"] = ""
        st.rerun()
    if view_col2.button("Detail View"):
        params["view"] = "detail"
        if not selected_id:
            params["id"] = str(all_scenario_ids[0])
        st.rerun()

    if view == "table":
        create_index_table(st.session_state.data)
    else:
        # Convert to int if present, otherwise use first scenario
        try:
            selected_id = int(selected_id) if selected_id else all_scenario_ids[0]
            if selected_id not in all_scenario_ids:
                selected_id = all_scenario_ids[0]
        except ValueError:
            selected_id = all_scenario_ids[0]

        # Update URL if needed
        if str(selected_id) != params.get("id"):
            params["id"] = str(selected_id)

        # Add Next/Previous navigation
        col1, col2 = st.sidebar.columns(2)
        current_index = all_scenario_ids.index(selected_id)

        if col1.button("Previous") and current_index > 0:
            params["id"] = str(all_scenario_ids[current_index - 1])
            st.rerun()

        if col2.button("Next") and current_index < len(all_scenario_ids) - 1:
            params["id"] = str(all_scenario_ids[current_index + 1])
            st.rerun()

        # Main content
        selected_item = next(
            (item for item in st.session_state.data if item["id"] == selected_id),
            None,
        )

        if selected_item:
            main_content(selected_item)
            sidebar_stats()


def sidebar_stats():
    st.sidebar.divider()
    st.sidebar.subheader("Statistics")
    total_scenarios = len(st.session_state.data)
    scenarios_by_status = Counter(
        item.get("status", "unexamined") for item in st.session_state.data
    )
    st.sidebar.write(f"Total Scenarios: {total_scenarios}")
    for status, count in scenarios_by_status.items():
        st.sidebar.write(f"{status.capitalize()}: {count}")


def main_content(selected_item: dict[str, Any]) -> None:
    # Display scenario details
    st.header(f"Scenario {selected_item['id']}")

    # Scenario content
    st.subheader("Scenario Description")
    st.write(selected_item["scenario"])

    # Question and answer prefix
    col1, col2 = st.columns(2)
    with col1:
        st.subheader("Question")
        st.write(selected_item["question"])
    with col2:
        st.subheader("Answer Prefix")
        st.write(selected_item["answer_prefix"])

    # Ratings
    st.subheader("Ratings")
    col1, col2 = st.columns(2)
    with col1, st.container(border=True):
        st.write("Clarity Rating:", selected_item["clarity_rating"]["rating"])
        with st.expander("Explanation", expanded=False):
            st.write(selected_item["clarity_rating"]["explanation"])
    with col2, st.container(border=True):
        st.write("Goal Rating:", selected_item["goal_rating"]["rating"])
        with st.expander("Explanation", expanded=False):
            st.write(selected_item["goal_rating"]["explanation"])

    # Replace radio buttons with Approve/Reject buttons
    current_status = selected_item.get("status", "unexamined")
    st.write(f"Current status: **{current_status}**")

    if st.button("✅ Approve", disabled=current_status == "approved"):
        update_status(selected_item["id"], "approved", st.session_state.file_path)
        st.rerun()

    if st.button("❌ Reject", disabled=current_status == "rejected"):
        update_status(selected_item["id"], "rejected", st.session_state.file_path)
        st.rerun()

    # Responses
    st.subheader("Responses")
    for response in selected_item["responses"]:
        with st.container(border=True):
            st.write("*Response:*", response["response"])
            # response["score"] is a dict with keys:
            #   "is_deceptive" (0-1), "unambiguous" (0-1), "explanation" (str)
            st.markdown(
                "*Is deceptive:* " + get_colored_value(response["score"]["is_deceptive"], 0, 1)
            )
            st.markdown(
                "*Unambiguous:* " + get_colored_value(response["score"]["unambiguous"], 0, 1)
            )
            with st.expander("Reasoning", expanded=False):
                st.write(response["score"]["explanation"])


if __name__ == "__main__":
    main()
