import random

import numpy as np
import statsmodels.api as sm
import tensorflow as tf

from relational_erm.data_cleaning.wikipedia import load_data_wikipedia_processed


### SIMULATING TREATMENT/OUTCOME VARIABLES:


def sigmoid(x):
    z = np.exp(-x)
    sig = 1 / (1 + z)

    return sig


def simulate_y(propensities, treatment, beta0=1.0, beta1=1.0, gamma=1.0, set_seed=42):
    confounding = (propensities - 0.5).astype(np.float32)
    np.random.seed(set_seed)
    noise = np.random.normal(0., 1., size=propensities.shape[0]).astype(np.float32)

    y0 = beta1 * confounding
    y1 = beta0 * treatment + y0
    y = y1 + gamma * noise

    return y, y0, y1


def simulate_y_binary(propensities, treatment, beta0=1.0, beta1=1.0, gamma=1.0, set_seed=42):
    confounding = (propensities - 0.5).astype(np.float32)
    np.random.seed(set_seed)
    noise = np.random.normal(0., 1., size=propensities.shape[0]).astype(np.float32)
    y0 = beta1 * confounding
    y1 = beta0 * treatment + y0
    y = y1 + gamma * noise
    y = np.random.binomial(1, sigmoid(y))
    return y


def simulate_from_wikipedia_covariate(data_dir, covariate='unique_category', beta0=1.0, beta1=1.0, gamma=1.0, set_seed=42):
    graph_data, profiles = load_data_wikipedia_processed(data_dir)

    np.random.seed(set_seed)


    # reindex to 0, 1, 2
    unique_category = profiles['unique_category'].tolist()
    unique_category = np.searchsorted(np.unique(unique_category), unique_category) - 1.



    if covariate == 'unique_category':
        confounder = unique_category
    else:
        raise Exception("covariate name not recognized")

    # simulate treatments and outcomes
    propensities = 0.5 + 0.35 * confounder
    treatment = np.random.binomial(1, propensities)

    treatment_agg = np.empty(shape=(len(treatment)), dtype=np.float32)
    for i in range(len(treatment)):
        neighbours = graph_data.adjacency_list.get_neighbours(i)
        # lst = np.append(i, neighbours)
        treatment_agg[i] = np.mean(treatment[neighbours], dtype=np.float32)
    treatment = treatment_agg

    y, y0, y1 = simulate_y(propensities, treatment, beta0=beta0, beta1=beta1, gamma=gamma, set_seed=set_seed)
    t = treatment.astype(np.float32)
    y = y.astype(np.float32)
    y0 = y0.astype(np.float32)
    y1 = y1.astype(np.float32)

    return t, y, y0, y1, propensities


def simulate_from_wikipedia_covariate_y_binary(data_dir, covariate='unique_category', beta0=1.0, beta1=1.0, gamma=1.0, set_seed=42):
    graph_data, profiles = load_data_wikipedia_processed(data_dir)

    np.random.seed(set_seed)

    # reindex to 0, 1, 2
    unique_category = profiles['unique_category'].tolist()
    unique_category = np.searchsorted(np.unique(unique_category), unique_category) - 1.



    if covariate == 'unique_category':
        confounder = unique_category
    else:
        raise Exception("covariate name not recognized")

    # simulate treatments and outcomes
    propensities = 0.5 + 0.35 * confounder
    treatment = np.random.binomial(1, propensities)

    treatment_agg = np.empty(shape=(len(treatment)), dtype=np.float32)
    for i in range(len(treatment)):
        neighbours = graph_data.adjacency_list.get_neighbours(i)
        # lst = np.append(i, neighbours)
        treatment_agg[i] = np.mean(treatment[neighbours], dtype=np.float32)
    treatment = treatment_agg

    y = simulate_y_binary(propensities, treatment, beta0=beta0, beta1=beta1, gamma=gamma, set_seed=set_seed)
    t = treatment.astype(np.float32)
    y = y.astype(np.float32)

    return t, y


