import streamlit as st
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm

# Section 1: Input parameters
st.title("Meritocracy and Bias")

st.header("1. Enter Institution, Group, and Bias Parameters")

c = st.slider("Selection coefficient (c)", min_value=0.01, max_value=1.0, value=0.1, step=0.01)
alpha = st.slider("Fraction of candidates in Group G1 (α)", min_value=0.01, max_value=1.0, value=0.5, step=0.01)
rho = st.slider("Bias parameter (ρ)", min_value=0.0, max_value=1.0, value=0.8, step=0.01)
#dist = st.selectbox("Select Utility Distribution (D)", ["Uniform [0,1]", "Truncated Gaussian [0,1]"])

dist = "Uniform [0,1]"

# Section 2: Compute Nash equilibrium
st.header("2. Compute Nash Equilibrium")

if dist == "Uniform [0,1]":
    if rho < 1 - c / (1 - alpha):
        t = 1 - c / (1 - alpha)
    else:
        t = (rho * (1 - c)) / (rho - alpha * rho + alpha)
    st.write(f"**Threshold (t):** {t:.3f}")
else:
    # Truncated Gaussian parameters
    mu = 0.5
    sigma = 0.2

    def truncated_gaussian(v):
        phi_v = norm.pdf((v - mu) / sigma)
        phi_1 = norm.pdf((1 - mu) / sigma)
        phi_0 = norm.pdf((-mu) / sigma)
        Phi_1 = norm.cdf((1 - mu) / sigma)
        Phi_0 = norm.cdf((-mu) / sigma)
        term1 = (1 - alpha) * (phi_v - phi_1) / (Phi_1 - Phi_0)
        term2 = alpha * (norm.pdf((v - rho * mu) / sigma) - norm.pdf((1 - rho * mu) / sigma)) / (
            norm.cdf((1 - rho * mu) / sigma) - norm.cdf((-rho * mu) / sigma)
        )
        return term1 + term2 - (1 - c)

    # Numerical solution
    v_vals = np.linspace(0, 1, 1000)
    v_solutions = [v for v in v_vals if np.isclose(truncated_gaussian(v), 0, atol=1e-3)]
    v = v_solutions[0] if v_solutions else None
    st.write(f"**Threshold (v):** {v:.3f}" if v else "No solution found.")

# Section 3: Metrics
st.header("3. Metrics")

st.subheader("Uniform Distribution Metrics")
if rho < 1 - c / (1 - alpha):
    r_rep = 0
    s1 = c**2 / (2 * (1 - alpha)**2)
    s2 = 0
    r_social = 0
else:
    r_rep = (rho - alpha * rho + alpha + c - 1) / (alpha - alpha * rho + c * rho)
    s1 = (alpha - alpha * rho + c * rho)**2 / (2 * (rho - alpha * rho + alpha)**2)
    s2 = rho * (rho - alpha * rho + alpha + c - 1)**2 / (2 * (rho - alpha * rho + alpha)**2)
    r_social = rho * (rho - alpha * rho + alpha + c - 1)**2 / (alpha - alpha * rho + c * rho)**2

st.write(f"**Representation Ratio:** {r_rep:.3f}")
st.write(f"**Social Welfare (Group 1):** {s1:.3f}")
st.write(f"**Social Welfare (Group 2):** {s2:.3f}")
st.write(f"**Social Welfare Ratio :** {r_social:.3f}")

# Section for plots
st.header("4. Plots for Each Parameter")

# Create three columns for the three parameters
col1, col2, col3 = st.columns(3)

