from data import SimData_Missing, miss_simulated_preprocess, true_ate
from torch.utils.data import DataLoader
import numpy as np
import os
import json

def error(missing_pattern, path):
    ate_true, ate_est = np.zeros(10), np.zeros(10)
    # mse for missing at random
    for seed in range(1, 11, 1):
        print('iter', seed)

        # load data
        data = SimData_Missing(seed, missing_pattern)
        _, _, test_set = miss_simulated_preprocess(data, 1)
        test_data = DataLoader(test_set, batch_size=1000)

        # calculate true ate
        for y, treat, _, _, y_count in test_data:
            ate = true_ate(y, treat, y_count)
        print("true ate", ate)
        ate_true[seed-1] = ate

        # load estimated ate
        base_path = os.path.join('.', 'sim_miss', 'result', missing_pattern, str(seed))
        file_dir = os.path.join(base_path, path)
        bic = np.loadtxt(os.path.join(file_dir, 'Overall_BIC.txt'))
        ate = np.loadtxt(os.path.join(file_dir, 'Overall_ATE.txt'))
        ate_est[seed-1] = ate[np.argmin(bic)]
        print('ate_est', ate_est[seed-1])

    # calculate absolute error
    abs_error = abs(ate_est - ate_true)
    mean = abs_error.mean()
    sd = np.sqrt(abs_error.var()/10)

    return mean, sd


path = "sim_miss"

# error for mar
eps_mean_mar, eps_sd_mar = error('mar', path)
print("erorr mean mar", eps_mean_mar)
print("error sd mar", eps_sd_mar)

# error for mnar
eps_mean_mnar, eps_sd_mnar = error('mnar', path)
print("erorr mean mnar", eps_mean_mnar)
print("error sd mnar", eps_sd_mnar)

file_name = 'sim_miss_result' + '.json'
result = dict(mar=dict(error_mean=eps_mean_mar, error_sd=eps_sd_mar),
              mnar=dict(error_mean=eps_mean_mnar, error_sd=eps_sd_mnar))

result_file = open(os.path.join('./sim_miss/result', file_name), "w")
json.dump(result, result_file, indent="")
result_file.close()
