import pandas as pd
import matplotlib
matplotlib.use('Agg')  # Use non-interactive backend
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import probplot
# from ace_tools import display_dataframe_to_user

# Load and prepare the data
df = pd.read_csv('research_funding_rates.csv')

# Order disciplines by overall success rate
order = df.sort_values('success_rates_total', ascending=False)['discipline']

# Create tidy dataframe
men_df = df[['discipline', 'applications_men', 'awards_men', 'success_rates_men']].copy()
men_df.columns = ['discipline', 'applications', 'awards', 'success_rate']
men_df['gender'] = 'Men'

women_df = df[['discipline', 'applications_women', 'awards_women', 'success_rates_women']].copy()
women_df.columns = ['discipline', 'applications', 'awards', 'success_rate']
women_df['gender'] = 'Women'

tidy = pd.concat([men_df, women_df], ignore_index=True)

# Reorder disciplines by overall success rate
tidy['discipline'] = pd.Categorical(tidy['discipline'], categories=order, ordered=True)
tidy = tidy.sort_values('discipline')

# Map disciplines to numeric codes
tidy['disc_code'] = tidy['discipline'].cat.codes

# Plot
fig, ax = plt.subplots(figsize=(10, 6))
for gender in tidy['gender'].unique():
    subset = tidy[tidy['gender'] == gender]
    ax.scatter(subset['disc_code'], subset['success_rate'], s=subset['applications'], label=gender)

# Formatting
ax.set_xticks(tidy['disc_code'].unique())
ax.set_xticklabels(order, rotation=45, ha='right')
ax.set_xlabel('Discipline (ordered by overall success rate)')
ax.set_ylabel('Success Rate (%)')
ax.legend(title='Gender')
plt.tight_layout()
plt.savefig('funding_success_rates_by_discipline.png', dpi=300, bbox_inches='tight')
# plt.show()  # Disabled for Agg backend

# Compute standardized log odds ratios (z-scores)
z_scores = []
for idx, row in df.iterrows():
    a = row['awards_men']
    b = row['applications_men'] - row['awards_men']
    c = row['awards_women']
    d = row['applications_women'] - row['awards_women']
    log_or = np.log((a/b) / (c/d))
    se_log_or = np.sqrt(1/a + 1/b + 1/c + 1/d)
    z_scores.append(log_or / se_log_or)

results = pd.DataFrame({
    'discipline': df['discipline'],
    'z_score': z_scores
})

# Display z-scores
# display_dataframe_to_user("Standardized Log Odds Ratios (Z-scores) by Discipline", results)

# QQ-plot
fig, ax = plt.subplots()
probplot(results['z_score'], dist="norm", plot=ax)
ax.set_title("QQ-Plot of Standardized Log Odds Ratios")
ax.set_xlabel("Theoretical Quantiles")
ax.set_ylabel("Sample Quantiles")
plt.tight_layout()
plt.savefig('qq_plot_log_odds_ratios.png', dpi=300, bbox_inches='tight')
# plt.show()  # Disabled for Agg backend