# pages/multiverse.py
import streamlit as st
import plotly.graph_objects as go
from utils import add_navigation, add_instruction_text, add_red_text

import random
import pandas as pd
import numpy as np

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler, StandardScaler, RobustScaler
from sklearn.neural_network import MLPClassifier
from sklearn.linear_model import SGDClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.preprocessing import LabelEncoder
from sklearn.feature_selection import SelectKBest, f_classif
from sklearn.decomposition import PCA

choices_list = [
    {"label": "Data Scaling", "options": [
        "MinMax Scaler",
        "Standard Scaler",
        "Robust Scaler"
    ]},
    {"label": "Feature Selection", "options": [
        "Select K Best (k=5)",
        "PCA (n=5)",
        "All Features"
    ]},
    {"label": "Model Architecture", "options": [
        "Logistic Regression",
        "Decision Tree",
        "Neural Network (Small)"
    ]},
    {"label": "Random Seed", "options": [
        "1", "2", "3", "4", "5", "6", "7", "8", "9", "10"
    ]}
]

def build_tree_and_trace_path(selected_path, spread_factor=10000):
    """
    Build tree nodes and edges. Then trace selected_path (one choice per stage)
    by walking children of the current node to find the matching label at each stage.

    Parameters
    ----------
    selected_path : list of str
        The path to highlight.
    spread_factor : float (default=5.0)
        Controls vertical spread. Larger values => more separation early,
        less separation as depth increases.

    Returns
    -------
    node_labels, node_positions, edges, highlight_edges, highlight_nodes
    """
    
    node_labels = ["Start"]
    node_positions = [(0.0, 0.0)]
    node_stage = [0]
    edges = []

    prev_nodes = [0]  # nodes at previous stage (start)
    y_spacing_base = 1.0

    # Build nodes and edges stage by stage
    for stage_idx, stage in enumerate(choices_list, start=1):
        options = stage["options"]
        next_nodes = []

        # scaling: huge spread at stage 1, tapering off deeper
        scale = spread_factor ** (1.0 / stage_idx**(0.2))

        for parent_order, parent_idx in enumerate(prev_nodes):
            px, py = node_positions[parent_idx]

            # each parent gets its own block of vertical space
            parent_block_size = len(options) * y_spacing_base * scale
            base_y = py + (parent_order - (len(prev_nodes) - 1) / 2.0) * parent_block_size

            for opt_idx, opt in enumerate(options):
                child_x = float(stage_idx)

                offset = (opt_idx - (len(options) - 1) / 2.0) * (y_spacing_base * scale)

                child_y = base_y + offset
                node_index = len(node_labels)
                node_labels.append(opt)
                node_positions.append((child_x, child_y))
                node_stage.append(stage_idx)
                edges.append((parent_idx, node_index))
                next_nodes.append(node_index)

        prev_nodes = next_nodes

    # Trace the single chosen path by walking children
    highlight_edges = set()
    highlight_nodes = set([0])
    current_node = 0

    for stage_idx, chosen_label in enumerate(selected_path, start=1):
        children = [dst for (src, dst) in edges if src == current_node]
        found_child = None
        for c in children:
            if node_labels[c] == chosen_label:
                found_child = c
                break
        if found_child is None:
            break
        highlight_edges.add((current_node, found_child))
        highlight_nodes.add(found_child)
        current_node = found_child

    return node_labels, node_positions, edges, highlight_edges, highlight_nodes



