import streamlit as st
import pandas as pd
import plotly.graph_objects as go
import numpy as np
import plotly.express as px
import statsmodels.api as sm  # For detailed regression analysis
from math import exp
from scipy.special import erf, erfinv
import math

# -----------------------------
# 1. App Configuration
# -----------------------------

# Set the page configuration (optional)
st.set_page_config(
    page_title="Benchmark Results Dashboard",
    layout="wide",
    initial_sidebar_state="expanded",
)

# Title of the app
st.title("Benchmark Results Dashboard")

# -----------------------------
# 2. Data Loading and Processing
# -----------------------------

@st.cache_data
def load_data():
    """
    Load and process the benchmark results from CSV.
    Applies epsilon to accuracy values equal to 0 to avoid log(0).
    Computes the natural logarithm of accuracy.
    """
    try:
        df = pd.read_csv('processed_results.csv')
    except FileNotFoundError:
        st.error("The file 'processed_results.csv' was not found in the current directory.")
        st.stop()

    # Define a small epsilon value
    epsilon = 1e-6

    # Replace accuracy values <= 0 with epsilon
    num_zero_acc = (df['accuracy'] <= 0).sum()
    if num_zero_acc > 0:
        st.warning(f"Found {num_zero_acc} data point(s) with accuracy <= 0. Replacing them with epsilon = {epsilon}.")
        df.loc[df['accuracy'] <= 0, 'accuracy'] = epsilon

    # Compute log_accuracy
    df['log_accuracy'] = np.log(df['accuracy'])

    return df

df = load_data()

# -----------------------------
# 3. Sidebar Filters and Options
# -----------------------------

st.sidebar.header("Filters")

# a. Dataset Selection
datasets = df['dataset'].unique()
selected_dataset = st.sidebar.selectbox("Select Dataset", options=datasets)

# b. Model Selection based on Selected Dataset
models = df[df['dataset'] == selected_dataset]['model'].unique()
selected_models = st.sidebar.multiselect("Select Model(s)", options=models, default=models)

# c. Length Selection (Multi-select)
lengths = sorted(df['length'].unique())
selected_lengths = st.sidebar.multiselect(
    "Select Length(s)",
    options=lengths,
    default=lengths  # Default to all lengths selected
)

# d. Outlier Removal (Accuracy Range Slider)
st.sidebar.header("Outlier Removal")
accuracy_min, accuracy_max = st.sidebar.slider(
    "Select Accuracy Range for Data Points (including all views)",
    min_value=0.0,
    max_value=1.0,
    value=(0.01, 0.9),
    step=0.01
)

# e. N Range Selection for Regression
st.sidebar.header("Regression N Range")
n_min, n_max = st.sidebar.slider(
    "Select N Range for Linear Regression",
    min_value=int(df['N'].min()),
    max_value=int(df['N'].max()),
    value=(int(df['N'].min()), int(df['N'].max())),
    step=1
)

# f. Confidence Level Selection
st.sidebar.header("Confidence Level for Regression")
confidence_level = st.sidebar.selectbox(
    "Select Confidence Level (%)",
    options=[50, 90, 95, 99],
    index=2  # Default to 95%
)

# g. View Option Selection
st.sidebar.header("View Options")
view_option = st.sidebar.radio(
    "Select View",
    options=["Accuracy", "Log(Accuracy)"]
)

# h. Regression Analysis Toggle
regression_toggle = st.sidebar.checkbox("Show Linear Regression (Log-View Only)")

# i. Option to Hide Original Data
hide_original_data = st.sidebar.checkbox("Hide Original Data Lines", value=False)

# -----------------------------
# 4. Data Filtering Based on Selections
# -----------------------------

# Filter the DataFrame based on user selections
filtered_df = df[
    (df['dataset'] == selected_dataset) &
    (df['model'].isin(selected_models)) &
    (df['length'].isin(selected_lengths)) &
    (df['accuracy'] >= accuracy_min) &
    (df['accuracy'] <= accuracy_max)
]

# Check if the filtered DataFrame is empty
if filtered_df.empty:
    st.warning("No data available for the selected filters.")
    st.stop()

# Further filter for regression based on N range
regression_df = filtered_df[
    (filtered_df['N'] >= n_min) &
    (filtered_df['N'] <= n_max)
]

# -----------------------------
# 5. Plotting Functions
# -----------------------------

