import copy
import json
import os
import pickle
import numpy

import numpy as np
from matplotlib import pyplot as plt



import pyAgrum as gum

import random

from Causal_Partial_Mnist.CausalGraph_Mnist import set_mnist_nonId_newgraph
from ModularUtils.ControllerConstants import generate_permutations
from ModularUtils.Experiment_Class import Experiment
from ModularUtils.FunctionsConstant import top_sort_dict
from ModularUtils.FunctionsTraining import top_sort_list

trial = 0


# mnist_state = {"digit1": 10, "digit2": 10, "sign": 2, "result": 10}


def plot_labels(label_name, *args):
    i = 0
    colors = ['r', 'g']
    for ara in args:
        data = ara

        weights = np.ones_like(data) / len(data)
        plt.hist(data, weights=weights)
        # plt.hist(data, density=True, bins=30)  # density=False would make counts
        i += 1

    plt.xlabel(label_name)
    plt.show()


def get_funcparents():
    return {"X1": ["U1"], "X2": ["U2", "X1"], "W": ["X1", "X2"], "Ydigit1": ["W"], "Ydigit2": ["W"],
            "Ycolor": ["U2"], "Ythick": ["W"]}  # cpt tables are produced according to this relations


def f_X1(dictt):  # [0,8] , X1-> X2 <-->U2,
    res = random.choice([0, 1])
    return res


def f_X2(dictt):  # [0,8] , X1-> X2 <-->U2,

    if dictt["X1"] == 0:
        res = 0 + dictt["U2"] % 3  # [0,1,2]
    elif dictt["X1"] == 1:
        res = random.choice([3, 6]) + dictt["U2"] % 3  # [3,4,5,6,7,8]

    return res


def f_W(dictt):  # W= X1 + X2
    return (dictt["X1"] + dictt["X2"])


def f_Ydigit1(dictt):
    return int((dictt["W"] * dictt["W"]) / 10)


def f_Ydigit2(dictt):
    return (dictt["W"] * dictt["W"]) % 10


def f_Ycolor(dictt):  # [0,2] , W-> color <-->U2,
    return dictt["U2"] % 3  # [0,1,2]


def f_Ythick(dictt):  # [0,1] , W-> thickness
    return 0 if int(dictt["W"] * dictt["W"] / 10) <= 1 else 1  # W/10= [0,1] ->1 ,W/10= [2, 9] -> 1


# func_pars =get_funcparents
# def R_func(Exp, cpt, label, label_dim):
#     allowed_noise = 0.25
#
#     newlst = []
#     used = {}
#     zero_lst={}
#
#     for i in cpt.loopIn():
#         dictt = i.todict()
#
#         label_func = "f_" + label
#         res = globals()[label_func](dictt)
#
#         del dictt[label]
#         key = tuple(sorted(dictt.items()))
#         if key in used:
#             continue
#         used[key] = 1
#
#         cur_row = [0 for i in range(label_dim)]
#         noisy_value = random.randint(0, label_dim - 1) if random.random() < Exp.corr_thresh[label] else res
#         #             noisy_value =  res
#
#         cur_row[noisy_value] = 1
#         newlst += cur_row
#
#     return newlst

def R_func(Exp, cpt, label, lb_dim):
    newlst = []
    used = {}
    zero_lst = {}

    allowed_noise = Exp.allowed_noise
    rem_dim = 1
    func_pars = get_funcparents()

    del_vars = []
    for key in cpt.var_names:
        if key not in func_pars[label] + [label]:
            del_vars.append(key)
            rem_dim = rem_dim * Exp.label_dim[key]["feature"]

    distribute = int(rem_dim * allowed_noise)
    per_state = int(distribute / Exp.label_dim[label]["feature"])

    lst = [st for st in range(Exp.label_dim[label]["feature"])] * per_state

    for i in cpt.loopIn():
        dictt = copy.deepcopy(i.todict())
        #         dic22 = copy.deepcopy(i.todict())

        del dictt[label]
        key = tuple(sorted(dictt.items()))
        #         print(key)
        if key in used:
            continue
        used[key] = 1

        for var in del_vars:
            del dictt[var]

        key22 = tuple(sorted(dictt.items()))
        if key22 not in zero_lst:
            zero_lst[key22] = copy.deepcopy(lst)

        if len(zero_lst[key22]) > 0:
            res = zero_lst[key22].pop(0)
        else:
            # getting value according to function
            label_func = "f_" + label
            res = globals()[label_func](dictt)

        cur_row = [0 for i in range(lb_dim)]
        #       noisy_value = random.randint(0,lb_dim - 1) if random.random() < corr_thresh[label] else res
        noisy_value = res
        cur_row[noisy_value] = 1
        newlst += cur_row

    return newlst


