import copy

import networkx as nx
import matplotlib.pyplot as plt

from ModularUtils.ControllerConstants import generate_permutations
from ModularUtils.FunctionsConstant import build_compares


def getdoKey(obs_Var, intv_key):
    query_str = "P(" + "".join(x for x in obs_Var) + "|do"

    if type(intv_key) == dict:
        query_str = query_str + "".join(x for x in intv_key.keys()) + "_" + "".join(str(x) for x in intv_key.values())
    else:
        query_str= query_str + "".join(x for x in intv_key)

    if len(intv_key)==0:
        query_str+="[]"
    query_str+=")"
    return query_str

def set_mnist_nonId_newgraph(noise_states, latent_state, obs_state, Data_intervs):
    DAG_desc = "mnist_nonId_newgraph"

    Complete_DAG_desc = "mnist_nonId_newgraph"
    Complete_DAG = {}
    Complete_DAG["U1"] = []
    Complete_DAG["U2"] = []
    Complete_DAG["X1"] = ["U1"]
    Complete_DAG["X2"] = ["U1", "U2", "X1"]
    Complete_DAG["W"] = ["X1", "X2"]
    Complete_DAG["Ydigit1"] = ["W"]
    Complete_DAG["Ydigit2"] = ["W"]
    Complete_DAG["Ycolor"] = ["U2", "W"]
    Complete_DAG["Ythick"] = ["W"]
    # Complete_DAG["ImgYdigit1"] = ["U2", "W"]
    # Complete_DAG["ImgYdigit2"] = ["U2", "W"]
    complete_labels = list(Complete_DAG.keys())


    Observed_DAG = {}
    Observed_DAG["X1"] = []
    Observed_DAG["X2"] = ["X1"]
    Observed_DAG["W"] = ["X1", "X2"]
    Observed_DAG["Ydigit1"] = ["W"]
    Observed_DAG["Ydigit2"] = ["W"]
    Observed_DAG["Ycolor"] = ["W"]
    Observed_DAG["Ythick"] = ["W"]
    # Observed_DAG["ImgYdigit1"] = ["W"]
    # Observed_DAG["ImgYdigit2"] = ["W"]
    label_names = list(Observed_DAG.keys())

    # image_labels= ["ImgYdigit1", "ImgYdigit2"]
    image_labels= []


    label_dim = {
        "U1":  latent_state,
        "U2":  latent_state,
        "X1":  2,  # [0,1]
        "X2":  9,  # [0,8]
        "X1p":  2,  # [0,1]
        "X2p":  9,  # [0,8]
        "W":  10,  # [0,9]
        "Ydigit1":  10,  # [0,9]
        "Ydigit2":  10,  # [0,9]
        "Ycolor": 3,  # [0,2]
        "Ythick":  2  # [0,1]
    }



    var_list= copy.deepcopy(label_names)
    var_list.remove("X1")
    # var_list= ["X1", "X2", "W", "Ycolor"]
    # var_list= ["Ycolor"]
    intervention_list = [{"obs": var_list, "inter_vars": ["X1"]}]

    for lid in range(len(intervention_list)):
        intervention_list[lid]["expr"]= getdoKey(intervention_list[lid]["obs"], intervention_list[lid]["inter_vars"])

    # intervention_list[0]["expr"]= getdoKey(intervention_list[0]["obs"], intervention_list[0]["inter_vars"])
    # intervention_list[1]["expr"]= getdoKey(intervention_list[1]["obs"], intervention_list[1]["inter_vars"])

    interv_queries = []
    for intervention in intervention_list:
        perms = generate_permutations([label_dim[lb] for lb in intervention["inter_vars"]])
        key_val = [dict(zip(intervention["inter_vars"], comb)) for comb in perms]
        interv_queries.append({"obs": intervention["obs"], "intervs": key_val, "expr": intervention["expr"]})



    cf_list = [
        {"intv": ["X1", "X2"], "evid": ["X1p", "X2p"], "expr": "P(Ycolor|do(X1,X2),X1p, X2p)"}]

    obs_vars = ["Ycolor"]
    cf_queries = []
    for cf in cf_list:
        perms = generate_permutations([label_dim[lb] for lb in cf["intv"]]).tolist()

        intv_key_val = [dict(zip(cf["intv"], comb)) for comb in perms]

        perms = generate_permutations([label_dim[lb] for lb in cf["evid"]]).tolist()
        ev_key_val = [dict(zip(cf["evid"], comb)) for comb in perms]

        cf_queries.append({"obs": obs_vars, "intervs": intv_key_val, "evidence":ev_key_val, "expr": cf["expr"]})



    latent_conf = {"X1": ["U1"], "X2": ["U1", "U2"], "W": [], "Ydigit1": [], "Ydigit2": [], "Ycolor": ["U2"],
                   "Ythick": []}

    confTochild = {"U1": ["X1", "X2"], "U2": ["X2", "Ycolor"]}

    exogenous = {}
    for label in label_names:
        if label not in image_labels:
            exogenous[label] = "n" + label



    # counterfactual variables
    cflabel_names = ["U1", "U2", "X1", "X1p", "X2", "X2p", "W", "Ydigit1", "Ydigit2", "Ycolor", "Ythick"]
    Twin_Network = {}
    Twin_Network["U1"] = []
    Twin_Network["U2"] = []
    Twin_Network["X1"] = []
    Twin_Network["X2"] = []
    Twin_Network["X1p"] = []
    Twin_Network["X2p"] = ["U1", "U2", "X1p"]
    Twin_Network["W"] = ["X1", "X2"]
    Twin_Network["Ydigit1"] = ["W"]
    Twin_Network["Ydigit2"] = ["W"]
    Twin_Network["Ycolor"] = ["U2", "W"]
    Twin_Network["Ythick"] = ["W"]
    cf_exogenous = {"X2p": "nX2", "W": "nW", "Ydigit1": "nYdigit1", "Ydigit2": "nYdigit2", "Ycolor": "nYcolor","Ythick": "nYthick"}

    cf_intervene = {"X1": 1, "X2": 5}
    # cf_observe = ["Ydigit1", "Ydigit2", "Ycolor", "Ythick"]
    cf_observe = [ "Ycolor"]
    cf_evidence = {"X1p": 1, "X2p": 1}



    twin_map = {"X1p": "X1", "X1": "X1p", "X2p": "X2", "X2": "X2p"}



    noise_params = {"nX1": (0.1, noise_states),
                    "nX2": (0.1, noise_states),
                    "nW": (0.1, noise_states),
                    "nYdigit1": (0.1, noise_states),
                    "nYdigit2": (0.1, noise_states),
                    "nYcolor": (0.1, noise_states),
                    "nYthick": (0.1, noise_states),
                    # "nImgYdigit1": (0.1, noise_states),
                    # "nImgYdigit2": (0.1, noise_states),

                    "U1": (1, latent_state),
                    "U2": (1, latent_state)}


    # mechanism training
    intervention_datavar = []  # I cant concatenate different intvened variables distributions.


    train_mech_dict={}
    for dist in Data_intervs:
        comp_dict= build_compares(confTochild, Observed_DAG, label_names, list(dist.keys()))
        for label in label_names:
            if label not in train_mech_dict:
                train_mech_dict[label]=[]

            mech_dict = {"parents": Observed_DAG[label], "intv": dist, "compare":comp_dict[label]}
            if label in image_labels:
                continue
            train_mech_dict[label].append(mech_dict)

    #image labels are a little different different then labels as mechanism itself is not included in "compare"
    # train_mech_dict["ImgYdigit1"]=[{'parents': ['W'], 'intv': {}, 'compare': ['X1', 'X2', 'W']}]
    # train_mech_dict["ImgYdigit1"]=[{'parents': ['W'], 'intv': {}, 'compare': ['W']}]
    # train_mech_dict["ImgYdigit2"]=[{'parents': ['W'], 'intv': {}, 'compare': ['W']}] #ideally I need ImgYdigit2 also.

    print("printing")
    for label in label_names:
        print(label, train_mech_dict[label])

    # compare={
    #     "X1":{}
    # }






    for label in Observed_DAG:
        if label not in image_labels:
            label_dim["n" + label] =  noise_states


    rep_labels=[]
    plot_title="mnist_nonId_newgraph"

    return DAG_desc, Complete_DAG_desc, Complete_DAG, complete_labels, Observed_DAG, label_names, image_labels, rep_labels, interv_queries, cf_queries, latent_conf, \
           confTochild, exogenous, cf_intervene, cf_observe, cf_evidence, cflabel_names, twin_map, Twin_Network, cf_exogenous, \
           noise_params, train_mech_dict, label_dim, plot_title






