import copy
import json
import os
import pickle
import numpy

import numpy as np
from matplotlib import pyplot as plt



import pyAgrum as gum

import random

from ModularUtils.Experiment_Class import Experiment
from ModularUtils.FunctionsConstant import top_sort_dict, getdoKey
from ModularUtils.FunctionsTraining import top_sort_list
from Train_By_Components.Causal_TrainGraph import set_trainGraph

trial = 0

def plot_labels(label_name, *args):
    i = 0
    colors = ['r', 'g']
    for ara in args:
        data = ara

        weights = np.ones_like(data) / len(data)
        plt.hist(data, weights=weights)
        # plt.hist(data, density=True, bins=30)  # density=False would make counts
        i += 1

    plt.xlabel(label_name)
    plt.show()


def get_funcparents():
    return {"X1": ["U1"], "X2": ["U2", "X1"], "W": ["X1", "X2"], "Ydigit1": ["W"], "Ydigit2": ["W"],
            "Ycolor": ["U2"], "Ythick": ["W"]}  # cpt tables are produced according to this relations




def R(Exp, label, feature):
    val = random.randint(0, Exp.label_dim[label] - 1)
    lst = [0 for ii in range(Exp.label_dim[label])]
    lst[val] = 1
    return lst


def R_func(Exp, cpt, label, lb_dim):
    newlst = []
    used = {}
    zero_lst = {}

    allowed_noise = Exp.allowed_noise
    rem_dim = 1
    func_pars = get_funcparents()

    del_vars = []
    # for key in cpt.var_names:
    #     if key not in func_pars[label] + [label]:
    #         del_vars.append(key)
    #         rem_dim = rem_dim * Exp.label_dim[key]["feature"]

    distribute = int(rem_dim * allowed_noise)
    per_state = int(distribute / Exp.label_dim[label])

    lst = [st for st in range(Exp.label_dim[label])] * per_state

    for i in cpt.loopIn():
        dictt = copy.deepcopy(i.todict())
        #         dic22 = copy.deepcopy(i.todict())

        del dictt[label]
        key = tuple(sorted(dictt.items()))
        #         print(key)
        if key in used:
            continue
        used[key] = 1

        for var in del_vars:
            del dictt[var]

        key22 = tuple(sorted(dictt.items()))
        if key22 not in zero_lst:
            zero_lst[key22] = copy.deepcopy(lst)

        if len(zero_lst[key22]) > 0:
            res = zero_lst[key22].pop(0)
        else:
            # getting value according to function
            # label_func = "f_" + label
            # res = globals()[label_func](dictt)
            res= random.randint(0, Exp.label_dim[label] - 1)

        cur_row = [0 for i in range(lb_dim)]
        #       noisy_value = random.randint(0,lb_dim - 1) if random.random() < corr_thresh[label] else res
        noisy_value = res
        cur_row[noisy_value] = 1
        newlst += cur_row

    return newlst


def get_true_noise_dist(Exp):
    with open(Exp.SCM_PATH) as f:
        data = f.read()
    INSTANCE = json.loads(data)
    noise_dist = INSTANCE["noise_dist"]

    return noise_dist


def test_cpt_nonzero_prob(Exp, bn, target, parents, feature):
    ie = gum.LazyPropagation(bn)
    var_set = set([target] + parents)
    ie.addJointTarget(var_set)
    ie.makeInference()
    r_interv_org = ie.evidenceImpact(target, parents)
    r_joint_dist = ie.evidenceJointImpact([target] + parents, [])
    # r_interv_org = ie.evidenceJointImpact([target], parents)

    # print("r_interv_org",r_interv_org)
    min_val = min(r_interv_org.toarray().ravel())
    max_val = max(r_interv_org.toarray().ravel())

    thresh = 1 / numpy.prod([Exp.label_dim[lb] for lb in [target] + parents])
    # print(
    #     f'min:{round(min_val, 5)} max:{round(max_val, 5)} but thresh:{round(thresh, 5)} for {target}|{list(parents)} dimension: {feature}')
    # if min_val<= min(0.001,thresh):
    if min_val <= thresh / 3:
        return False
    print("passed")
    # print("joint distribution", r_joint_dist.toarray().ravel())

    return True


