import numpy as np
import pandas as pd
from statsmodels.api import Logit, OLS, add_constant
import warnings
from tqdm import tqdm
import itertools

import matplotlib.pyplot as plt

# Suppress convergence and future warnings for cleaner output
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

def generate_data(n, att=5, ass='2.3'):
    """
    Generates a dataset based on the simulation parameters in the paper.

    Args:
        n (int): The sample size.
        att (float): The true average treatment effect on the treated.
        ass (str): assumption (see paper).

    Returns:
        pd.DataFrame: A DataFrame containing the generated data.
    """

    # Generate true covariates
    z = np.random.normal(0, 1, size=(n, 4))

    # Generate propensity scores and treatment indicator
    true_propensity_linear = - z[:, 0] + 0.5 * z[:, 1] - 0.25 * z[:, 2] - 0.1 * z[:, 3]
    true_propensity = 1 / (1 + np.exp(-true_propensity_linear))
    t_indicator = np.random.binomial(1, true_propensity)

    # Generate outcomes
    y0 = 210 + 27.4 * z[:, 0] + 13.7 * z[:, 1] + 13.7 * z[:, 2] + 13.7 * z[:, 3] + np.random.normal(0, 1, n)
    y10 = y0 + np.random.normal(0, 1, n)
    y11 = y10 + t_indicator * att

    y1 = np.where(t_indicator == 1, y11, y10)

    # Generate missingness pattern and missingness indicator
    if ass == '2.3':
        true_missingness_linear_0 = - 0.25 * z[:, 0] - 0.1 * z[:, 1] - 0.5 * z[:, 2] + 0.3 * z[:, 3]
        true_missingness_linear_1 = true_missingness_linear_0 - 0.2
        true_missingness_linear = np.where(t_indicator == 1, true_missingness_linear_1, true_missingness_linear_0)
    else: 
        true_missingness_linear_0 = - 0.25 * z[:, 0] - 0.1 * z[:, 1] - 0.5 * z[:, 2] + 0.3 * z[:, 3] + 0.01 * y1
        true_missingness_linear_1 = true_missingness_linear_0 - 0.2
        true_missingness_linear = np.where(t_indicator == 1, true_missingness_linear_1, true_missingness_linear_0)

    true_missingness = 1 / (1 + np.exp(-true_missingness_linear))
    r_indicator = np.random.binomial(1, true_missingness)

    true_missingness_0 = 1 / (1 + np.exp(-true_missingness_linear_0))
    true_missingness_1 = 1 / (1 + np.exp(-true_missingness_linear_1))

    # Generate observed y0
    y0_observed = np.where(r_indicator == 1, y0, np.nan)

    # Generate observed covariates
    x1 = np.exp(z[:, 0] / 2)
    x2 = z[:, 1] / (1 + np.exp(z[:, 0])) + 10
    x3 = (z[:, 0] * z[:, 2] / 25 + 0.6)**3
    x4 = (z[:, 1] + z[:, 3] + 20)**2

    return pd.DataFrame({
        'z1': z[:, 0], 'z2': z[:, 1], 'z3': z[:, 2], 'z4': z[:, 3],
        'x1': x1, 'x2': x2, 'x3': x3, 'x4': x4,
        'y0': y0_observed, 'y1': y1,
        't': t_indicator, 'r': r_indicator,
        'pi': true_propensity, 'gamma0': true_missingness_0, 
        'gamma1': true_missingness_1, 'mu00': y0,
        'mu10': y10, 'mu01': y0
    })

def estimate_propensity_scores(data, use_correct_model):
    """
    Estimates propensity scores using a logistic regression model.
    """
    X = add_constant(data[['z1', 'z2', 'z3', 'z4']]) if use_correct_model else add_constant(data[['x1', 'x2', 'x3', 'x4']])
    logit_model = Logit(data['t'], X).fit(disp=0)
    return np.clip(logit_model.predict(X), 0.01, 0.99)

    
