# This is a sample Python script.
from data_simulator import simulate_data
from analytic_trunc_fac import analytic_causal_effect_computation
from density_estimation import joint_density_estimation
from generalized_trunc_fac import exch_causal_effect_computation
from iid_trunc_fac import iid_causal_effect_computation
from graph_estimate import causal_de_finetti, iid_cd
from utils import *
import numpy as np

# Press the green button in the gutter to run the script.
if __name__ == '__main__':
    np.random.seed(42)
    nsample = 2
    alpha = 1
    beta = 3
    nexp = 100
    nenv_lists = [100 , 1000, 2000, 5000]
    std_graph_acc_env = dict.fromkeys(nenv_lists, None)
    std_effect_mse_env = dict.fromkeys(nenv_lists, None)
    graph_acc_env = dict.fromkeys(nenv_lists, None)
    effect_mse_env = dict.fromkeys(nenv_lists, None)
    for nenv in nenv_lists:
        graph_acc = {'iid': [], 'CdF': []}
        effect_mse = {'iid': [], 'doF': [], 'doF-ablation': [], 'iid-ablation': []}
        mean_dag_acc = {'iid': 0, 'CdF': 0}
        mean_effect_mse = {'iid': 0, 'doF': 0, 'doF-ablation': 0, 'iid-ablation': 0}
        std_dag_acc = {'iid': 0, 'CdF': 0}
        std_effect_mse = {'iid': 0, 'doF': 0, 'doF-ablation': 0, 'iid-ablation': 0}
        for _ in range(nexp):
            graph = np.random.choice(['xtoy', 'ytox', 'xindy'])
            data = simulate_data(nenv, nsample, alpha, beta, graph)
            oracle_dag = data['true_dag']
            #print(oracle_dag)
            dag_estimate = causal_de_finetti(data)
            dag_estimate_iid = iid_cd(data)
            # test restrict to one element
            intervened_var = np.random.choice(['x1', 'x2', 'y1', 'y2'])
            intervened_val = np.random.choice([0, 1])
            intervention_desc = ('do(' + intervened_var + ')' + '=' + str(intervened_val))
            oracle_trunc_fac = analytic_causal_effect_computation(intervened_var, intervened_val, oracle_dag, alpha, beta)
            joint_density = joint_density_estimation(data)
            iid_trunc_fac = iid_causal_effect_computation(joint_density, intervened_var, intervened_val, dag_estimate_iid)
            generalized_trunc_fac = exch_causal_effect_computation(joint_density, intervened_var, intervened_val, dag_estimate)
            generalized_trunc_fac_given_oracle_dag = exch_causal_effect_computation(joint_density, intervened_var, intervened_val, oracle_dag)
            iid_trunc_fac_given_oracle_dag = iid_causal_effect_computation(joint_density, intervened_var, intervened_val, oracle_dag)
            compute_graph_acc(oracle_dag, dag_estimate, dag_estimate_iid, graph_acc)
            mse_causal_effect(oracle_trunc_fac, generalized_trunc_fac, iid_trunc_fac, generalized_trunc_fac_given_oracle_dag,
                              iid_trunc_fac_given_oracle_dag, effect_mse, intervention_desc)
        for key in graph_acc.keys():
            mean_dag_acc[key] = np.mean(graph_acc[key])
            std_dag_acc[key] = np.var(graph_acc[key])
        for key in effect_mse.keys():
            mean_effect_mse[key] = np.mean(effect_mse[key])
            std_effect_mse[key] = np.var(effect_mse[key])

        std_graph_acc_env[nenv] = std_dag_acc
        std_effect_mse_env[nenv] = std_effect_mse
        graph_acc_env[nenv] = mean_dag_acc
        effect_mse_env[nenv] = mean_effect_mse

    plot(graph_acc_env, effect_mse_env, std_graph_acc_env, std_effect_mse_env)