def plot_log_accuracy_with_regression(filtered_df, regression_df, selected_models, selected_lengths, confidence_level):
    """
    Plots log(acc) vs N with linear regression lines for each model and length.
    Also plots the corresponding exponential functions based on regression.
    Returns the regression coefficients and their standard errors as a list of dictionaries.
    """
    fig_log = go.Figure()
    fig_exp = go.Figure()
    
    regression_metrics = []

    for model in selected_models:
        for length in selected_lengths:
            # Filter data for the current model and length
            subset_df = filtered_df[
                (filtered_df['model'] == model) &
                (filtered_df['length'] == length)
            ]

            # Plot log(acc) vs N data
            if not hide_original_data:
                fig_log.add_trace(go.Scatter(
                    x=subset_df['N'],
                    y=subset_df['log_accuracy'],
                    mode='lines+markers',
                    name=f'Data {model} Length {length}',
                    marker=dict(size=6)
                ))

            if regression_toggle:
                # Further filter for regression based on N range
                reg_subset_df = regression_df[
                    (regression_df['model'] == model) &
                    (regression_df['length'] == length)
                ]

                if len(reg_subset_df) >= 2:
                    X = reg_subset_df['N']
                    y = reg_subset_df['log_accuracy']

                    # Add a constant term for intercept
                    X_const = sm.add_constant(X)

                    # Fit linear regression using statsmodels
                    model_reg = sm.OLS(y, X_const).fit()
                    a = model_reg.params['N']
                    b = model_reg.params['const']
                    stderr_a = model_reg.bse['N']
                    stderr_b = model_reg.bse['const']
                    
                    # Store regression metrics
                    regression_metrics.append({
                        'Model': model,
                        'Length': length,
                        'Coefficient (a)': a,
                        'Intercept (b)': b,
                        'Std. Error (a)': stderr_a,
                        'Std. Error (b)': stderr_b
                    })
                    
                    # Predict log(acc)
                    y_pred = model_reg.predict(X_const)

                    # Plot regression line
                    fig_log.add_trace(go.Scatter(
                        x=reg_subset_df['N'],
                        y=y_pred,
                        mode='lines',
                        name=f'Regression {model} Length {length}',
                        line=dict(dash='dash')
                    ))

                    # Calculate corresponding exponential function: acc = exp(a * N + b)
                    acc_pred = np.exp(a * reg_subset_df['N'] + b)

                    # Plot exponential function on acc vs N
                    fig_exp.add_trace(go.Scatter(
                        x=reg_subset_df['N'],
                        y=acc_pred,
                        mode='lines',
                        name=f'Exp Fit {model} Length {length}',
                        line=dict(dash='dot')
                    ))
                else:
                    st.warning(f"Not enough data points for regression on Model '{model}' Length '{length}' (requires ≥2 points).")
    
    # Update log(acc) vs N figure
    fig_log.update_layout(
        title=f"Log(Accuracy) vs N for Dataset: {selected_dataset}",
        xaxis_title="N",
        yaxis_title="Log(Accuracy)",
        legend_title="Legend",
        hovermode='x unified'
    )

    # Update acc vs N (Exponential Fit) figure if regression was performed
    if regression_toggle and regression_metrics:
        fig_exp.update_layout(
            title=f"Exponential Fit (acc = exp(a*N + b)) vs N for Dataset: {selected_dataset}",
            xaxis_title="N",
            yaxis_title="Predicted Accuracy (exp(a*N + b))",
            legend_title="Legend",
            hovermode='x unified'
        )
    else:
        fig_exp = None  # No exponential plot to show

    return fig_log, fig_exp, regression_metrics

def plot_accuracy(filtered_df, selected_models, selected_lengths):
    """
    Plots Accuracy vs N for each model and length without regression.
    """
    fig = go.Figure()

    for model in selected_models:
        for length in selected_lengths:
            subset_df = filtered_df[
                (filtered_df['model'] == model) &
                (filtered_df['length'] == length)
            ]
            fig.add_trace(go.Scatter(
                x=subset_df['N'],
                y=subset_df['accuracy'],
                mode='lines+markers',
                name=f'{model} Length {length}',
                marker=dict(size=6)
            ))

    fig.update_layout(
        title=f"Accuracy vs N for Dataset: {selected_dataset}",
        xaxis_title="N",
        yaxis_title="Accuracy",
        legend_title="Legend",
        hovermode='x unified'
    )

    return fig