def estimate_missingness_scores(data, ass, t, use_correct_model):
    """
    Estimates missingness scores using a logistic regression model.
    """
    respondents = data['t'] == t
    if ass == '2.3':
        X_resp = add_constant(data.loc[respondents, ['z1', 'z2', 'z3', 'z4']]) if use_correct_model else add_constant(data.loc[respondents, ['x1', 'x2', 'x3', 'x4']])
        X_all = add_constant(data[['z1', 'z2', 'z3', 'z4']]) if use_correct_model else add_constant(data[['x1', 'x2', 'x3', 'x4']])
    else:
        X_resp = add_constant(data.loc[respondents, ['z1', 'z2', 'z3', 'z4', 'y1']]) if use_correct_model else add_constant(data.loc[respondents, ['x1', 'x2', 'x3', 'x4', 'y1']])
        X_all = add_constant(data[['z1', 'z2', 'z3', 'z4', 'y1']]) if use_correct_model else add_constant(data[['x1', 'x2', 'x3', 'x4', 'y1']])

    logit_model = Logit(data.loc[respondents,'r'], X_resp).fit(disp=0)
    #print(logit_model.summary())
    return np.clip(logit_model.predict(X_all), 0.01, 0.99)
    
def get_y_model_predictions(data, ass, t, target, use_correct_model):
    """
    Fits a y-model on respondents and predicts for the entire sample.
    """
    respondents_r = (data['r'] == 1)
    respondents_t = (data['t'] == t)

    if ass == '2.3':
        X_resp = add_constant(data.loc[(respondents_r & respondents_t), ['z1', 'z2', 'z3', 'z4']]) if use_correct_model else add_constant(data.loc[(respondents_r & respondents_t), ['x1', 'x2', 'x3', 'x4']])
        X_all = add_constant(data[['z1', 'z2', 'z3', 'z4']]) if use_correct_model else add_constant(data[['x1', 'x2', 'x3', 'x4']])
    else:
        X_resp = add_constant(data.loc[(respondents_r & respondents_t), ['z1', 'z2', 'z3', 'z4', 'y1']]) if use_correct_model else add_constant(data.loc[(respondents_r & respondents_t), ['x1', 'x2', 'x3', 'x4', 'y1']])
        X_all = add_constant(data[['z1', 'z2', 'z3', 'z4', 'y1']]) if use_correct_model else add_constant(data[['x1', 'x2', 'x3', 'x4', 'y1']])

    if target == 'y1':
        X_resp = add_constant(data.loc[respondents_t, ['z1', 'z2', 'z3', 'z4']]) if use_correct_model else add_constant(data.loc[respondents_t, ['x1', 'x2', 'x3', 'x4']])
        X_all = add_constant(data[['z1', 'z2', 'z3', 'z4']]) if use_correct_model else add_constant(data[['x1', 'x2', 'x3', 'x4']])
        ols_model = OLS(data.loc[respondents_t, target], X_resp).fit()
    else:
        ols_model = OLS(data.loc[(respondents_r & respondents_t), target], X_resp).fit()

    m_hat_all = ols_model.predict(X_all)
    return m_hat_all