def R(Exp, label, feature):
    val = random.randint(0, Exp.label_dim[label][feature] - 1)
    lst = [0 for ii in range(Exp.label_dim[label][feature])]
    lst[val] = 1
    return lst


def get_true_noise_dist(Exp):
    with open(Exp.SCM_PATH) as f:
        data = f.read()
    INSTANCE = json.loads(data)
    noise_dist = INSTANCE["noise_dist"]

    return noise_dist


def test_cpt_nonzero_prob(Exp, bn, target, parents, feature):
    ie = gum.LazyPropagation(bn)
    var_set = set([target] + parents)
    ie.addJointTarget(var_set)
    ie.makeInference()
    r_interv_org = ie.evidenceImpact(target, parents)
    r_joint_dist = ie.evidenceJointImpact([target] + parents, [])
    # r_interv_org = ie.evidenceJointImpact([target], parents)

    # print("r_interv_org",r_interv_org)
    min_val = min(r_interv_org.toarray().ravel())
    max_val = max(r_interv_org.toarray().ravel())

    thresh = 1 / numpy.prod([Exp.label_dim[lb][feature] for lb in [target] + parents])
    # print(
    #     f'min:{round(min_val, 5)} max:{round(max_val, 5)} but thresh:{round(thresh, 5)} for {target}|{list(parents)} dimension: {feature}')
    # if min_val<= min(0.001,thresh):
    if min_val <= thresh / 3:
        return False
    # print("passed")
    # print("joint distribution", r_joint_dist.toarray().ravel())

    return True