def plot_regression_coefficients(regression_metrics, confidence_level):
    """
    Plots the regression coefficients `a` and `b` against the rank of each model-length pair.
    Includes error bars representing the standard errors.
    The X-axis represents the rank of the model-length pair, labeled with actual model and length.
    """
    if not regression_metrics:
        st.info("No regression metrics to display.")
        return

    # Create a DataFrame from regression_metrics
    metrics_df = pd.DataFrame(regression_metrics)
    
    # Sorting and ranking
    metrics_df = metrics_df.sort_values(['Model', 'Length']).reset_index(drop=True)
    metrics_df['Rank'] = metrics_df.index

    # Create labels combining Model and Length for clarity
    metrics_df['Model_Length'] = metrics_df.apply(lambda row: f"{row['Model']}-{row['Length']}", axis=1)
    rank_labels = metrics_df['Model_Length'].tolist()
    rank_values = metrics_df['Rank'].tolist()

    # Plot Coefficient 'a' vs Rank with error bars
    fig_a = go.Figure()
    fig_a.add_trace(go.Scatter(
        x=metrics_df['Rank'],
        y=metrics_df['Coefficient (a)'],
        mode='markers+lines',
        name="Coefficient (a)",
        marker=dict(size=8, color='blue'),
        error_y=dict(
            type='data',
            array=metrics_df['Std. Error (a)'] * math.sqrt(2) * erfinv(confidence_level/100),
            visible=True,
            color='blue',
            thickness=1.5,
            width=3
        )
    ))
    fig_a.update_layout(
        title="Coefficient (a) vs Model-Length Rank",
        xaxis=dict(
            title="Model-Length Rank",
            tickmode='array',
            tickvals=rank_values,
            ticktext=rank_labels,
            tickangle=45
        ),
        yaxis_title="Coefficient (a)",
        legend_title="",
        hovermode='x unified'
    )

    # Plot Coefficient 'b' vs Rank with error bars
    fig_b = go.Figure()
    fig_b.add_trace(go.Scatter(
        x=metrics_df['Rank'],
        y=metrics_df['Intercept (b)'],
        mode='markers+lines',
        name="Intercept (b)",
        marker=dict(size=8, color='red'),
        error_y=dict(
            type='data',
            array=metrics_df['Std. Error (b)'] * math.sqrt(2) * erfinv(confidence_level/100),
            visible=True,
            color='red',
            thickness=1.5,
            width=3
        )
    ))
    fig_b.update_layout(
        title="Intercept (b) vs Model-Length Rank",
        xaxis=dict(
            title="Model-Length Rank",
            tickmode='array',
            tickvals=rank_values,
            ticktext=rank_labels,
            tickangle=45
        ),
        yaxis_title="Intercept (b)",
        legend_title="",
        hovermode='x unified'
    )

    # Display the plots side by side
    st.subheader("Regression Coefficients Visualization")
    col1, col2 = st.columns(2)

    with col1:
        st.plotly_chart(fig_a, use_container_width=True)

    with col2:
        st.plotly_chart(fig_b, use_container_width=True)

def plot_regression_coefficients_comparison(regression_metrics, confidence_level):
    """
    Plots comparison of coefficients 'a' and 'b' across different models grouped by context length.
    Each model is represented with a distinct color.
    """
    if not regression_metrics:
        st.info("No regression metrics to display.")
        return

    # Create a DataFrame from regression_metrics
    metrics_df = pd.DataFrame(regression_metrics)
    
    # Sorting by Length and then Model
    metrics_df = metrics_df.sort_values(['Length', 'Model']).reset_index(drop=True)
    
    # Assign ranks based on Length
    sorted_lengths = sorted(metrics_df['Length'].unique())
    length_to_rank = {length: rank for rank, length in enumerate(sorted_lengths)}
    metrics_df['Length_Rank'] = metrics_df['Length'].map(length_to_rank)
    
    # Create labels combining Model and Length for clarity
    metrics_df['Model_Length'] = metrics_df.apply(lambda row: f"{row['Model']}-{row['Length']}", axis=1)
    rank_labels = [f"{model}-{length}" for model, length in zip(metrics_df['Model'], metrics_df['Length'])]
    rank_values = metrics_df['Length_Rank'].tolist()

    models = metrics_df['Model'].unique()
    colors = px.colors.qualitative.Dark24  # Use a qualitative color palette
    color_map = {model: colors[i % len(colors)] for i, model in enumerate(models)}

    # Plot Coefficient 'a' vs Length Rank with error bars, colored by Model
    fig_a = go.Figure()

    for model in models:
        model_df = metrics_df[metrics_df['Model'] == model]
        fig_a.add_trace(go.Scatter(
            x=model_df['Length_Rank'],
            y=model_df['Coefficient (a)'],
            mode='markers+lines',
            name=f"{model}",
            marker=dict(size=8, color=color_map[model]),
            error_y=dict(
                type='data',
                array=model_df['Std. Error (a)'] * math.sqrt(2) * erfinv(confidence_level/100),
                visible=True,
                color=color_map[model],
                thickness=1.5,
                width=3
            )
        ))

    fig_a.update_layout(
        title=f"Coefficient (a) vs Length Rank with {confidence_level}% Confidence Intervals",
        xaxis=dict(
            title="Length Rank",
            tickmode='array',
            tickvals=list(length_to_rank.values()),
            ticktext=list(length_to_rank.keys()),
            tickangle=45
        ),
        yaxis_title="Coefficient (a)",
        legend_title="Model",
        hovermode='x unified'
    )

    # Plot Coefficient 'b' vs Length Rank with error bars, colored by Model
    fig_b = go.Figure()

    for model in models:
        model_df = metrics_df[metrics_df['Model'] == model]
        fig_b.add_trace(go.Scatter(
            x=model_df['Length_Rank'],
            y=model_df['Intercept (b)'],
            mode='markers+lines',
            name=f"{model}",
            marker=dict(size=8, color=color_map[model]),
            error_y=dict(
                type='data',
                array=model_df['Std. Error (b)'] * math.sqrt(2) * erfinv(confidence_level/100),
                visible=True,
                color=color_map[model],
                thickness=1.5,
                width=3
            )
        ))

    fig_b.update_layout(
        title=f"Intercept (b) vs Length Rank with {confidence_level}% Confidence Intervals",
        xaxis=dict(
            title="Length Rank",
            tickmode='array',
            tickvals=list(length_to_rank.values()),
            ticktext=list(length_to_rank.keys()),
            tickangle=45
        ),
        yaxis_title="Intercept (b)",
        legend_title="Model",
        hovermode='x unified'
    )

    # Display the comparison plots side by side
    st.subheader("Comparison of Regression Coefficients Across Models")
    col1, col2 = st.columns(2)

    with col1:
        st.plotly_chart(fig_a, use_container_width=True)

    with col2:
        st.plotly_chart(fig_b, use_container_width=True)