def simulate_from_wikipedia_covariate_treatment_all0(data_dir, covariate='unique_category', beta0=1.0, beta1=1.0, gamma=1.0,
                                                 set_seed=42):
    graph_data, profiles = load_data_wikipedia_processed(data_dir)
    np.random.seed(set_seed)

    # reindex to 0, 1, 2
    unique_category = profiles['unique_category'].tolist()
    unique_category = np.searchsorted(np.unique(unique_category), unique_category) - 1.

    if covariate == 'unique_category':
        confounder = unique_category
    else:
        raise Exception("covariate name not recognized")

    # simulate treatments and outcomes
    propensities = 0.5 + 0.35 * confounder
    treatment = np.zeros(shape=len(propensities), dtype=np.float32)

    treatment_agg = np.empty(shape=(len(treatment)), dtype=np.float32)
    for i in range(len(treatment)):
        neighbours = graph_data.adjacency_list.get_neighbours(i)
        # lst = np.append(i, neighbours)
        treatment_agg[i] = np.mean(treatment[neighbours], dtype=np.float32)
    treatment = treatment_agg

    y, y0, y1 = simulate_y(propensities, treatment, beta0=beta0, beta1=beta1, gamma=gamma, set_seed=set_seed)
    t = treatment.astype(np.float32)
    y = y.astype(np.float32)
    y0 = y0.astype(np.float32)
    y1 = y1.astype(np.float32)

    return t, y, y0, y1


def simulate_from_wikipedia_covariate_treatment_all1(data_dir, covariate='unique_category', beta0=1.0, beta1=1.0, gamma=1.0,
                                                 set_seed=42):
    graph_data, profiles = load_data_wikipedia_processed(data_dir)

    np.random.seed(set_seed)

    # reindex to 0, 1, 2
    unique_category = profiles['unique_category'].tolist()
    unique_category = np.searchsorted(np.unique(unique_category), unique_category) - 1.


    if covariate == 'unique_category':
        confounder = unique_category
    else:
        raise Exception("covariate name not recognized")

    # simulate treatments and outcomes
    propensities = 0.5 + 0.35 * confounder
    treatment = np.ones(shape=len(propensities), dtype=np.float32)

    treatment_agg = np.empty(shape=(len(treatment)), dtype=np.float32)
    for i in range(len(treatment)):
        neighbours = graph_data.adjacency_list.get_neighbours(i)
        # lst = np.append(i, neighbours)
        treatment_agg[i] = np.mean(treatment[neighbours], dtype=np.float32)
    treatment = treatment_agg

    y, y0, y1 = simulate_y(propensities, treatment, beta0=beta0, beta1=beta1, gamma=gamma, set_seed=set_seed)
    t = treatment.astype(np.float32)
    y = y.astype(np.float32)
    y0 = y0.astype(np.float32)
    y1 = y1.astype(np.float32)

    return t, y, y0, y1


def simulate_from_wikipedia_covariate_treatment_label(data_dir, covariate='unique_category', set_seed=2):
    graph_data, profiles = load_data_wikipedia_processed(data_dir)


    # predictable covariates

    # reindex to 0, 1, 2
    unique_category = profiles['unique_category'].tolist()
    unique_category = np.searchsorted(np.unique(unique_category), unique_category) - 1.

    if covariate == 'unique_category':
        confounder = unique_category
    else:
        raise Exception("covariate name not recognized")

    # simulate treatments and outcomes
    np.random.seed(set_seed)
    propensities = 0.5 + 0.35 * confounder
    treatment = np.random.binomial(1, propensities)
    y = treatment
    treatment_new = treatment[:].copy()
    indices = np.where(np.in1d(treatment_new, [1]))[0]
    n_obs = indices.shape[0]
    subset = random.sample(indices.tolist(), int(n_obs / 2))
    treatment_new[subset] = 0

    treatment_agg = np.empty(shape=(len(treatment_new)), dtype=np.float32)
    for i in range(len(treatment_new)):
        neighbours = graph_data.adjacency_list.get_neighbours(i)
        treatment_agg[i] = np.mean(treatment_new[neighbours], dtype=np.float32)

    t = treatment_agg.astype(np.float32)
    y = y.astype(np.float32)

    return t, y


def simulate_from_wikipedia_covariate_treatment_all0_treatment_label(data_dir, covariate='unique_category'):
    graph_data, profiles = load_data_wikipedia_processed(data_dir)

    # reindex to 0, 1, 2
    unique_category = profiles['unique_category'].tolist()
    unique_category = np.searchsorted(np.unique(unique_category), unique_category) - 1.

    if covariate == 'unique_category':
        confounder = unique_category
    else:
        raise Exception("covariate name not recognized")

    # simulate treatments and outcomes
    propensities = 0.5 + 0.35 * confounder
    treatment = np.zeros(shape=len(propensities), dtype=np.float32)
    y = treatment
    treatment_agg = np.empty(shape=(len(treatment)), dtype=np.float32)
    for i in range(len(treatment)):
        neighbours = graph_data.adjacency_list.get_neighbours(i)
        # lst = np.append(i, neighbours)
        treatment_agg[i] = np.mean(treatment[neighbours], dtype=np.float32)
    treatment = treatment_agg
    # y = simulate_y_binary(propensities, treatment, beta0=beta0, beta1=beta1, gamma=gamma)
    t = treatment.astype(np.float32)
    y = y.astype(np.float32)

    return t, y