def render():
    add_navigation("txt_multiverse", "txt_conclusion")

    add_instruction_text(
        """
        Visually explore the multiverse of AI models to judge loan applications.<br> 
        We are using a publicly available loan approval dataset.<br>
        Make a choice, and scroll down to see the properties of the trained model.<br>
        Not sure what choice to make? Just pick something and see what happens.
        """
    )

    # --- User picks one choice per stage via dropdowns ---
    cols_list = st.columns([1, 1])
    selected_path = []
    for ite, stage in enumerate(choices_list):
        with cols_list[ite%2]:
            # use a stable key per stage to avoid conflicts
            key = f"multiverse_choice_{stage['label']}"
            choice = st.selectbox(f"{stage['label']}", stage["options"], key=key)
            selected_path.append(choice)

    # --- Build tree and compute which edges/nodes to highlight ---
    labels, positions, edges, highlight_edges, highlight_nodes = build_tree_and_trace_path(selected_path)

    # --- Prepare node and edge traces for Plotly ---
    x_vals = [pos[0] for pos in positions]
    y_vals = [pos[1] for pos in positions]

    node_colors = []
    for idx in range(len(labels)):
        if idx in highlight_nodes:
            node_colors.append("rgba(34,139,34,0.95)")  # green for selected path nodes
        elif idx == 0:
            node_colors.append("rgba(30,144,255,0.9)")  # start node distinct
        else:
            node_colors.append("rgba(135,206,250,0.6)")  # default skyblue

    node_trace = go.Scatter(
        x=x_vals, y=y_vals,
        mode='markers',
        text=labels,
        # textposition="top center",
        marker=dict(size=18, color=node_colors, line=dict(width=1, color='black')),
        hoverinfo="text"
    )

    edge_traces = []
    for src, dst in edges:
        if (src, dst) in highlight_edges:
            color = "rgba(0,128,0,0.9)"  # bright green
            width = 4
        else:
            color = "rgba(120,120,120,0.4)"
            width = 1.5
        edge_traces.append(go.Scatter(
            x=[positions[src][0], positions[dst][0]],
            y=[positions[src][1], positions[dst][1]],
            mode='lines',
            line=dict(width=width, color=color),
            hoverinfo='none'
        ))

    # --- Add stage labels at the top of each layer ---
    stage_label_traces = []
    for stage_idx, stage in enumerate(choices_list, start=1):
        # find all nodes belonging to this stage (x == stage_idx)
        stage_nodes = [i for i, (x, y) in enumerate(positions) if x == float(stage_idx)]
        if not stage_nodes:
            continue
        # max y among these nodes
        max_y = max(positions[i][1] for i in stage_nodes)
        x = float(stage_idx)
        y = max_y + 20000  # offset above top node
        stage_label_traces.append(go.Scatter(
            x=[x], y=[y],
            text=[stage["label"]],
            mode="text",
            textfont=dict(size=16, color="white"),
            hoverinfo="none",
            showlegend=False
        ))


    # --- Render figure ---
    # fig = go.Figure(data=edge_traces + [node_trace])
    fig = go.Figure(data=edge_traces + stage_label_traces + [node_trace])
    fig.update_layout(
        showlegend=False,
        xaxis=dict(visible=False),
        yaxis=dict(visible=False),
        # paper_bgcolor='rgba(0,0,0,0)',   # transparent
        # plot_bgcolor='rgba(0,0,0,0)',    # transparent
        paper_bgcolor='black',
        plot_bgcolor='black',
        font=dict(color='white'),
        margin=dict(l=10, r=10, t=10, b=10),
        hovermode="closest"
    )

    st.plotly_chart(fig, use_container_width=True)

    ##########################
    ##########################
    ##########################


    def split_and_scale(features, label, test_split=0.2, preprocess_scale=None):
        X_train, X_test, y_train, y_test = train_test_split(features, label, test_size=test_split, random_state=0)
    
        if preprocess_scale is not None:
            if preprocess_scale=="MinMax Scaler":
                scaler = MinMaxScaler()
            elif preprocess_scale=="Standard Scaler":
                scaler = StandardScaler()
            elif preprocess_scale=="Robust Scaler":
                scaler = RobustScaler()
            scaler.fit(X_train)
            X_train, X_test = scaler.transform(X_train), scaler.transform(X_test)
    
        return X_train, X_test, y_train, y_test
    
    def get_stackoverflow_dataset(test_split=0.2, preprocess_scale=None):
        data = pd.read_csv('loan_approval_dataset.csv')

        features = data.drop(columns=["loan_id", " loan_status"])
        features = pd.get_dummies(features, columns=[" education", " self_employed"], drop_first=True).values
        
        le = LabelEncoder()
        label = le.fit_transform(data[" loan_status"])
    
        features, label = np.array(features), np.array(label)
    
        return split_and_scale(features, label, test_split, preprocess_scale)
    
    def model_train_and_pred(scaler, feature_sel, arch, seed):
        X_train, X_test, y_train, y_test = get_stackoverflow_dataset(preprocess_scale=scaler)

        if feature_sel=="Select K Best (k=5)":
            selector = SelectKBest(score_func=f_classif, k=5)
            X_train = selector.fit_transform(X_train, y_train)
            X_test = selector.transform(X_test)
        elif feature_sel=="PCA (n=5)":
            pca = PCA(n_components=2)
            X_train = pca.fit_transform(X_train, y_train)
            X_test = pca.transform(X_test)
        
        modelclass_dict = {'Neural Network (Small)': MLPClassifier([10], random_state=seed, max_iter=500),
                           'Logistic Regression': SGDClassifier(random_state=seed, max_iter=500),
                           'Decision Tree': DecisionTreeClassifier(random_state=seed)}
        model = modelclass_dict[arch]
        model.fit(X_train, y_train)
    
        y_pred = model.predict(X_test)
        return y_pred

    # all_preds = []
    # for scaler in choices_list[0]["options"]:
    #     for feature_sel in choices_list[1]["options"]:
    #         for arch in choices_list[2]["options"]:
    #             for seed in choices_list[3]["options"]:
    #                 seed = int(seed)
    #                 y_pred_local = model_train_and_pred(scaler, feature_sel, arch, seed)
    #                 all_preds.append(y_pred_local)
    #             st.markdown(scaler + feature_sel + arch)

    # all_preds_numpy = np.array(all_preds)
    # from io import BytesIO
    
    # # Create a BytesIO object
    # buffer = BytesIO()
    # np.save(buffer, all_preds_numpy)
    # buffer.seek(0)  # move to start
    
    # st.download_button(
    #     label="Download predictions",
    #     data=buffer,
    #     file_name="all_predictions.npy",
    #     mime="application/octet-stream"
    # )

    ### Main Code Starts Here
    scaler, feature_sel, arch, seed = selected_path[0], selected_path[1], selected_path[2], int(selected_path[3])
    y_pred = model_train_and_pred(scaler, feature_sel, arch, seed)
    all_preds_numpy = np.load("all_predictions.npy")

    prop_ones = np.mean(all_preds_numpy == 1, axis=0)
    condition_rej = (y_pred == 0) & (prop_ones > 0.5)
    # uniq_perc = 100 * np.sum(condition) / len(y_pred)
    uniq_count_rej = np.sum(condition_rej)

    
    condition_acc = (y_pred == 1) & (prop_ones < 0.5)
    uniq_count_acc = np.sum(condition_acc)
    
    add_red_text(f"""
        <b>Based on your choices:</b><br>
        Number of loans accepted by the majority, but rejected by you: {uniq_count_rej}<br>
        Number of loans rejected by the majority, but accepted by you: {uniq_count_acc}<br><br>
        <b>Reasons you might want to conform:</b><br>
        To take lower risks and to avoid facing a justification crisis, i.e., 
        not able to explain why you rejected an applicant who would have been accepted by most other models.<br><br>
        <b>Reasons you might want to be unique:</b><br>
        To avoid competiting for the same loan applicants with others.<br> 
        To give a chance to unique applicants and deal with the concerns of homogenization.<br><br>
    """)