def fit_missingdid(data, ass, use_correct_model_mu, use_correct_model_gamma, use_correct_model_pi, use_correct_model_eta, augmentation, oracle=False):
    
    pi_hat = estimate_propensity_scores(data, use_correct_model=use_correct_model_pi)
    r_hat1 = estimate_missingness_scores(data, ass=ass, t=1, use_correct_model=use_correct_model_gamma)
    r_hat0 = estimate_missingness_scores(data, ass=ass, t=0, use_correct_model=use_correct_model_gamma)
    m_hat01 = get_y_model_predictions(data, ass=ass, t=1, target='y0', use_correct_model=use_correct_model_mu)
    m_hat10 = get_y_model_predictions(data, ass=ass, t=0, target='y1', use_correct_model=use_correct_model_mu)
    m_hat00 = get_y_model_predictions(data, ass=ass, t=0, target='y0', use_correct_model=use_correct_model_mu)

    if oracle:
        pi_hat = data['pi']
        r_hat1 = data['gamma1']
        r_hat0 = data['gamma0']
        # m_hat01 = data['mu01']
        # m_hat10 = data['mu10']
        # m_hat00 = data['mu00']

    # Fill na with 0
    data = data.fillna(0)

    if ass == '2.3':
        eta_hat00 = m_hat00
    else:
        X_all = add_constant(data[['z1', 'z2', 'z3', 'z4']]) if use_correct_model_eta else add_constant(data[['x1', 'x2', 'x3', 'x4']])
        if augmentation:
            ols_model = OLS(m_hat00 + data['r'] * (data['y0'] - m_hat00) / r_hat0, X_all).fit()
        else:
            ols_model = OLS(m_hat00, X_all).fit()
        eta_hat00 = ols_model.predict(X_all)

    influence_function = (data['t'] / data['t'].sum()) * (data['y1'] - (m_hat01 + data['r'] * (data['y0'] - m_hat01) / r_hat1) - m_hat10 + eta_hat00) - ((1 - data['t']) * pi_hat / ((1 - pi_hat) * data['t'].sum())) * (data['y1'] - m_hat10 - m_hat00 + eta_hat00 - data['r'] * (data['y0'] - m_hat00) / r_hat0)
    delta_method_var = np.sum((influence_function - data['t'] * influence_function.sum() / data['t'].sum()) ** 2)

    return influence_function.sum(), np.sqrt(delta_method_var)

