# from ModularUtils.FunctionsConstant import getdoKey
# from ModularUtils.ControllerConstants import generate_permutations


class CausalGraph():

    def __init__(self, name, dag, confs, dims, num_latent):
        self.DAG_desc = name

        self.Complete_DAG_desc = name
        self.Observed_DAG = dag

        self.num_confs = len(confs.keys())
        self.Complete_DAG = {}
        for cnf in range(self.num_confs):
            self.Complete_DAG["U" + str(cnf)] = []

        self.latent_conf = {}
        for var in self.Observed_DAG:
            self.Complete_DAG[var] = []
            self.latent_conf[var] = []

        self.confTochild = confs

        for cnf in self.confTochild:
            for var in self.confTochild[cnf]:
                self.latent_conf[var].append(cnf)
                self.Complete_DAG[var].append(cnf)

        for var in self.Observed_DAG:
            self.Complete_DAG[var] = self.Complete_DAG[var] + self.Observed_DAG[var]

        self.complete_labels = list(self.Complete_DAG.keys())
        self.label_names = list(self.Observed_DAG.keys())

        self.label_dim=dims

        for cnf in self.confTochild:
            self.label_dim[cnf] = num_latent


        self.image_labels= None
        self.rep_labels= None







def set_imageMediator(noise_states, latent_state, obs_state, Data_intervs):


    # if latent_state==-1:
    #     Observed_DAG = {
    #         "U0":[],
    #         "D":['U0'],
    #         "pC":[],
    #         "I":['D', 'pC'],
    #         "C":['U0','I']
    #         }
    #     confTochild = {}
    #     label_dim = {'U0':4, "D": 2, "pC": 2, 'I':4, "C": 2}
    #     G= CausalGraph(name="imageMediator", dag=Observed_DAG, confs=confTochild, dims=label_dim, num_latent=latent_state)
    #     plot_title="Image Mediator Synthetic Experiment"
    #     G.image_labels= []
    #     G.rep_labels= ["RI"]

    # else:
    Observed_DAG = {
        "medD": [],
        "I": ['medD'],
        "medC": ['I'],
        "RI": ["I", "medD"],
    }

    confTochild = {"medU0": ["medD", "medC"]}
    label_dim = {"medD": 2, 'I': 0, "medC": 3, 'C':3, 'RI': 10,}   # Issue here with Images dimension. And also difmensiona of encoder RI
    G = CausalGraph(name="imageMediator", dag=Observed_DAG, confs=confTochild, dims=label_dim,
                    num_latent=latent_state)

    plot_title = "Frontdoor image mediator experiment"
    #
    G.image_labels = ["I"]
    G.rep_labels = ["RI"]



    intervention_list = [{"expr":"P(D,C)" ,"obs":['medD','medC'], "inter_vars":[]},
        {"expr":"P(C|do(D))" ,"obs":['medC'], "inter_vars":['medD']}
                         ]

    for lid in range(len(intervention_list)):
        intervention_list[lid]["expr"] = getdoKey(intervention_list[lid]["obs"], intervention_list[lid]["inter_vars"])

    interv_queries = []
    for intervention in intervention_list:
        perms = generate_permutations([label_dim[lb] for lb in intervention["inter_vars"]])
        key_val = [dict(zip(intervention["inter_vars"], comb)) for comb in perms]
        interv_queries.append({"obs": intervention["obs"], "intervs": key_val, "expr": intervention["expr"]})


    cf_queries = []


    exogenous = {}
    for label in G.label_names:
        if label not in G.image_labels:
            exogenous[label] = "n" + label


    # counterfactual variables
    cflabel_names = []
    Twin_Network = {}

    cf_exogenous = {}

    cf_intervene = {}
    cf_observe = []
    cf_evidence = {}

    twin_map = {}


    noise_params = {}
    for label in Observed_DAG:
        noise_params["n" + label] = (0.5, noise_states)

    for conf in confTochild:
        noise_params[conf] = (0.1, latent_state)


    train_mech_dict={}
    # for dist in Data_intervs:
    #     comp_dict= build_compares(confTochild, Observed_DAG, label_names, list(dist.keys()))
    #     for label in label_names:
    #         if label not in train_mech_dict:
    #             train_mech_dict[label]=[]
    #
    #         mech_dict = {"parents": Observed_DAG[label], "intv": dist, "compare":comp_dict[label]}
    #         if label in image_labels:
    #             continue
    #         train_mech_dict[label].append(mech_dict)

    # train_mech_dict["I"]=[{'parents': ['D'], 'intv': {}, 'compare': ['D', 'I']}]
    # train_mech_dict["D"]=[{'parents': [], 'intv': {}, 'compare': ['D', 'I',  'C']}]
    # train_mech_dict["C"]=[{'parents': [], 'intv': {}, 'compare': ['D', 'I', 'C']}]

    train_mech_dict["medD"] = [{'parents': [], 'intv': {}, 'compare': ['medD',  'medC', 'RI']}]
    train_mech_dict["I"] = [{'parents': ['medD'], 'intv': {}, 'compare': ['I']}]
    train_mech_dict["medC"] = [{'parents': ['I'], 'intv': {}, 'compare': ['medD','medC', 'RI']}]
    train_mech_dict["RI"] = [{'parents': ['I'], 'intv': {}, 'compare': ['medD', 'RI']}]  #wish to train RI but fitting D,RI
    #compare: joint for which variables are needed. parents: which variables i need to intervene on


    # train_mech_dict["W0"][0]['intv']= {"X0":0}
    # train_mech_dict["W1"][0]['intv']= {"X0":0, "X1":0, "X2":0}


    # print("printing")
    # for label in label_names:
    #     print(label, train_mech_dict[label])




    for label in Observed_DAG:
        if label not in G.image_labels:
            label_dim["n" + label] =  noise_states

    return G.DAG_desc, G.Complete_DAG_desc, G.Complete_DAG, G.complete_labels, G.Observed_DAG, G.label_names, G.image_labels, G.rep_labels, interv_queries, cf_queries, G.latent_conf, \
           G.confTochild, exogenous, cf_intervene, cf_observe, cf_evidence, cflabel_names, twin_map, Twin_Network, cf_exogenous, \
           noise_params, train_mech_dict, G.label_dim, plot_title