# Dropdown for each parameter
with col1:
    metric_c = st.selectbox("Plots for c (Selection Coefficient)", ["Revenue", "Representation Ratio", "Social Welfare Ratio"])
    c_vals = np.linspace(0.01, 1, 100)
    if metric_c == "Revenue":
        values_c = [(1 - c_val / (1 - alpha)) if rho < 1 - c_val / (1 - alpha) else (rho * (1 - c_val)) / (rho - alpha * rho + alpha) for c_val in c_vals]
        ylabel = "Revenue"
    elif metric_c == "Representation Ratio":
        values_c = [
            (0 if rho < 1 - c_val / (1 - alpha) else (rho - alpha * rho + alpha + c_val - 1) / (alpha - alpha * rho + c_val * rho))
            for c_val in c_vals
        ]
        ylabel = "Representation Ratio"
    elif metric_c == "Social Welfare Ratio":
        values_c = [
            (0 if rho < 1 - c_val / (1 - alpha) else rho * (rho - alpha * rho + alpha + c_val - 1)**2 / (alpha - alpha * rho + c_val * rho)**2)
            for c_val in c_vals
        ]
        ylabel = "Social Welfare Ratio"

    # Plot for c
    plt.figure()
    plt.plot(c_vals, values_c)
    plt.xlabel("c (Selection Coefficient)")
    plt.ylabel(ylabel)
    plt.title(f"{ylabel} vs c")
    st.pyplot(plt)

with col2:
    metric_alpha = st.selectbox("Plots for α (Fraction in G1)", ["Revenue", "Representation Ratio", "Social Welfare Ratio"])
    alpha_vals = np.linspace(0.01, 1, 100)
    if metric_alpha == "Revenue":
        values_alpha = [(1 - c / (1 - alpha_val)) if rho < 1 - c / (1 - alpha_val) else (rho * (1 - c)) / (rho - alpha_val * rho + alpha_val) for alpha_val in alpha_vals]
        ylabel = "Revenue"
    elif metric_alpha == "Representation Ratio":
        values_alpha = [
            (0 if rho < 1 - c / (1 - alpha_val) else (rho - alpha_val * rho + alpha_val + c - 1) / (alpha_val - alpha_val * rho + c * rho))
            for alpha_val in alpha_vals
        ]
        ylabel = "Representation Ratio"
    elif metric_alpha == "Social Welfare Ratio":
        values_alpha = [
            (0 if rho < 1 - c / (1 - alpha_val) else rho * (rho - alpha_val * rho + alpha_val + c - 1)**2 / (alpha_val - alpha_val * rho + c * rho)**2)
            for alpha_val in alpha_vals
        ]
        ylabel = "Social Welfare Ratio"

    # Plot for alpha
    plt.figure()
    plt.plot(alpha_vals, values_alpha)
    plt.xlabel("α (Fraction in G1)")
    plt.ylabel(ylabel)
    plt.title(f"{ylabel} vs α")
    st.pyplot(plt)

with col3:
    metric_rho = st.selectbox("Plots for ρ (Bias Parameter)", ["Revenue", "Representation Ratio", "Social Welfare Ratio"])
    rho_vals = np.linspace(0.01, 1, 100)
    if metric_rho == "Revenue":
        values_rho = [(1 - c / (1 - alpha)) if rho_val < 1 - c / (1 - alpha) else (rho_val * (1 - c)) / (rho_val - alpha * rho_val + alpha) for rho_val in rho_vals]
        ylabel = "Revenue"
    elif metric_rho == "Representation Ratio":
        values_rho = [
            (0 if rho_val < 1 - c / (1 - alpha) else (rho_val - alpha * rho_val + alpha + c - 1) / (alpha - alpha * rho_val + c * rho_val))
            for rho_val in rho_vals
        ]
        ylabel = "Representation Ratio"
    elif metric_rho == "Social Welfare Ratio":
        values_rho = [
            (0 if rho_val < 1 - c / (1 - alpha) else rho_val * (rho_val - alpha * rho_val + alpha + c - 1)**2 / (alpha - alpha * rho_val + c * rho_val)**2)
            for rho_val in rho_vals
        ]
        ylabel = "Social Welfare Ratio"

    # Plot for rho
    plt.figure()
    plt.plot(rho_vals, values_rho)
    plt.xlabel("ρ (Bias Parameter)")
    plt.ylabel(ylabel)
    plt.title(f"{ylabel} vs ρ")
    st.pyplot(plt)
