import logging
import numpy as np
import torch
import pandas as pd
import pyro
from CFDiVAE_gpu import CFDiVAE_Model
from load_datasets import load_data_CFDiVAE
import statsmodels.api as sm

logging.getLogger("pyro").setLevel(logging.DEBUG)
logging.getLogger("pyro").handlers[0].setLevel(logging.DEBUG)


def run(N, num):
    pyro.enable_validation(__debug__)
    torch.set_default_tensor_type('torch.cuda.FloatTensor')

    # Load synthetic data.
    train = load_data_CFDiVAE(path="./Data", N=N, num=num)
    (x_train, t_train, y_train) = train

    # Train CFDiVAE
    pyro.clear_param_store()
    CFDiVAE = CFDiVAE_Model(feature_dim=9,
                            outcome_dist="normal",
                            latent_dim_Zf=2,
                            hidden_dim=200,
                            num_layers=3)
    CFDiVAE.fit(x_train,
                t_train,
                y_train,
                num_epochs=30,
                batch_size=256,
                learning_rate=1e-3,
                learning_rate_decay=0.01,
                weight_decay=1e-4)

    zf_res = CFDiVAE.guide.zf(x_train)

    zf_res = zf_res.cpu().detach().numpy().astype(np.float16)
    y_train = y_train.cpu().detach().numpy().astype(np.float16)
    t_train = t_train.cpu().detach().numpy().astype(np.float16)
    x_train = x_train.cpu().detach().numpy().astype(np.float16)

    zf_res = zf_res.flatten()

    zf_res = zf_res.reshape(-1, 2)
    t_train = t_train.reshape(-1, 1)
    y_train = y_train.reshape(-1, 1)
    x_train = x_train[:, 1:5].reshape(-1, 4)

    res = np.concatenate((t_train, y_train, zf_res, x_train), axis=1)

    return res


def estimate_ate_frontdoor_linear(df, t, y, z, w1, w2, w3, w4):
    t = df[t].values
    y = df[y].values
    z = df[z].values
    w1 = df[w1].values
    w2 = df[w2].values
    w3 = df[w3].values
    w4 = df[w4].values

    z_t_model = sm.OLS(z, sm.add_constant(np.column_stack((t, w1, w2, w3, w4)))).fit()

    z_bar = z_t_model.predict(sm.add_constant(np.column_stack((t, w1, w2, w3, w4))))
    z_prime = z - z_bar

    y_z_model = sm.OLS(y, sm.add_constant(z_prime)).fit()

    return y_z_model.params[1] * z_t_model.params[1]


if __name__ == '__main__':
    N_list = [500, 1000, 2000, 4000, 6000, 8000, 10000, 20000]

    for N in N_list:
        CE = []
        CE_bias = []
        for i in range(30):
            res = pd.DataFrame(run(N, i))
            print("Finished:", str(i + 1))
            order = ['t', 'y', 'z1', 'z2','w1','w2','w3','w4']
            res.columns = order
            effect1 = estimate_ate_frontdoor_linear(res, 't', 'y', 'z1','w1','w2','w3','w4')
            effect2 = estimate_ate_frontdoor_linear(res, 't', 'y', 'z2', 'w1', 'w2', 'w3', 'w4')
            effect = effect1 + effect2
            bias = abs(effect - 10) / 10
            CE.append(effect)
            CE_bias.append(bias)
            print(bias * 100)
        res = pd.DataFrame(CE)
        res.to_csv(r'./Res/CFDiVAE_2/CFDiVAE_' + str(N) + '.csv', sep=',', float_format='%.5f')

        res = pd.DataFrame(CE_bias)
        res.to_csv(r'./Res/CFDiVAE_2/CFDiVAE_bias_' + str(N) + '.csv', sep=',', float_format='%.5f')