# implement intervened and latents
def get_bayesian_network(Exp, intervened, load_scm):
    INSTANCE = {}
    INSTANCE["cpt"] = {}

    for label in Exp.complete_labels:
        INSTANCE["cpt"][label] = {"feature": {}}

    noise_dist = {}
    if load_scm == 1:
        with open(Exp.SCM_PATH) as f:
            data = f.read()
        INSTANCE = json.loads(data)
        noise_dist = INSTANCE["noise_dist"]
    elif load_scm == 0:
        for noise in Exp.noise_params:
            const, states = Exp.noise_params[noise]
            noise_dist[noise] = np.random.dirichlet(const * np.ones(states), size=1)[0].tolist()
        INSTANCE["noise_dist"] = noise_dist

    bn_dict = {"feature": gum.BayesNet(Exp.Complete_DAG_desc)}

    for label in Exp.Complete_DAG:
        if label in noise_dist:
            bn_dict["feature"].add(label, Exp.latent_state)  # Exp.latent_state
        else:
            bn_dict["feature"].add(label, Exp.label_dim[label]["feature"])  # Exp.label_state

        for parent in Exp.Complete_DAG[label]:
            for feat in bn_dict:
                bn_dict[feat].addArc(*(parent, label))

    for label in Exp.exogenous:
        for feat in bn_dict:
            bn_dict[feat].add(Exp.exogenous[label], Exp.noise_states)  # Exp.noise_states
            bn_dict[feat].addArc(*(Exp.exogenous[label], label))

    # assign probabilities to noise
    for noise in noise_dist:
        for feat in bn_dict:
            # print(noise, len(noise_dist[noise]))
            bn_dict[feat].cpt(noise).fillWith(noise_dist[noise])

    for label in Exp.complete_labels:

        for feature in bn_dict:

            successful = False
            while not successful:
                probs = []

                if load_scm == 1:
                    probs = INSTANCE["cpt"][label][feature]
                elif label in ["X2", "W", "Ydigit1", "Ydigit2", "Ycolor", "Ythick"]:
                    probs = R_func(Exp, bn_dict[feature].cpt(label), label, Exp.label_dim[label][feature])

                else:
                    domainSize = int(bn_dict[feature].cpt(label).domainSize() / Exp.label_dim[label][feature])

                    for i in range(domainSize):
                        output = R(Exp, label, feature)
                        probs += output

                        # print("output length", len(output))

                INSTANCE["cpt"][label][feature] = probs

                if label in noise_dist:  # common confounders
                    bn_dict[feature].cpt(label).fillWith(noise_dist[label])
                    successful = True
                    continue

                #                 print(label)
                #                 print(int(bn_dict[feature].cpt(label).domainSize()))
                #                 print(len(probs))
                for i, ins in enumerate(bn_dict[feature].cpt(label).loopIn()):
                    bn_dict[feature].cpt(label).set(ins, probs[i])

                #                 gnb.flow.row(bn.cpt("Ythick"))

                # successful = test_cpt_nonzero_prob(Exp, bn_dict[feature], label, Exp.Observed_DAG[label], feature)
                # if label in ["W", "Ydigit1", "Ydigit2", "Ythick"]:
                successful = True

    # fix till now
    if load_scm == 0:
        with open(Exp.SCM_PATH, 'w') as fp:
            fp.write(json.dumps(INSTANCE))

    # Intervention, Set X=x
    for var in intervened:
        for parent in bn_dict['feature'].cpt(var).var_names:
            if parent != var:
                bn_dict['feature'].eraseArc(*(parent, var))

        lst = [0 for i in range(Exp.label_dim[var]["feature"])]
        lst[intervened[var]] = 1
        bn_dict['feature'].cpt(var).fillWith(lst)

    return bn_dict, INSTANCE
    # return bn_dict['feature']


def get_bn(Exp, bn2, intervened):
    for var in intervened:
        for parent in bn2.cpt(var).var_names:
            if parent != var:
                bn2.eraseArc(*(parent, var))

        lst = [0 for i in range(Exp.label_dim[var]["feature"])]
        lst[intervened[var]] = 1
        bn2.cpt(var).fillWith(lst)

    return bn2


def check_queries(Exp, bn):
    # observed_vars = ["Ydigit1","Ydigit2", "Ycolor", "Ythick"]
    observed_vars = ["Ydigit1", "Ydigit2"]

    # P(Y)
    true_bn_cond = gum.BayesNet(bn)
    ie = gum.LazyPropagation(true_bn_cond)
    var_set = set(Exp.label_names)
    ie.addJointTarget(var_set)
    ie.makeInference()
    r_marg_org = ie.evidenceJointImpact(observed_vars, [])

    #     joint prob
    ie = gum.LazyPropagation(bn)
    ie.addJointTarget(set(Exp.label_names))
    ie.makeInference()
    joint_post = ie.jointPosterior(set(Exp.label_names))

    # P(Y|do(X1=0, X2=0))
    true_bn = gum.BayesNet(bn)
    true_bn_interv = get_bn(Exp, true_bn, {"X1": 1, "W": 4})

    ie = gum.LazyPropagation(true_bn_interv)
    var_set = set(Exp.label_names)
    ie.addJointTarget(var_set)
    ie.makeInference()
    r_interv_org = ie.evidenceJointImpact(observed_vars, [])

    print(ie.evidenceJointImpact(["X1"], []),
          ie.evidenceJointImpact(["X2"], []),
          ie.evidenceJointImpact(["W"], []),
          ie.evidenceJointImpact(["Ydigit1"], []),
          ie.evidenceJointImpact(["Ydigit2"], []),
          ie.evidenceJointImpact(["Ycolor"], []),
          ie.evidenceJointImpact(["Ythick", "W"], [])
          )