# -----------------------------
# 6. Visualization Based on View Option
# -----------------------------

if view_option == "Accuracy":
    # Plot Accuracy vs N
    fig_accuracy = plot_accuracy(filtered_df, selected_models, selected_lengths)
    st.plotly_chart(fig_accuracy, use_container_width=True)
else:
    # Plot Log(Accuracy) with Regression
    fig_log, fig_exp, regression_metrics = plot_log_accuracy_with_regression(
        filtered_df, 
        regression_df, 
        selected_models, 
        selected_lengths,
        confidence_level
    )
    st.plotly_chart(fig_log, use_container_width=True)

    if regression_toggle and regression_metrics:
        if fig_exp:
            st.plotly_chart(fig_exp, use_container_width=True)
        
        # Display Regression Coefficients Table with Standard Errors
        st.subheader("Linear Regression Coefficients (log(acc) = a * N + b)")
        metrics_df = pd.DataFrame(regression_metrics)
        metrics_df = metrics_df.rename(columns={
            'a': 'Coefficient (a)',
            'b': 'Intercept (b)',
            'stderr_a': 'Std. Error (a)',
            'stderr_b': 'Std. Error (b)'
        })
        st.table(metrics_df.style.format({
            'Coefficient (a)': "{:.4f}",
            'Intercept (b)': "{:.4f}",
            'Std. Error (a)': "{:.4f}",
            'Std. Error (b)': "{:.4f}"
        }))

        # Plot Regression Coefficients with Error Bars
        plot_regression_coefficients(regression_metrics, confidence_level)

        # Compare Regression Coefficients Across Models
        plot_regression_coefficients_comparison(regression_metrics, confidence_level)
    
    elif regression_toggle:
        st.info("Regression was toggled on, but no sufficient data points were available for any model-length combination.")

# -----------------------------
# 7. Additional Features
# -----------------------------

# a. Data Table Display
with st.expander("View Data"):
    st.dataframe(filtered_df.sort_values(['model', 'length', 'N']))

# b. Download Data as CSV
csv = filtered_df.to_csv(index=False).encode('utf-8')
st.download_button(
    label="Download Data as CSV",
    data=csv,
    file_name='filtered_results.csv',
    mime='text/csv',
)

# c. Download Regression Metrics as CSV (if available)
if view_option == "Log(Accuracy)" and regression_toggle and len(regression_metrics) > 0:
    metrics_df = pd.DataFrame(regression_metrics)
    metrics_df = metrics_df.rename(columns={
        'a': 'Coefficient (a)',
        'b': 'Intercept (b)',
        'stderr_a': 'Std. Error (a)',
        'stderr_b': 'Std. Error (b)'
    })
    csv_metrics = metrics_df.to_csv(index=False).encode('utf-8')
    st.download_button(
        label="Download Regression Metrics as CSV",
        data=csv_metrics,
        file_name='regression_metrics.csv',
        mime='text/csv',
    )