# implement intervened and latents
def get_bayesian_network(Exp, intervened, load_scm):
    INSTANCE = {}
    INSTANCE["cpt"] = {}

    for label in Exp.complete_labels:
        INSTANCE["cpt"][label] = {"feature": {}}

    noise_dist = {}
    if load_scm == 1:
        with open(Exp.SCM_PATH) as f:
            data = f.read()
        INSTANCE = json.loads(data)
        noise_dist = INSTANCE["noise_dist"]
    elif load_scm == 0:
        for noise in Exp.noise_params:
            const, states = Exp.noise_params[noise]
            noise_dist[noise] = np.random.dirichlet(const * np.ones(states), size=1)[0].tolist()
        INSTANCE["noise_dist"] = noise_dist

    bn_dict = {"feature": gum.BayesNet(Exp.Complete_DAG_desc)}

    for label in Exp.Complete_DAG:
        if label in noise_dist:
            bn_dict["feature"].add(label, Exp.latent_state)  # Exp.latent_state
        else:
            bn_dict["feature"].add(label, Exp.label_dim[label])  # Exp.label_state

        for parent in Exp.Complete_DAG[label]:
            for feat in bn_dict:
                bn_dict[feat].addArc(*(parent, label))

    for label in Exp.exogenous:
        for feat in bn_dict:
            bn_dict[feat].add(Exp.exogenous[label], Exp.noise_states)  # Exp.noise_states
            bn_dict[feat].addArc(*(Exp.exogenous[label], label))

    # assign probabilities to noise
    for noise in noise_dist:
        for feat in bn_dict:
            # print(noise, len(noise_dist[noise]))
            bn_dict[feat].cpt(noise).fillWith(noise_dist[noise])

    for label in Exp.complete_labels:

        for feature in bn_dict:

            successful = False
            while not successful:
                probs = []

                if load_scm == 1:
                    probs = INSTANCE["cpt"][label][feature]

                elif label in Exp.label_names:
                    probs= R_func(Exp, bn_dict[feature].cpt(label), label, Exp.label_dim[label][feature])

                else:
                    domainSize = int(bn_dict[feature].cpt(label).domainSize() / Exp.label_dim[label])

                    for i in range(domainSize):
                        output = R(Exp, label, feature)
                        probs += output



                INSTANCE["cpt"][label][feature] = probs

                if label in noise_dist:  # common confounders
                    bn_dict[feature].cpt(label).fillWith(noise_dist[label])
                    successful = True
                    continue


                for i, ins in enumerate(bn_dict[feature].cpt(label).loopIn()):
                    bn_dict[feature].cpt(label).set(ins, probs[i])


                successful = test_cpt_nonzero_prob(Exp, bn_dict[feature], label, Exp.Observed_DAG[label], feature)


    # fix till now
    if load_scm == 0:
        with open(Exp.SCM_PATH, 'w') as fp:
            fp.write(json.dumps(INSTANCE))

    # Intervention, Set X=x
    for var in intervened:
        for parent in bn_dict['feature'].cpt(var).var_names:
            if parent != var:
                bn_dict['feature'].eraseArc(*(parent, var))

        lst = [0 for i in range(Exp.label_dim[var])]
        lst[intervened[var]] = 1
        bn_dict['feature'].cpt(var).fillWith(lst)

    return bn_dict, INSTANCE
    # return bn_dict['feature']


def get_bn(Exp, bn2, intervened):
    for var in intervened:
        for parent in bn2.cpt(var).var_names:
            if parent != var:
                bn2.eraseArc(*(parent, var))

        lst = [0 for i in range(Exp.label_dim[var])]
        lst[intervened[var]] = 1
        bn2.cpt(var).fillWith(lst)

    return bn2


def check_queries(Exp, bn):
    # observed_vars = ["Ydigit1","Ydigit2", "Ycolor", "Ythick"]
    observed_vars = ["Ydigit1", "Ydigit2"]

    # P(Y)
    true_bn_cond = gum.BayesNet(bn)
    ie = gum.LazyPropagation(true_bn_cond)
    var_set = set(Exp.label_names)
    ie.addJointTarget(var_set)
    ie.makeInference()
    r_marg_org = ie.evidenceJointImpact(observed_vars, [])

    #     joint prob
    ie = gum.LazyPropagation(bn)
    ie.addJointTarget(set(Exp.label_names))
    ie.makeInference()
    joint_post = ie.jointPosterior(set(Exp.label_names))

    # P(Y|do(X1=0, X2=0))
    true_bn = gum.BayesNet(bn)
    true_bn_interv = get_bn(Exp, true_bn, {"X1": 1, "W": 4})

    ie = gum.LazyPropagation(true_bn_interv)
    var_set = set(Exp.label_names)
    ie.addJointTarget(var_set)
    ie.makeInference()
    r_interv_org = ie.evidenceJointImpact(observed_vars, [])

    print(ie.evidenceJointImpact(["X1"], []),
          ie.evidenceJointImpact(["X2"], []),
          ie.evidenceJointImpact(["W"], []),
          ie.evidenceJointImpact(["Ydigit1"], []),
          ie.evidenceJointImpact(["Ydigit2"], []),
          ie.evidenceJointImpact(["Ycolor"], []),
          ie.evidenceJointImpact(["Ythick", "W"], [])
          )