def get_synthetic_dist(Exp, joint_vars, bn):
    if len(joint_vars) == 0:
        return 1
    # joint
    ie = gum.LazyPropagation(bn)

    # sorting the dictionary in topological order later
    var_set = set(joint_vars)

    ie.addJointTarget(var_set)
    ie.makeInference()
    # gnb.sideBySide(ie.jointPosterior(var_set),
    #                captions=["Joint posterior $P(" + str(var_set) + ")$"])

    res = ie.jointPosterior(var_set)
    res.domainSize()
    combinations = []
    dist_dict = {}
    for i in res.loopIn():
        # comb= dict(sorted(i.todict().items()))
        comb = top_sort_dict(i.todict(), Exp.Complete_DAG.keys())
        # print(comb, res[i])
        dist_dict[tuple(comb.values())] = res[i]
        combinations.append(list(comb.values()))

    # print("---",dist_dict)

    # getting latent confounders and observed variables.
    varialbes = top_sort_list([i.description() for i in res.variablesSequence()], Exp.Complete_DAG.keys())
    # print(varialbes)

    # print(res)
    res = res.toarray()
    joint_dist = res.ravel()
    combinations = np.array(combinations)
    # print(combinations)

    # print(joint_dist)

    # all will be in topological order
    return varialbes, combinations, joint_dist, dist_dict


def get_cond_synthetic_dist(joint_vars, conditions, arrangeKeys, bn):
    # joint
    ie = gum.LazyPropagation(bn)

    # sorting the dictionary in arrangeKeys order
    var_set = set(joint_vars)

    ie.addJointTarget(var_set)
    ie.makeInference()

    res = ie.evidenceJointImpact(joint_vars, conditions)
    print(res)

    res.domainSize()
    combinations = []
    dist_dict = {}
    for i in res.loopIn():
        # comb= dict(sorted(i.todict().items()))
        comb = top_sort_dict(i.todict(), arrangeKeys)
        # print(comb, res[i])
        dist_dict[tuple(comb.values())] = res[i]
        combinations.append(list(comb.values()))

    # print("---",dist_dict)

    # getting latent confounders and observed variables.
    varialbes = top_sort_list([i.description() for i in res.variablesSequence()], arrangeKeys)

    res = res.toarray()
    joint_dist = res.ravel()
    combinations = np.array(combinations)

    # all will be in arrangeKeys order
    return varialbes, combinations, joint_dist, dist_dict


def generate_synthetic_samples(Exp, feature, varialbes, combinations, joint_dist, sample_size):
    print("For " + feature + " Producing samples:", sample_size, " joint dist:", joint_dist)
    print(sum(joint_dist))

    samples = np.random.choice(len(joint_dist), sample_size, p=joint_dist)
    observations = combinations[samples]
    observations = np.array(observations)

    # print(observations[0:30])

    label_dict = {}
    for id in range(len(varialbes)):
        plot_labels(varialbes[id] + feature, observations[:, id])
        label_dict[varialbes[id]] = observations[:, id]

    # print(label_dict)

    return label_dict


def save_datasets(SAVE_DATASET, label_save_dir, feature, true_data):
    if SAVE_DATASET == False:
        return

    for label in true_data:
        file_name = label_save_dir + label + feature + ".pkl"
        with open(file_name, 'wb') as fp:
            pickle.dump(np.array(true_data[label]), fp)
        print(file_name, " saved")


