"""
Analyze Framingham data
"""

# imports
import pandas as pd
from methods import *

def process_data():
    """
    Function to process the longitudinal data and combine rows corresponding
    to the first 2 time points of each individual
    """

    # read in the data
    orig_data = pd.read_csv("framingham_data.csv")
    # only keep individuals with at least one follow up
    counts = orig_data["RANDID"].value_counts()
    num_no_follow_up = 0
    ids = {i: 0 for i in counts.index if counts[i] > 1}

    # info collected at first time point
    age, sex, bpmeds, cur_smoke, prev_chd, prev_hypertension, diabetes, educ, prev_stroke, glucose, bmi = [], [], [], [], [], [], [], [], [], [], []
    # info collected at second time point
    hypertension, anychd, stroke, mifchd = [], [], [], []

    # info collected at second time point
    for index, row in orig_data.iterrows():
        individual_id = row["RANDID"]
        if individual_id not in ids:
            continue

        # if encountering first time step
        if ids[individual_id] == 0:
            ids[individual_id] += 1
            age.append(row["AGE"])
            sex.append(row["SEX"])
            bpmeds.append(row["BPMEDS"])
            # if row["CIGPDAY"] >= 20:
            #     cur_smoke.append(1)
            # else:
            #     cur_smoke.append(0)
            cur_smoke.append(row["CURSMOKE"])
            prev_hypertension.append(row["PREVHYP"])
            prev_chd.append(row["PREVCHD"])
            diabetes.append(row["DIABETES"])
            educ.append(row["educ"])
            prev_stroke.append(row["PREVSTRK"])
            glucose.append(row["GLUCOSE"])
            bmi.append(row["BMI"])


        elif ids[individual_id] == 1:
            ids[individual_id] += 1
            hypertension.append(row["HYPERTEN"])
            anychd.append(row["ANYCHD"])
            stroke.append(row["STROKE"])
            mifchd.append(row["MI_FCHD"])


    # put all the variables together in a data frame
    data_processed = pd.DataFrame({"age": age, "sex": sex, "bpmeds": bpmeds, "smoking": cur_smoke, "bmi": bmi, "glucose": glucose,
                                   "hype": hypertension, "anychd": anychd, "prevchd": prev_chd, "prevstroke": prev_stroke,
                                   "prevhype": prev_hypertension, "diabetes": diabetes, "stroke": stroke, "educ": educ, "mifchd": mifchd})

    # drop missing rows and return
    return data_processed.dropna()


if __name__ == "__main__":

    np.random.seed(0)
    # set significance level and bootstraps
    alpha = 0.05
    num_bootstraps = 200
    data = process_data()

    # define various variables/sets
    A = "smoking"; M = "hype"; Y = "anychd"
    C = ["age", "sex", "prevchd", "bmi"]
    Z = "prevhype"
    mpA = C + [Z]
    mpM = C + [Z, A]
    mpY = C + [Z, A, M]
    print("Proportion smokers", np.mean(data[A]))

    ml_dual_weights_A0 = ml_dual_weights(data, M, mpM, A, 0, trim=False)
    ml_dual_weights_A1 = ml_dual_weights(data, M, mpM, A, 1, trim=False)
    dual_p_val = fcit_test(data, Y, Z, C + [M], ml_dual_weights_A1)
    print("ML dual pval", dual_p_val)
    print("ML dual IPW point estimate", dual_ipw(data, Y, ml_dual_weights_A0, ml_dual_weights_A1))
    print("IV pval", fcit_test(data, Z, M, C + [A]))
    print("Backdoor pval", fcit_test(data, Z, Y, C + [A]))

    Ql = alpha/2
    Qu = 1 - alpha/2
    estimates = []

    for i in range(num_bootstraps):

        # resample the data with replacement
        data_sampled = data.sample(len(data), replace=True)
        data_sampled.reset_index(drop=True, inplace=True)
        ml_dual_weights_A0 = ml_dual_weights(data_sampled, M, mpM, A, 0, trim=False)
        ml_dual_weights_A1 = ml_dual_weights(data_sampled, M, mpM, A, 1, trim=False)
        estimates.append(dual_ipw(data_sampled, Y, ml_dual_weights_A0, ml_dual_weights_A1))

    # calculate the quantiles
    quantiles = np.quantile(estimates, q=[Ql, Qu])
    q_low = quantiles[0]
    q_up = quantiles[1]
    print("ML dual IPW 95% CIs", q_low, q_up)