if __name__ == "__main__":

    # # run one sim
    # np.random.seed(6)
    # for i in range(100):
    #     if i == 98:
    #         data_sim = generate_data(n=2000, att=5, ass='2.3')

    # print(data_sim.head())

    # # compute empirical att
    # print(np.mean(data_sim.loc[data_sim['t'] == 1, 'y1'] - data_sim.loc[data_sim['t'] == 1, 'y0']))

    # print(sum(data_sim['t']))
    # print(sum(data_sim['r']))
    # print(sum(data_sim['r'] & (data_sim['t'] == 1)))

    # # Assumption 2.3
    # # Estimate nuisance functions with correct models
    # pi_hat = estimate_propensity_scores(data_sim, use_correct_model=True)
    # r_hat1 = estimate_missingness_scores(data_sim, ass='2.3', t=1, use_correct_model=True)
    # r_hat0 = estimate_missingness_scores(data_sim, ass='2.3', t=0, use_correct_model=True)
    # m_hat01 = get_y_model_predictions(data_sim, ass='2.3', t=1, target='y0', use_correct_model=False)
    # m_hat10 = get_y_model_predictions(data_sim, ass='2.3', t=0, target='y1', use_correct_model=False)
    # m_hat00 = get_y_model_predictions(data_sim, ass='2.3', t=0, target='y0', use_correct_model=False)

    # pi_hat = data_sim['pi']
    # r_hat1 = data_sim['gamma1']
    # r_hat0 = data_sim['gamma0']

    # plt.plot(pi_hat, data_sim['pi'], '.')
    # plt.show()

    # plt.plot(r_hat1, data_sim['gamma1'], '.')
    # plt.show()

    # plt.plot(r_hat0, data_sim['gamma0'], '.')
    # plt.show()

    # # Fill na with 0
    # data_sim = data_sim.fillna(0)

    # influence_function = (data_sim['t'] / data_sim['t'].sum()) * (data_sim['y1'] - (m_hat01 + data_sim['r'] * (data_sim['y0'] - m_hat01) / r_hat1) - m_hat10 + m_hat00) - ((1 - data_sim['t']) * pi_hat / ((1 - pi_hat) * data_sim['t'].sum())) * (data_sim['y1'] - m_hat10 - data_sim['r'] * (data_sim['y0'] - m_hat00) / r_hat0)

    # print(influence_function.sum())


    # # For loop all possible specifications
    # mu_spec = [True, False]
    # gamma_spec = [True, False]
    # pi_spec = [True, False]

    # for mu in mu_spec:
    #     for gamma in gamma_spec:
    #         for pi in pi_spec:
    #             print(f"mu: {mu}, gamma: {gamma}, pi: {pi}")
    #             # Estimate nuisance functions
    #             pi_hat = estimate_propensity_scores(data_sim, use_correct_model=pi)
    #             r_hat1 = estimate_missingness_scores(data_sim, ass='2.3', t=1, use_correct_model=gamma)
    #             r_hat0 = estimate_missingness_scores(data_sim, ass='2.3', t=0, use_correct_model=gamma)
    #             m_hat01 = get_y_model_predictions(data_sim, ass='2.3', t=1, target='y0', use_correct_model=mu)
    #             m_hat10 = get_y_model_predictions(data_sim, ass='2.3', t=0, target='y1', use_correct_model=mu)
    #             m_hat00 = get_y_model_predictions(data_sim, ass='2.3', t=0, target='y0', use_correct_model=mu)

    #             # Fill na with 0
    #             data = data_sim.fillna(0)

    #             influence_function = (data['t'] / data['t'].sum()) * (data['y1'] - (m_hat01 + data['r'] * (data['y0'] - m_hat01) / r_hat1) - m_hat10 + m_hat00) - ((1 - data['t']) * pi_hat / ((1 - pi_hat) * data['t'].sum())) * (data['y1'] - m_hat10 - data['r'] * (data['y0'] - m_hat00) / r_hat0)

    #             print(influence_function.sum())

    n = 500
    att = 5
    n_sims = 500
    np.random.seed(6)

    results = pd.DataFrame(columns=['seed', 'ass', 'mu', 'gamma', 'pi', 'eta', 'augmentation_eta', 'estimate', 'std'])
    for seed in tqdm(range(n_sims)):
        for ass in ['2.3', '2.4']:
            data_sim = generate_data(n=n, att=att, ass=ass)

            if data_sim['r'].sum() == len(data_sim['r']):
                print("All data missing")
                break

            if data_sim['r'].sum() == 0:
                print("No missing data")
                break

            for mu, gamma, pi, eta, augmentation in itertools.product([True, False], [True, False], [True, False], [True, False], [True, False]):
                estimate, std = fit_missingdid(data_sim, ass=ass, use_correct_model_mu=mu, use_correct_model_gamma=gamma, use_correct_model_pi=pi, use_correct_model_eta=eta, augmentation=augmentation)
                width = 1.96 * std #* np.sqrt(n)
                coverage = (att <= estimate + width) and (att >= estimate - width)

                res_row = {
                    'seed': seed,
                    'ass': ass,
                    'mu': mu,
                    'gamma': gamma,
                    'pi': pi,
                    'eta': eta,
                    'augmentation_eta': augmentation,
                    'estimate': estimate,
                    'std': std,
                    'coverage': coverage,
                    'width': 2 * width,
                    'n_missing': data_sim['r'].sum(),
                    'n_treated': data_sim['t'].sum(),
                    'oracle': False
                }

                results = pd.concat([results, pd.DataFrame([res_row])], ignore_index=True)

            oracle_estimate, oracle_std = fit_missingdid(data_sim, ass=ass, use_correct_model_mu=False, use_correct_model_gamma=True, use_correct_model_pi=True, use_correct_model_eta=False, augmentation=False, oracle=True)
            width_oracle = 1.96 * oracle_std #* np.sqrt(n)
            coverage_oracle = (att <= oracle_estimate + width_oracle) and (att >= oracle_estimate - width_oracle)

            res_row = {
                    'seed': seed,
                    'ass': ass,
                    'mu': False,
                    'gamma': True,
                    'pi': True,
                    'eta': False,
                    'augmentation_eta': False,
                    'estimate': oracle_estimate,
                    'std': oracle_std,
                    'coverage': coverage_oracle,
                    'width': 2 * width_oracle,
                    'n_missing': data_sim['r'].sum(),
                    'n_treated': data_sim['t'].sum(),
                    'oracle': True
                }

            results = pd.concat([results, pd.DataFrame([res_row])], ignore_index=True)

    results.to_csv(f'sim_results_{n}.csv', index=False)