def produce_datasets(Exp, load_scm, SAVE_DATASET=True):
    # samples distribution only here
    intervened = Exp.Data_intervs[0]
    bn_dict, INSTANCES = get_bayesian_network(Exp, intervened, load_scm)
    print(bn_dict)
    # check_queries(Exp, bn_dict["feature"])

    return_dict = {}
    bn_dict, INSTANCES = get_bayesian_network(Exp, intervened, load_scm=1)
    vars, combs, jdist, dist_dict = get_synthetic_dist(Exp, Exp.label_names, bn_dict["feature"])
    return_dict["feature"] = [vars, combs, jdist, dist_dict]

    true_obs = generate_synthetic_samples(Exp, "feature", vars, combs, jdist, Exp.Synthetic_Sample_Size)
    label_save_dir = Exp.file_roots[0]

    save_datasets(SAVE_DATASET, label_save_dir, "feature", true_obs)

    for i in range(1, Exp.num_datasets):
        print(i, Exp.Data_intervs[i])
        bn_dict, INSTANCES = get_bayesian_network(Exp, Exp.Data_intervs[i], load_scm=1)
        vars, combs, jdist, dist_dict = get_synthetic_dist(Exp, Exp.label_names, bn_dict["feature"])

        each_intv_sample_size = Exp.intv_Sample_Size

        true_obs = generate_synthetic_samples(Exp, "feature", vars, combs, jdist, each_intv_sample_size)

        label_save_dir = Exp.file_roots[i]
        save_datasets(SAVE_DATASET, label_save_dir, "feature", true_obs)

    if SAVE_DATASET:
        with open(Exp.file_roots[0] + "scm.txt", 'w') as fp:
            fp.write(json.dumps(INSTANCES))


def save_queryscm(Exp, query):
    intervention = query["interv"]

    true_bn, _ = get_bayesian_network(Exp, intervention, load_scm=1)
    _, _, _, true_dist_dict = get_synthetic_dist(Exp, query["obs"], true_bn["feature"])

    # scms_list
    save_dist_dict = {}
    for dist in true_dist_dict:
        save_dist_dict[str(dist)] = true_dist_dict[dist]

    file_name = Exp.Intv_SCMs + str(intervention) + ".txt"
    print(f"Saving {query['obs']}|do({intervention}) at {file_name}")
    with open(file_name, 'w') as fp:
        fp.write(json.dumps(save_dist_dict))


def get_intv_dist(Exp, obs_vars, intv_key, expr):
    # loading
    # vals = "".join(str(x) for x in intv_key.values())
    # file_name = Exp.Intv_SCMs + expr + "_" + vals + ".txt"
    file_name = Exp.Intv_SCMs + expr +  ".txt"
    if os.path.exists(file_name):
        with open(file_name) as f:
            data = f.read()
        save_dist_dict = json.loads(data)
        true_dist = {eval(dist): save_dist_dict[dist] for dist in save_dist_dict}
        return true_dist

    true_bn, _ = get_bayesian_network(Exp, intv_key, load_scm=1)
    _, _, _, true_dist_dict = get_synthetic_dist(Exp, obs_vars, true_bn["feature"])

    save_dist_dict = {}
    for dist in true_dist_dict:
        save_dist_dict[str(dist)] = true_dist_dict[dist]
    print(f"Saving {expr} at {file_name}")
    with open(file_name, 'w') as fp:
        fp.write(json.dumps(save_dist_dict))

    return true_dist_dict




