import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

# Data from the regression summaries
data = {
    "model_name_short": [
        "M1\nabs(Martingale)",
        "M2\n(+Domain)",
        "M3\n(+RM)",
        "M4\n(+Domain +RM)",
        "M5\n(+Domain +RM +Model)",
        "M6\n(+Domain +RM +Prompt)",
    ],
    "model_spec": [
        "abs(Martingale)",
        "abs(Martingale) + Domain",
        "abs(Martingale) + Reasoning Mode (RM)",
        "abs(Martingale) + Domain + RM",
        "abs(Martingale) + Domain + RM + Model",
        "abs(Martingale) + Domain + RM + Prompt",
    ],
    "coef_bias": [-0.174, 0.3393, -0.1297, 0.4689, 0.3603, 0.409],
    "pval_bias": [0.214, 0.065, 0.356, 0.011, 0.043, 0.016],
    "r_squared": [0.020, 0.191, 0.060, 0.269, 0.344, 0.398],
}
df = pd.DataFrame(data)

# Define colors for p-value significance
def get_pval_color(p_val):
    if p_val < 0.01:
        return 'darkgreen'  # Highly significant
    elif p_val < 0.05:
        return 'mediumseagreen'  # Significant
    elif p_val < 0.1:
        return 'lightgreen'  # Marginally significant
    else:
        return 'silver'  # Not significant

df['pval_color'] = df['pval_bias'].apply(get_pval_color)

# Create the plot
fig, axs = plt.subplots(2, 1, figsize=(12, 10), sharex=True)
plt.style.use('seaborn-v0_8-whitegrid') # Using a seaborn style for better aesthetics

# Subplot 1: Coefficient of abs(bias)
bars_coef = axs[0].bar(df['model_name_short'], df['coef_bias'], color=df['pval_color'], alpha=0.8)
axs[0].axhline(0, color='black', linewidth=0.8, linestyle='--') # Horizontal line at y=0
axs[0].set_ylabel('Coefficient of abs(Martingale)', fontsize=12)
# axs[0].set_title('Impact of Belief Entrenchment on Accuracy with Different Sets of Controls\nRegression Formula: accuracy ~ abs(martingale) + controls', fontsize=16, pad=20)
axs[0].tick_params(axis='y', labelsize=10) # Customize y-axis tick labels

# Annotate bars with p-values for the coefficient plot
for bar, pval, coef in zip(bars_coef, df['pval_bias'], df['coef_bias']):
    yval = bar.get_height()
    # Adjust vertical offset for text based on bar height (positive or negative)
    position_offset = 0.008
    axs[0].text(bar.get_x() + bar.get_width()/2.0, # x-position: center of the bar
                yval + position_offset, # y-position: above or below the bar
                f'p={pval:.3f}',        # Text: formatted p-value
                ha='center',            # Horizontal alignment: center
                va='bottom',            # Vertical alignment: bottom
                fontsize=9, color='black')

# Create a custom legend for p-value significance colors
legend_elements_coef = [
    plt.Rectangle((0,0),1,1, color='darkgreen', label='p < 0.01 (Highly Significant)'),
    plt.Rectangle((0,0),1,1, color='mediumseagreen', label='p < 0.05 (Significant)'),
    plt.Rectangle((0,0),1,1, color='lightgreen', label='p < 0.1 (Marginally Significant)'),
    plt.Rectangle((0,0),1,1, color='silver', label='p >= 0.1 (Not Significant)')
]
axs[0].legend(handles=legend_elements_coef, loc='upper left', fontsize=9, title="Significance of abs(Martingale) Coef.")

# Subplot 2: R-squared
bars_r2 = axs[1].bar(df['model_name_short'], df['r_squared'], color='skyblue', alpha=0.8)
axs[1].set_ylabel('Model R-squared', fontsize=12)
axs[1].set_xlabel('Model Specification', fontsize=12, labelpad=10) # Add padding to x-axis label
# Customize x-axis tick labels: rotate for better readability
# axs[1].tick_params(axis='x', labelsize=9, rotation=10) # Removed ha='right'
axs[1].tick_params(axis='y', labelsize=10) # Customize y-axis tick labels
axs[1].set_ylim(0, max(df['r_squared']) * 1.15) # Adjust y-limit for R-squared plot to give some space at the top

# Annotate bars with R-squared values
for bar in bars_r2:
    yval = bar.get_height()
    axs[1].text(bar.get_x() + bar.get_width()/2.0, # x-position: center of the bar
                yval + 0.005,                   # y-position: slightly above the bar
                f'{yval:.3f}',                  # Text: formatted R-squared value
                ha='center',                    # Horizontal alignment: center
                va='bottom',                    # Vertical alignment: bottom
                fontsize=9)

# Add a super title for the entire figure
# plt.suptitle('Regression Analysis: Impact accuracy ~ abs(Martingale) + controls', fontsize=18, y=0.98)
# Adjust layout to prevent titles/labels from overlapping and make space for suptitle
plt.tight_layout(rect=[0, 0, 1, 1])
plt.savefig('data/figures/per_step/accuracy_regression.pdf')

# For reference, print the data table that was plotted
print("Data used for the plot:")
print(df[['model_spec', 'coef_bias', 'pval_bias', 'r_squared']].to_string())
