import logging
import numpy as np
import torch
import pandas as pd
import pyro
from CFDiVAE_gpu import CFDiVAE_Model
from load_datasets import load_data_CFDiVAE
import statsmodels.api as sm

logging.getLogger("pyro").setLevel(logging.DEBUG)
logging.getLogger("pyro").handlers[0].setLevel(logging.DEBUG)


def run(N, num):
    pyro.enable_validation(__debug__)
    torch.set_default_tensor_type('torch.cuda.FloatTensor')

    # Load synthetic data.
    train = load_data_CFDiVAE(path="./Data", N=N, num=num)
    (x_train, t_train, y_train, zf_real) = train

    # Train CFDiVAE
    pyro.clear_param_store()
    CFDiVAE = CFDiVAE_Model(feature_dim=9,
                            outcome_dist="normal",
                            latent_dim_Zf=1,
                            hidden_dim=200,
                            num_layers=3)
    CFDiVAE.fit(x_train,
                t_train,
                y_train,
                num_epochs=30,
                batch_size=256,
                learning_rate=1e-3,
                learning_rate_decay=0.01,
                weight_decay=1e-4)

    zf_res = CFDiVAE.guide.zf(x_train)
    zf_res = zf_res.cpu().detach().numpy().astype(np.float16)

    zf_real = zf_real.cpu().detach().numpy().astype(np.float16)

    zf_res = zf_res.reshape(N, 1)
    zf_real = zf_real.reshape(N, 1)

    return zf_res, zf_real


def estimate_ate_frontdoor_linear(df, t, y, z, w1, w2, w3, w4):
    t = df[t].values
    y = df[y].values
    z = df[z].values
    w1 = df[w1].values
    w2 = df[w2].values
    w3 = df[w3].values
    w4 = df[w4].values

    z_t_model = sm.OLS(z, sm.add_constant(np.column_stack((t, w1, w2, w3, w4)))).fit()

    z_bar = z_t_model.predict(sm.add_constant(np.column_stack((t, w1, w2, w3, w4))))
    z_prime = z - z_bar

    y_z_model = sm.OLS(y, sm.add_constant(z_prime)).fit()

    return y_z_model.params[1] * z_t_model.params[1]


if __name__ == '__main__':
    N_list = [500, 1000, 2000, 4000, 6000, 8000, 10000, 20000]

    for N in N_list:
        print("Finished:", str(N))
        zf_res, zf_real = run(N, 0)

        file1 = pd.DataFrame(zf_res)
        file2 = pd.DataFrame(zf_real)
        file1.to_csv(r'./Res/CFDiVAE/CFDiVAE_learned_' + str(N) + '.csv', sep=',', float_format='%.5f')
        file2.to_csv(r'./Res/CFDiVAE/CFDiVAE_truth_' + str(N) + '.csv', sep=',', float_format='%.5f')