def get_synthetic_dist(Exp, joint_vars, bn):
    if len(joint_vars) == 0:
        return 1
    # joint
    ie = gum.LazyPropagation(bn)

    # sorting the dictionary in topological order later
    var_set = set(joint_vars)

    ie.addJointTarget(var_set)
    ie.makeInference()
    # gnb.sideBySide(ie.jointPosterior(var_set),
    #                captions=["Joint posterior $P(" + str(var_set) + ")$"])

    res = ie.jointPosterior(var_set)
    print(res)

    res.domainSize()
    combinations = []
    dist_dict = {}
    for i in res.loopIn():
        # comb= dict(sorted(i.todict().items()))
        comb = top_sort_dict(i.todict(), Exp.Complete_DAG.keys())
        # print(comb, res[i])
        dist_dict[tuple(comb.values())] = res[i]
        combinations.append(list(comb.values()))

    # print("---",dist_dict)

    # getting latent confounders and observed variables.
    varialbes = top_sort_list([i.description() for i in res.variablesSequence()], Exp.Complete_DAG.keys())
    # print(varialbes)

    # print(res)
    res = res.toarray()
    joint_dist = res.ravel()
    combinations = np.array(combinations)
    # print(combinations)

    # print(joint_dist)

    # all will be in topological order
    return varialbes, combinations, joint_dist, dist_dict


def get_cond_synthetic_dist(joint_vars, conditions, arrangeKeys, bn):
    # joint
    ie = gum.LazyPropagation(bn)

    # sorting the dictionary in arrangeKeys order
    var_set = set(joint_vars)

    ie.addJointTarget(var_set)
    ie.makeInference()

    res = ie.evidenceJointImpact(joint_vars, conditions)
    print(res)

    res.domainSize()
    combinations = []
    dist_dict = {}
    for i in res.loopIn():
        # comb= dict(sorted(i.todict().items()))
        comb = top_sort_dict(i.todict(), arrangeKeys)
        # print(comb, res[i])
        dist_dict[tuple(comb.values())] = res[i]
        combinations.append(list(comb.values()))

    # print("---",dist_dict)

    # getting latent confounders and observed variables.
    varialbes = top_sort_list([i.description() for i in res.variablesSequence()], arrangeKeys)

    res = res.toarray()
    joint_dist = res.ravel()
    combinations = np.array(combinations)

    # all will be in arrangeKeys order
    return varialbes, combinations, joint_dist, dist_dict


def generate_synthetic_samples(Exp, feature, varialbes, combinations, joint_dist, sample_size):
    print("For " + feature + " Producing samples:", sample_size, " joint dist:", joint_dist)
    print(sum(joint_dist))

    samples = np.random.choice(len(joint_dist), sample_size, p=joint_dist)
    observations = combinations[samples]
    observations = np.array(observations)

    # print(observations[0:30])

    label_dict = {}
    for id in range(len(varialbes)):
        plot_labels(varialbes[id] + feature, observations[:, id])
        label_dict[varialbes[id]] = observations[:, id]

    # print(label_dict)

    return label_dict


def save_datasets(SAVE_DATASET, label_save_dir, feature, true_data):
    if SAVE_DATASET == False:
        return

    for label in true_data:
        file_name = label_save_dir + label + feature + ".pkl"
        with open(file_name, 'wb') as fp:
            pickle.dump(np.array(true_data[label]), fp)
        print(file_name, " saved")


