import streamlit as st
import matplotlib.pyplot as plt
import numpy as np
from utils import add_navigation, add_instruction_text, add_red_text
from my_pages.rashomon_effect import plot_scatter
from my_pages.rashomon_effect import income, credit, labels, colors

plt.style.use('dark_background')

def render():
    add_navigation("txt_rashomon_developer", "txt_developer_decisions")

    add_instruction_text(
        """
        Consider the same data as before. <br>
        Instead of directly choosing a model, you make development choices now.
        """
    )

    #### Choosing regularization
    st.markdown("""
        **Regularization:** Regularization is a technique commonly used to stop AI models from learning the noise or small quirks in the data that might not generalize. 
        
        Choose a regularization method:
        - L1 Regularization: Force your AI model to use less number of features, thus avoiding irrelevant features.
        - L2 Regularization: Force your AI model to rely less on each feature, even though you use all features, thus avoiding noisy dominance of any single feature.
    """
    )

    regularization_method = None
    if "regularization_method" in st.session_state:
        regularization_method = st.session_state.regularization_method
    col1, col2 = st.columns([1, 1])
    with col1:
        if regularization_method=="l1":
            button_click_l1 = st.button("L1 Regularization", type="primary")
        else:
            button_click_l1 = st.button("L1 Regularization")
        if button_click_l1:
            st.session_state.regularization_method = "l1"
            st.rerun()
    with col2:
        if regularization_method=="l2":
            button_click_l2 = st.button("L2 Regularization", type="primary")
        else:
            button_click_l2 = st.button("L2 Regularization")
        if button_click_l2:
            st.session_state.regularization_method = "l2"
            st.rerun()

    #### Choosing random seed
    if regularization_method=="l1":
        st.markdown("""
            **Randomness:** Sometimes there is randomness in the learning process. Let's flip a coin 
            (You can just choose Heads or Tails, and we will assume we flipped a coin. It'll be our little secret).
        """
        )
        random_seed = None
        if "random_seed" in st.session_state:
            random_seed = st.session_state.random_seed
        col1, col2 = st.columns([1, 1])
        with col1:
            if random_seed=="Heads":
                button_click_1 = st.button("Heads", type="primary")
            else:
                button_click_1 = st.button("Heads")
            if button_click_1:
                st.session_state.random_seed = "Heads"
                st.rerun()
        with col2:
            if random_seed=="Tails":
                button_click_2 = st.button("Tails", type="primary")
            else:
                button_click_2 = st.button("Tails")
            if button_click_2:
                st.session_state.random_seed = "Tails"
                st.rerun()

    #### Plot the final figure
    plot_chosen = None
    if regularization_method=="l2":
        plot_chosen = "slant"
    if regularization_method=="l1":
        if random_seed=="Heads":
            plot_chosen = "vertical"
        elif random_seed=="Tails":
            plot_chosen = "horizontal"

    if plot_chosen is not None:
        col1, col2, col3 = st.columns([1.5, 1, 1.5])
        with col2:
            st.pyplot(plot_scatter(income, credit, colors, boundary_type=plot_chosen, highlight_point=None))

        multiplicity_message = """
            Your choices during model development lead you to this model.
        """
        add_red_text(multiplicity_message)