import streamlit as st
import matplotlib.pyplot as plt
import numpy as np
from utils import add_navigation, add_instruction_text, add_red_text

plt.style.use('dark_background')

def render():
    add_navigation("txt_ica", "txt_multiverse")

    add_instruction_text(
        """
        Explore the intention-convention-arbitrariness (ICA) framework.<br>
        Use the sliders to adjust the three dimensions and uncover various examples.
        """
    )

    # Initialize weights
    if "weights" not in st.session_state:
        st.session_state.weights = {
            "Intentional": 0.33,
            "Conventional": 0.33,
            "Arbitrary": 0.34
        }

    # Keep track of previous weights
    if "prev_weights" not in st.session_state:
        st.session_state.prev_weights = st.session_state.weights.copy()

    w = st.session_state.weights
    prev_w = st.session_state.prev_weights

    # --- Three sliders ---
    col1, col2, col3 = st.columns(3)
    with col1:
        i_new = st.slider("Intentional", 0.0, 1.0, w["Intentional"], 0.01)
    with col2:
        c_new = st.slider("Conventional", 0.0, 1.0, w["Conventional"], 0.01)
    with col3:
        a_new = st.slider("Arbitrary", 0.0, 1.0, w["Arbitrary"], 0.01)

    # --- Adjust other sliders proportionally ---
    # Detect which slider changed
    if i_new != prev_w["Intentional"]:
        diff = i_new - prev_w["Intentional"]
        total_other = w["Conventional"] + w["Arbitrary"]
        if total_other > 0:
            w["Conventional"] -= diff * (w["Conventional"] / total_other)
            w["Arbitrary"] -= diff * (w["Arbitrary"] / total_other)
        w["Intentional"] = i_new
        st.session_state.prev_weights = w.copy()
        st.rerun()

    elif c_new != prev_w["Conventional"]:
        diff = c_new - prev_w["Conventional"]
        total_other = w["Intentional"] + w["Arbitrary"]
        if total_other > 0:
            w["Intentional"] -= diff * (w["Intentional"] / total_other)
            w["Arbitrary"] -= diff * (w["Arbitrary"] / total_other)
        w["Conventional"] = c_new
        st.session_state.prev_weights = w.copy()
        st.rerun()

    elif a_new != prev_w["Arbitrary"]:
        diff = a_new - prev_w["Arbitrary"]
        total_other = w["Intentional"] + w["Conventional"]
        if total_other > 0:
            w["Intentional"] -= diff * (w["Intentional"] / total_other)
            w["Conventional"] -= diff * (w["Conventional"] / total_other)
        w["Arbitrary"] = a_new
        st.session_state.prev_weights = w.copy()
        st.rerun()

    # --- Triangle vertices ---
    vertices = np.array([
        [0.5, np.sqrt(3)/2],  # Intentional
        [0, 0],               # Conventional
        [1, 0]                # Arbitrary
    ])

    # Point from barycentric coords
    point = (
        w["Intentional"] * vertices[0] +
        w["Conventional"] * vertices[1] +
        w["Arbitrary"] * vertices[2]
    )

    # --- Plot ---
    fig, ax = plt.subplots()
    ax.plot(*np.append(vertices, [vertices[0]], axis=0).T)
    ax.text(*vertices[0], "Intentional", ha="center", va="bottom", color="green", weight="heavy")
    ax.text(*vertices[1], "Conventional", ha="right", va="top", color="green", weight="heavy")
    ax.text(*vertices[2], "Arbitrary", ha="left", va="top", color="green", weight="heavy")
    ax.scatter(point[0], point[1], c="white", s=10000)
    ax.scatter(point[0], point[1], c="orange", s=10000, zorder=5, alpha=0.3)
    ax.set_aspect("equal")
    ax.axis("off")
    fig.patch.set_alpha(0)
    ax.patch.set_alpha(0)

    # --- Dummy points ---
    locations = [
        (0.9, 0.1, "Random Seeds", "Randomness is highly arbitrary, without any convention or intentionality.", 
         "left", "bottom"),
        (0.35, 0.06, "Neural networks for Tabular Data", "Using an unnecessarily complex AI model for a setting where its not needed is highly conventional, a bit arbitrary, and has very low intentionality.", 
         "left", "bottom"),
        (0.4, 0.5, "Foundation Model for a Complex Task", "Using a language model for a complex task is intentional, however, it also has conventionality to it, as a specialized model could have worked. Low arbitrariness.", 
         "center", "bottom"),
        (0.5, 0.7, "Best Bias Mitigation for a Particular Setup", "Choosing the most appropriate bias mitigation technique, specialized for the particular setup, is highly intentional.", 
         "center", "bottom"),
        (0.7, 0.5, "Randomly chosen Regularization", "Adding regularization, but choosing the regularization method randomly, creates a decision that is intentional and arbitrary, while avoiding conventionality.", 
         "center", "top"),
        (0.15, 0.15, "ReLU Activation as Default", "Choosing some popular architecture component without testing what other components could have worked, is a highly conventional decision.", 
         "center", "bottom"),
    ]

    torch_radius = 0.177
    explanations = []
    for (x, y, label, labeltext, ha, va) in locations:
        dist = np.linalg.norm([x - point[0], y - point[1]])
        if dist <= torch_radius:
            ax.scatter(x, y, c="red", s=50, zorder=6)
            if va=="bottom":
                ax.text(x, y + 0.03, label, ha=ha, va=va, color="red", zorder=6, weight="heavy")
            elif va=="top":
                ax.text(x, y - 0.03, label, ha=ha, va=va, color="red", zorder=6, weight="heavy")
            explanations.append((label, labeltext))
        else:
            ax.scatter(x, y, c="red", s=50, zorder=6, alpha=0.3)

    col1, col2, col3 = st.columns([0.3, 1, 0.3])
    with col2:
        st.pyplot(fig)

    if len(explanations) > 0:
        text_to_show = ""
        for label, labeltext in explanations:
            text_to_show += "<b>" + label + ":</b> " + labeltext + "<br><br>"
        add_red_text(text_to_show)