def simulate_from_wikipedia_covariate_treatment_all1_treatment_label(data_dir, covariate='unique_category'):
    graph_data, profiles = load_data_wikipedia_processed(data_dir)

    # reindex to 0, 1, 2
    unique_category = profiles['unique_category'].tolist()
    unique_category = np.searchsorted(np.unique(unique_category), unique_category) - 1.

    if covariate == 'unique_category':
        confounder = unique_category
    else:
        raise Exception("covariate name not recognized")

    # simulate treatments and outcomes
    propensities = 0.5 + 0.35 * confounder
    treatment = np.ones(shape=len(propensities), dtype=np.float32)
    y = treatment

    treatment_agg = np.empty(shape=(len(treatment)), dtype=np.float32)
    for i in range(len(treatment)):
        neighbours = graph_data.adjacency_list.get_neighbours(i)
        # lst = np.append(i, neighbours)
        treatment_agg[i] = np.mean(treatment[neighbours], dtype=np.float32)
    treatment = treatment_agg
    # y = simulate_y_binary(propensities, treatment, beta0=beta0, beta1=beta1, gamma=gamma)
    t = treatment.astype(np.float32)
    y = y.astype(np.float32)

    return t, y


def simulate_from_wikipedia_covariate_binary_region(data_dir, covariate='unique_category', set_seed=2):
    graph_data, profiles = load_data_wikipedia_processed(data_dir)


    # reindex to 0, 1, 2
    unique_category = profiles['unique_category'].tolist()
    unique_category = np.searchsorted(np.unique(unique_category), unique_category) - 1.


    if covariate == 'unique_category':
        confounder = unique_category

    else:
        raise Exception("covariate name not recognized")

    # simulate treatments and outcomes
    np.random.seed(set_seed)
    propensities = 0.5 + 0.35 * confounder
    treatment = np.random.binomial(1, propensities)
    y = treatment
    treatment_new = treatment[:].copy()
    indices = np.where(np.in1d(treatment_new, [1]))[0]
    n_obs = indices.shape[0]
    subset = random.sample(indices.tolist(), int(n_obs / 2))
    treatment_new[subset] = 0

    treatment_agg = np.empty(shape=(len(treatment_new)), dtype=np.float32)
    for i in range(len(treatment_new)):
        neighbours = graph_data.adjacency_list.get_neighbours(i)
        treatment_agg[i] = np.mean(treatment_new[neighbours], dtype=np.float32)

    t = treatment_agg.astype(np.float32)
    y = y.astype(np.float32)
    confounder = confounder.astype(np.float32)
    return t, confounder


def main():
    tf.compat.v1.enable_eager_execution()

    data_dir = 'dat/wikipedia'
    graph_data, profiles = load_data_wikipedia_processed(data_dir)
    # t, y = simulate_from_wikipedia_covariate_treatment_label(data_dir, covariate='unique_category', set_seed=2)
    #
    # y = list(y)
    # t = list(t)
    # log_reg = sm.Logit(y, t).fit()
    # unadjusted_ate = log_reg.params[0]
    # print(unadjusted_ate)
    #t, y = simulate_from_pokec_covariate(data_dir, covariate='region', set_seed=42)
    # t, y_all1, y0, y1 = simulate_fr_covariate_treatment_all1(data_dir, covariate='region', beta0=1.0, beta1=1,
    #                                                                  gamma=1.0, set_seed=42)
    # t, y_all0, y0, y1 = simulate_from_pokec_covariate_treatment_all0(data_dir, covariate='region', beta0=1.0, beta1=1,
    #                                                                  gamma=1.0, set_seed=42)
    #
    t, y, y0, y1, prop = simulate_from_wikipedia_covariate(data_dir, covariate='unique_category', beta0=1.0, beta1=10, gamma=1.0, set_seed=42)
    y = y
    X = t
    X = sm.add_constant(X)
    model1 = sm.OLS(y, X).fit()
    unadjusted_ate = model1.params
    print(unadjusted_ate)


if __name__ == '__main__':
    main()