def produce_datasets(Exp, load_scm, SAVE_DATASET=True):
    # samples distribution only here
    intervened = Exp.Data_intervs[0]
    bn_dict, INSTANCES = get_bayesian_network(Exp, intervened, load_scm)
    # print(bn_dict)
    # check_queries(Exp, bn_dict["feature"])

    return_dict = {}
    bn_dict, INSTANCES = get_bayesian_network(Exp, intervened, load_scm=1)
    vars, combs, jdist, dist_dict = get_synthetic_dist(Exp, Exp.label_names, bn_dict["feature"])
    return_dict["feature"] = [vars, combs, jdist, dist_dict]

    intv_dist= get_expected_true_intervs(Exp, load_scm)
    print("intv_dist", intv_dist)
    cond_dist = get_cond_synthetic_dist(["Y0"], ["W0"], Exp.label_names, bn_dict["feature"])
    cond_dist = get_cond_synthetic_dist(["Y1"], ["W1"], Exp.label_names, bn_dict["feature"])
    # print("cond_dist", cond_dist)



    true_obs = generate_synthetic_samples(Exp, "feature", vars, combs, jdist, Exp.Synthetic_Sample_Size)
    label_save_dir = Exp.file_roots[0]

    save_datasets(SAVE_DATASET, label_save_dir, "feature", true_obs)

    for i in range(1, Exp.num_datasets):
        print(i, Exp.Data_intervs[i])
        bn_dict, INSTANCES = get_bayesian_network(Exp, Exp.Data_intervs[i], load_scm=1)
        vars, combs, jdist, dist_dict = get_synthetic_dist(Exp, Exp.label_names, bn_dict["feature"])

        each_intv_sample_size = Exp.intv_Sample_Size

        true_obs = generate_synthetic_samples(Exp, "feature", vars, combs, jdist, each_intv_sample_size)

        label_save_dir = Exp.file_roots[i]
        save_datasets(SAVE_DATASET, label_save_dir, "feature", true_obs)

    if SAVE_DATASET:
        with open(Exp.file_roots[0] + "scm.txt", 'w') as fp:
            fp.write(json.dumps(INSTANCES))


def save_queryscm(Exp, query):
    intervention = query["interv"]

    true_bn, _ = get_bayesian_network(Exp, intervention, load_scm=1)
    _, _, _, true_dist_dict = get_synthetic_dist(Exp, query["obs"], true_bn["feature"])

    # scms_list
    save_dist_dict = {}
    for dist in true_dist_dict:
        save_dist_dict[str(dist)] = true_dist_dict[dist]

    file_name = Exp.Intv_SCMs + str(intervention) + ".txt"
    print(f"Saving {query['obs']}|do({intervention}) at {file_name}")
    with open(file_name, 'w') as fp:
        fp.write(json.dumps(save_dist_dict))


def get_intv_dist(Exp, obs_vars, intv_key, expr, load_scm):
    # loading
    # vals = "".join(str(x) for x in intv_key.values())
    # file_name = Exp.Intv_SCMs + expr + "_" + vals + ".txt"

    file_name = Exp.Intv_SCMs + expr +  ".txt"
    if load_scm==1:
        if os.path.exists(file_name):
            with open(file_name) as f:
                data = f.read()
            save_dist_dict = json.loads(data)
            true_dist = {eval(dist): save_dist_dict[dist] for dist in save_dist_dict}
            return true_dist

    true_bn, _ = get_bayesian_network(Exp, intv_key, load_scm=1)
    _, _, _, true_dist_dict = get_synthetic_dist(Exp, obs_vars, true_bn["feature"])

    save_dist_dict = {}
    for dist in true_dist_dict:
        save_dist_dict[str(dist)] = true_dist_dict[dist]
    print(f"Saving {expr} at {file_name}")
    with open(file_name, 'w') as fp:
        fp.write(json.dumps(save_dist_dict))

    return true_dist_dict



def get_expected_true_intervs(Exp, load_scm):
    true_expected_dist = {}
    for query in Exp.interv_queries:
        for intv_key in query["intervs"]:
            query_string= getdoKey(query["obs"], intv_key)
            true_dist= get_intv_dist(Exp, query["obs"], intv_key, query_string, load_scm)

            print(query_string, "->", true_dist)

            for key_dist in true_dist:
                if key_dist not in true_expected_dist:
                    true_expected_dist[key_dist]=0
                true_expected_dist[key_dist]+= true_dist[key_dist]* (1/len(query["intervs"]))

    return  true_expected_dist


if __name__ == '__main__':
    # configuration starts
    # lat_dim = 16
    Exp = Experiment("Exp1", set_trainGraph,
                     dist_thresh=0.15,
                     Synthetic_Sample_Size=10000,
                     intv_Sample_Size=10000,
                     batch_size=200,
                     features=["feature"],
                     noise_states=100,
                     latent_state=16,
                     Data_intervs=[{}],
                     new_experiment=False,
                     allowed_noise=0.25,
                     obs_state=3,
                     )

    produce_datasets(Exp, load_scm=1, SAVE_DATASET=True)


# get_cond_synthetic_dist


