#
# Lalonde analysis
#

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor

def set_size(width, fraction=1, subplots=(3, 3)):
    """Set figure dimensions to avoid scaling in LaTeX.

    Parameters
    ----------
    width: float or string
            Document width in points, or string of predined document type
    fraction: float, optional
            Fraction of the width which you wish the figure to occupy
    subplots: array-like, optional
            The number of rows and columns of subplots.
    Returns
    -------
    fig_dim: tuple
            Dimensions of figure in inches
    """
    if width == 'thesis':
        width_pt = 426.79135
    elif width == 'beamer':
        width_pt = 307.28987
    else:
        width_pt = width

    # Width of figure (in pts)
    fig_width_pt = width_pt * fraction
    # Convert from pt to inches
    inches_per_pt = 1 / 72.27

    # Golden ratio to set aesthetic figure height
    # https://disq.us/p/2940ij3
    golden_ratio = (5**.5 - 1) / 2

    # Figure width in inches
    fig_width_in = fig_width_pt * inches_per_pt
    # Figure height in inches
    fig_height_in = fig_width_in * golden_ratio * (subplots[0] / subplots[1])

    return (fig_width_in, fig_height_in)

if __name__ == "__main__":
    # Load the Lalonde dataset
    url = "https://github.com/ljwa2323/MTNN/raw/refs/heads/main/data_lalonde_ps_MAR_0.5_1.csv"
    lalonde = pd.read_csv(url)

    # drop unnecessary columns
    lalonde = lalonde.loc[:, ~lalonde.columns.str.contains('_new')]
    lalonde = lalonde.drop(columns=['PS', 're75'])

    lalonde['r'] = 1 - lalonde['re74'].isna().astype(int)

    #
    # Naive diff-in-means
    #

    diff_in_means = lalonde.loc[lalonde['treat'] == 1, 're78'].mean() - lalonde.loc[lalonde['treat'] == 0, 're78'].mean()
    uq_dim = 1.96 * np.sqrt(lalonde.loc[lalonde['treat'] == 1, 're78'].var() / lalonde[lalonde['treat'] == 1].shape[0] + lalonde.loc[lalonde['treat'] == 0, 're78'].var() / lalonde[lalonde['treat'] == 0].shape[0])
    print(f"\nLalonde Diff-in-Means Estimate: {diff_in_means:.2f}")
    print(f"Lalonde Diff-in-Means 95% Confidence Interval: [{diff_in_means - uq_dim:.2f}, {diff_in_means + uq_dim:.2f}]")

    #
    # Diff in diff with complete data only
    # 

    complete_data = lalonde.dropna(subset=['re74', 're78'])
    diff_in_diff = (complete_data.loc[complete_data['treat'] == 1, 're78'].mean() - complete_data.loc[complete_data['treat'] == 1, 're74'].mean()) - (complete_data.loc[complete_data['treat'] == 0, 're78'].mean() - complete_data.loc[complete_data['treat'] == 0, 're74'].mean())
    uq_did = 1.96 * np.sqrt(complete_data.loc[complete_data['treat'] == 1, 're78'].var() / complete_data[complete_data['treat'] == 1].shape[0] + complete_data.loc[complete_data['treat'] == 1, 're74'].var() / complete_data[complete_data['treat'] == 1].shape[0] + complete_data.loc[complete_data['treat'] == 0, 're78'].var() / complete_data[complete_data['treat'] == 0].shape[0] + complete_data.loc[complete_data['treat'] == 0, 're74'].var() / complete_data[complete_data['treat'] == 0].shape[0])
    print(f"\nLalonde Diff-in-Diff Estimate (Complete Cases): {diff_in_diff:.2f}")
    print(f"Lalonde Diff-in-Diff 95% Confidence Interval (Complete Cases): [{diff_in_diff - uq_did:.2f}, {diff_in_diff + uq_did:.2f}]")

    #
    # Estimate nuisance functions
    #

    pi_model = RandomForestClassifier(n_estimators=100, max_depth=5, random_state=42)
    pi_model = pi_model.fit(lalonde.drop(columns=['re74', 're78']), lalonde['treat'])
    pi_hat = np.clip(pi_model.predict_proba(lalonde.drop(columns=['re74', 're78']))[:, 1], 0.01, 0.99)

    gamma1_model = RandomForestClassifier(n_estimators=100, max_depth=5, random_state=42)
    gamma1_model = gamma1_model.fit(lalonde[lalonde['treat'] == 1].drop(columns=['re74', 'treat']), lalonde.loc[lalonde['treat'] == 1, 'r'])
    gamma1_hat = np.clip(gamma1_model.predict_proba(lalonde.drop(columns=['re74', 'treat']))[:, 1], 0.01, 0.99)

    gamma0_model = RandomForestClassifier(n_estimators=100, max_depth=5, random_state=42)
    gamma0_model = gamma0_model.fit(lalonde[lalonde['treat'] == 0].drop(columns=['re74', 'treat']), lalonde.loc[lalonde['treat'] == 0, 'r'])
    gamma0_hat = np.clip(gamma0_model.predict_proba(lalonde.drop(columns=['re74', 'treat']))[:, 1], 0.01, 0.99)

    mu00_model = RandomForestRegressor(n_estimators=100, max_depth=5, random_state=42)
    mu00_model = mu00_model.fit(lalonde[(lalonde['treat'] == 0) & (lalonde['r'] == 1)].drop(columns=['re74', 'treat', 'r']), lalonde.loc[(lalonde['treat'] == 0) & (lalonde['r'] == 1), 're74'])
    mu00_hat = mu00_model.predict(lalonde.drop(columns=['re74', 'treat', 'r']))

    mu01_model = RandomForestRegressor(n_estimators=100, max_depth=5, random_state=42)
    mu01_model = mu01_model.fit(lalonde[(lalonde['treat'] == 1) & (lalonde['r'] == 1)].drop(columns=['re74', 'treat', 'r']), lalonde.loc[(lalonde['treat'] == 1) & (lalonde['r'] == 1), 're74'])
    mu01_hat = mu01_model.predict(lalonde.drop(columns=['re74', 'treat', 'r']))

    mu10_model = RandomForestRegressor(n_estimators=100, max_depth=5, random_state=42)
    mu10_model = mu10_model.fit(lalonde[lalonde['treat'] == 0].drop(columns=['re74', 'treat', 'r']), lalonde.loc[lalonde['treat'] == 0, 're78'])
    mu10_hat = mu10_model.predict(lalonde.drop(columns=['re74', 'treat', 'r']))

    lalonde = lalonde.fillna(0)

    # eta_model23 = RandomForestRegressor(n_estimators=100, max_depth=5, random_state=42)
    # eta_model23 = eta_model23.fit(lalonde.drop(columns=['re78', 'treat', 'r']), mu00_hat + lalonde['r'] * (lalonde['re74'] - mu00_hat) / gamma0_hat)
    # eta00_hat23 = eta_model23.predict(lalonde.drop(columns=['re78', 'treat', 'r']))
    eta00_hat23 = mu00_hat

    eta_model24 = RandomForestRegressor(n_estimators=100, max_depth=5, random_state=42)
    eta_model24 = eta_model24.fit(lalonde.drop(columns=['treat', 'r']), mu00_hat + lalonde['r'] * (lalonde['re74'] - mu00_hat) / gamma0_hat)
    eta00_hat24 = eta_model24.predict(lalonde.drop(columns=['treat', 'r']))

    eta_model24_noaugmentation = RandomForestRegressor(n_estimators=100, max_depth=5, random_state=42)
    eta_model24_noaugmentation = eta_model24_noaugmentation.fit(lalonde.drop(columns=['treat', 'r']), mu00_hat)
    eta00_hat24_noaugmentation = eta_model24_noaugmentation.predict(lalonde.drop(columns=['treat', 'r']))

    #
    # DR-DiD complete data
    #

    influence_function_cd = (lalonde['treat'] / lalonde['treat'].sum()) * (lalonde['re78'] - lalonde['re74'] - mu10_hat + mu00_hat) - ((1 - lalonde['treat']) * pi_hat / ((1 - pi_hat) * lalonde['treat'].sum())) * (lalonde['re78'] - lalonde['re74'] - mu10_hat + mu00_hat)
    att_cd = influence_function_cd.sum()
    delta_method_var = np.sum((influence_function_cd - lalonde['treat'] * influence_function_cd.sum() / lalonde['treat'].sum()) ** 2)
    uq_cd = 1.96 * np.sqrt(delta_method_var)

    print(f"\nLalonde DR-DiD Estimate (Complete Cases): {att_cd:.2f}")
    print(f"Lalonde DR-DiD 95% Confidence Interval (Complete Cases): [{att_cd - uq_cd:.2f}, {att_cd + uq_cd:.2f}]")

    #
    # MissingDiD 2.3
    #

    influence_function = (lalonde['treat'] / lalonde['treat'].sum()) * (lalonde['re78'] - (mu01_hat + lalonde['r'] * (lalonde['re74'] - mu01_hat) / gamma1_hat) - mu10_hat + eta00_hat23) - ((1 - lalonde['treat']) * pi_hat / ((1 - pi_hat) * lalonde['treat'].sum())) * (lalonde['re78'] - mu10_hat - mu00_hat + eta00_hat23 - lalonde['r'] * (lalonde['re74'] - mu00_hat) / gamma0_hat)
    att23 = influence_function.sum()
    delta_method_var = np.sum((influence_function - lalonde['treat'] * influence_function.sum() / lalonde['treat'].sum()) ** 2)
    uq23 = 1.96 * np.sqrt(delta_method_var)


    print(f"\nLalonde ATT Estimate (2.3): {att23:.2f}")
    print(f"Lalonde 95% Confidence Interval (2.3): [{att23 - uq23:.2f}, {att23 + uq23:.2f}]")

    #
    # MissingDiD 2.4
    #

    influence_function = (lalonde['treat'] / lalonde['treat'].sum()) * (lalonde['re78'] - (mu01_hat + lalonde['r'] * (lalonde['re74'] - mu01_hat) / gamma1_hat) - mu10_hat + eta00_hat24) - ((1 - lalonde['treat']) * pi_hat / ((1 - pi_hat) * lalonde['treat'].sum())) * (lalonde['re78'] - mu10_hat - mu00_hat + eta00_hat24 - lalonde['r'] * (lalonde['re74'] - mu00_hat) / gamma0_hat)
    att24 = influence_function.sum()
    delta_method_var = np.sum((influence_function - lalonde['treat'] * influence_function.sum() / lalonde['treat'].sum()) ** 2)
    uq24 = 1.96 * np.sqrt(delta_method_var)

    print(f"\nLalonde ATT Estimate (2.4): {att24:.2f}")
    print(f"Lalonde 95% Confidence Interval (2.4): [{att24 - uq24:.2f}, {att24 + uq24:.2f}]")

    #
    # Plotting results
    #

    sns.set_theme(style="whitegrid", palette="pastel", font_scale=1)
    width = 396

    palette = sns.color_palette()
    color = palette[0]

    # Plot four confidence intervals and point estimates horizontally
    axd = plt.figure(figsize=set_size(width, subplots=(1,1))).subplot_mosaic(
        [['main']])

    axd['main'].errorbar(y=['Our estimate (Ass. 2.4)'], x=[att24], xerr=[uq24], fmt='o', color=color)
    axd['main'].errorbar(y=['Our estimate (Ass. 2.3)'], x=[att23], xerr=[uq23], fmt='o', color=color)
    axd['main'].errorbar(y=['DR-DiD (Complete Cases)'], x=[att_cd], xerr=[uq_cd], fmt='o', color=color)
    axd['main'].errorbar(y=['Difference-In-Means'], x=[diff_in_means], xerr=[uq_dim], fmt='o', color=color)
    axd['main'].errorbar(y=['Difference-In-Differences (Complete Cases)'], x=[diff_in_diff], xerr=[uq_did], fmt='o', color=color) 
    axd['main'].axvline(x=1794.34, color='grey', linestyle='--', label='Experimental benchmark') 

    axd['main'].set_xlabel('ATT Estimate')
    axd['main'].set_title('Lalonde Dataset ATT Estimates')

    plt.savefig('lalonde_att_estimates.pdf', bbox_inches='tight')


    #
    # MissingDiD 2.4, no augmentation
    #

    influence_function = (lalonde['treat'] / lalonde['treat'].sum()) * (lalonde['re78'] - (mu01_hat + lalonde['r'] * (lalonde['re74'] - mu01_hat) / gamma1_hat) - mu10_hat + eta00_hat24_noaugmentation) - ((1 - lalonde['treat']) * pi_hat / ((1 - pi_hat) * lalonde['treat'].sum())) * (lalonde['re78'] - mu10_hat - mu00_hat + eta00_hat24_noaugmentation - lalonde['r'] * (lalonde['re74'] - mu00_hat) / gamma0_hat)
    att24_noaug = influence_function.sum()
    delta_method_var = np.sum((influence_function - lalonde['treat'] * influence_function.sum() / lalonde['treat'].sum()) ** 2)
    uq24_noaug = 1.96 * np.sqrt(delta_method_var)

    print(f"\nLalonde ATT Estimate (2.4): {att24_noaug:.2f}")
    print(f"Lalonde 95% Confidence Interval (2.4): [{att24_noaug - uq24_noaug:.2f}, {att24_noaug + uq24_noaug:.2f}]")

    #
    # MissingDiD 2.3 IPW
    #

    influence_function = (lalonde['treat'] / lalonde['treat'].sum()) * (lalonde['re78'] - (lalonde['r'] * (lalonde['re74']) / gamma1_hat) - mu10_hat + eta00_hat23) - ((1 - lalonde['treat']) * pi_hat / ((1 - pi_hat) * lalonde['treat'].sum())) * (lalonde['re78'] - mu10_hat + eta00_hat23 - lalonde['r'] * (lalonde['re74']) / gamma0_hat)
    att23_ipw = influence_function.sum()
    delta_method_var = np.sum((influence_function - lalonde['treat'] * influence_function.sum() / lalonde['treat'].sum()) ** 2)
    uq23_ipw = 1.96 * np.sqrt(delta_method_var)

    print(f"\nLalonde ATT Estimate (2.3): {att23_ipw:.2f}")
    print(f"Lalonde 95% Confidence Interval (2.3): [{att23_ipw - uq23_ipw:.2f}, {att23_ipw + uq23_ipw:.2f}]")

    #
    # MissingDiD 2.4 IPW
    #

    influence_function = (lalonde['treat'] / lalonde['treat'].sum()) * (lalonde['re78'] - (lalonde['r'] * (lalonde['re74']) / gamma1_hat) - mu10_hat + eta00_hat24_noaugmentation) - ((1 - lalonde['treat']) * pi_hat / ((1 - pi_hat) * lalonde['treat'].sum())) * (lalonde['re78'] - mu10_hat + eta00_hat24 - lalonde['r'] * (lalonde['re74']) / gamma0_hat)
    att24_ipw = influence_function.sum()
    delta_method_var = np.sum((influence_function - lalonde['treat'] * influence_function.sum() / lalonde['treat'].sum()) ** 2)
    uq24_ipw = 1.96 * np.sqrt(delta_method_var)

    print(f"\nLalonde ATT Estimate (2.4): {att24_ipw:.2f}")
    print(f"Lalonde 95% Confidence Interval (2.4): [{att24_ipw - uq24_ipw:.2f}, {att24_ipw + uq24_ipw:.2f}]")

    #
    # MissingDiD 2.4 IPW eta_augmented
    #

    influence_function = (lalonde['treat'] / lalonde['treat'].sum()) * (lalonde['re78'] - (lalonde['r'] * (lalonde['re74']) / gamma1_hat) - mu10_hat + eta00_hat24) - ((1 - lalonde['treat']) * pi_hat / ((1 - pi_hat) * lalonde['treat'].sum())) * (lalonde['re78'] - mu10_hat + eta00_hat24 - lalonde['r'] * (lalonde['re74']) / gamma0_hat)
    att24_ipw_aug = influence_function.sum()
    delta_method_var = np.sum((influence_function - lalonde['treat'] * influence_function.sum() / lalonde['treat'].sum()) ** 2)
    uq24_ipw_aug = 1.96 * np.sqrt(delta_method_var)

    print(f"\nLalonde ATT Estimate (2.4): {att24_ipw_aug:.2f}")
    print(f"Lalonde 95% Confidence Interval (2.4): [{att24_ipw_aug - uq24_ipw_aug:.2f}, {att24_ipw_aug + uq24_ipw_aug:.2f}]")

    #
    # Missing DiD 2.3 OR
    #

    influence_function = (lalonde['treat'] / lalonde['treat'].sum()) * (lalonde['re78'] - mu01_hat - mu10_hat + eta00_hat23) - ((1 - lalonde['treat']) * pi_hat / ((1 - pi_hat) * lalonde['treat'].sum())) * (lalonde['re78'] - mu10_hat - mu00_hat + eta00_hat23)
    att23_or = influence_function.sum()
    delta_method_var = np.sum((influence_function - lalonde['treat'] * influence_function.sum() / lalonde['treat'].sum()) ** 2)
    uq23_or = 1.96 * np.sqrt(delta_method_var)

    print(f"\nLalonde ATT Estimate (2.3): {att23_or:.2f}")
    print(f"Lalonde 95% Confidence Interval (2.3): [{att23_or - uq23_or:.2f}, {att23_or + uq23_or:.2f}]")

    #
    # Missing DiD 2.4 OR
    #

    influence_function = (lalonde['treat'] / lalonde['treat'].sum()) * (lalonde['re78'] - mu01_hat - mu10_hat + eta00_hat24_noaugmentation) - ((1 - lalonde['treat']) * pi_hat / ((1 - pi_hat) * lalonde['treat'].sum())) * (lalonde['re78'] - mu10_hat - mu00_hat + eta00_hat24)
    att24_or = influence_function.sum()
    delta_method_var = np.sum((influence_function - lalonde['treat'] * influence_function.sum() / lalonde['treat'].sum()) ** 2)
    uq24_or = 1.96 * np.sqrt(delta_method_var)

    print(f"\nLalonde ATT Estimate (2.4): {att24_or:.2f}")
    print(f"Lalonde 95% Confidence Interval (2.4): [{att24_or - uq24_or:.2f}, {att24_or + uq24_or:.2f}]")

    #
    # Missing DiD 2.4 OR eta augmented
    #

    influence_function = (lalonde['treat'] / lalonde['treat'].sum()) * (lalonde['re78'] - mu01_hat - mu10_hat + eta00_hat24) - ((1 - lalonde['treat']) * pi_hat / ((1 - pi_hat) * lalonde['treat'].sum())) * (lalonde['re78'] - mu10_hat - mu00_hat + eta00_hat24)
    att24_or_aug = influence_function.sum()
    delta_method_var = np.sum((influence_function - lalonde['treat'] * influence_function.sum() / lalonde['treat'].sum()) ** 2)
    uq24_or_aug = 1.96 * np.sqrt(delta_method_var)

    print(f"\nLalonde ATT Estimate (2.4): {att24_or_aug:.2f}")
    print(f"Lalonde 95% Confidence Interval (2.4): [{att24_or_aug - uq24_or_aug:.2f}, {att24_or_aug + uq24_or_aug:.2f}]")

    #
    # Plot confidence intervals and point estimates horizontally
    #

    # all bars same color, the defaul seaborn blue
    palette = sns.color_palette()
    color = palette[0]
    
    axd = plt.figure(figsize=set_size(width, subplots=(1,1))).subplot_mosaic(
        [['main']])
    
    axd['main'].errorbar(y=['DR (Ass. 2.4), augmented'], x=[att24], xerr=[uq24], fmt='o', color=color)
    axd['main'].errorbar(y=['DR (Ass. 2.4)'], x=[att24_noaug], xerr=[uq24_noaug], fmt='o', color=color)
    axd['main'].errorbar(y=['DR (Ass. 2.3)'], x=[att23], xerr=[uq23], fmt='o', color=color)
    axd['main'].errorbar(y=['OR (Ass. 2.4), augmented'], x=[att24_or_aug], xerr=[uq24_or_aug], fmt='o', color=color)
    axd['main'].errorbar(y=['OR (Ass. 2.4)'], x=[att24_or], xerr=[uq24_or], fmt='o', color=color)
    axd['main'].errorbar(y=['OR (Ass. 2.3)'], x=[att23_or], xerr=[uq23_or], fmt='o', color=color)
    axd['main'].errorbar(y=['IPW (Ass. 2.4), augmented'], x=[att24_ipw_aug], xerr=[uq24_ipw_aug], fmt='o', color=color)
    axd['main'].errorbar(y=['IPW (Ass. 2.4)'], x=[att24_ipw], xerr=[uq24_ipw], fmt='o', color=color)    
    axd['main'].errorbar(y=['IPW (Ass. 2.3)'], x=[att23_ipw], xerr=[uq23_ipw], fmt='o', color=color)
    axd['main'].axvline(x=1794.34, color='grey', linestyle='--', label='Experimental benchmark') 

    axd['main'].set_xlabel('ATT Estimate')
    axd['main'].set_title('Lalonde Dataset ATT Estimates')

    plt.savefig('lalonde_att_estimates_variants.pdf', bbox_inches='tight')