def get_tempo_cf_dist(Exp, targetVars, cf_intervene, cf_evidence):
    ## Counterfactual SCM

    with open(Exp.SCM_PATH) as f:
        data = f.read()
    INSTANCE = json.loads(data)


    bnc = gum.BayesNet("counterfactual SCM")


    for label in Exp.Twin_Network:
        if label in INSTANCE["noise_dist"]:
            bnc.add(label, Exp.latent_state)

        else:
            bnc.add(label, Exp.label_dim[label]["feature"])

        for parent in Exp.Twin_Network[label]:
            bnc.addArc(*(parent, label))

    added_latents = []
    for label in Exp.cf_exogenous:
        if Exp.cf_exogenous[label] not in added_latents:
            bnc.add(Exp.cf_exogenous[label], Exp.noise_states)
            added_latents.append(Exp.cf_exogenous[label])

        bnc.addArc(*(Exp.cf_exogenous[label], label))

    # # assign probabilities to noise
    for noise in INSTANCE["noise_dist"]:
        if noise == "nX1":
            continue
        bnc.cpt(noise).fillWith(INSTANCE["noise_dist"][noise])


    for label in Exp.cflabel_names:
        lb = label
        if label == "X1p":
            lb = "X1"
        if label == "X2p":
            lb = "X2"

        probs = INSTANCE["cpt"][lb]['feature']

        if label in INSTANCE["noise_dist"]:  # common confounders
            bnc.cpt(label).fillWith(INSTANCE["noise_dist"][label])
            continue

        for i, ins in enumerate(bnc.cpt(label).loopIn()):
            bnc.cpt(label).set(ins, probs[i])

        if label in cf_intervene.keys():
            lst = [0 for i in range(Exp.label_dim[label]['feature'])]
            lst[cf_intervene[label]] = 1
            bnc.cpt(label).fillWith(lst)

    # testing if the color is still same
    ie = gum.LazyPropagation(bnc)
    var_set = set(Exp.cflabel_names)
    ie.addJointTarget(var_set)
    ie.makeInference()
    # print(ie.evidenceJointImpact(["Ycolor"], []))


    # Counterfactual Inference
    # P(Y|X1'=1,X2'=1,do(X1=0,X2=0))
    true_bn = gum.BayesNet(bnc)
    true_bn = get_bn(Exp, true_bn, cf_intervene)

    ie = gum.LazyPropagation(true_bn)
    var_set = set(Exp.cflabel_names)
    ie.addJointTarget(var_set)
    ie.makeInference()



    # targetVarsdist=[]
    for Yvar in targetVars:
        cfYvar = gum.getPosterior(true_bn, evs=cf_evidence, target=Yvar)
        Yvar_dict = {}
        for i in cfYvar.loopIn():
            comb = top_sort_dict(i.todict(), Exp.Complete_DAG.keys())
            Yvar_dict[tuple(comb.values())] = cfYvar[i]

    targetVarsdist= Yvar_dict


    return targetVarsdist



if __name__ == '__main__':
    # configuration starts
    # lat_dim = 16
    Exp = Experiment("Exp1", set_mnist_nonId_newgraph,
    # Exp = Experiment("Exp1", set_mnist_random_graph,
                     dist_thresh=0.15,
                     causal_hierarchy=2,
                     noise_states=100,
                     latent_state=16,
                     new_experiment=True,
                     Synthetic_Sample_Size=40000,
                     intv_Sample_Size=40000,
                     Data_intervs=[{}],
                     allowed_noise=0.10
                     )

    # produce_datasets(Exp, load_scm=1, SAVE_DATASET=False)

    # for x2 in range(9):
    #     true_bn, _ = get_bayesian_network(Exp, {"X2": x2}, load_scm=1)
    #     _, _, _, true_dist_dict = get_synthetic_dist(Exp, ["Ycolor"], true_bn["feature"])
    #     print(f" X2: {x2}, {true_dist_dict}")
    #
    # bn_dict, INSTANCES = get_bayesian_network(Exp, {}, load_scm=1)
    # ret = get_cond_synthetic_dist(["Ycolor"], ["X2"], ["Ycolor", "X2"], bn_dict["feature"])
    # print(ret)



    #
    perms = generate_permutations([2,9]).tolist()
    intv_key_val = [dict(zip(["X1", "X2"], comb)) for comb in perms]
    perms = generate_permutations([2,9]).tolist()
    ev_key_val = [dict(zip(["X1p", "X2p"], comb)) for comb in perms]

    for ikey in intv_key_val:
        for evkey in ev_key_val:
            ret= get_tempo_cf_dist(Exp, ["Ycolor"], ikey, evkey)
            print(ret)

    # intv = get_expected_true_intervs(Exp)
    # cf = get_expected_true_cf(Exp)
    # true_intvVscf = calculate_TVD(intv, cf, doPrint=False)


    # produce_result_image(Exp, SHOW=["ImgYdigit1", "ImgYdigit2"], SAVE_DATASET=["ImgYdigit1", "ImgYdigit2"])

    # produce_result_image(Exp, SAVE_DATASET=["ImgYdigit1"])
    # test_result_data(Exp)

    # for query in Exp.interv_queries:
    #     # {"obs": obs_vars, "intervs": key_val, "expr": intervention["expr"]}
    #     for intv_key in query["intervs"]:
    #         x = get_intv_dist(Exp, query["obs"], intv_key, query['expr'])
    #         # print(x.keys())



# get_cond_synthetic_dist