import copy
import json

import numpy as np
import torch


from Causal_Partial_Mnist.CausalGraph_Mnist import set_mnist_nonId_newgraph, getdoKey
from Causal_Partial_Mnist.Find_CF_Synthetic_Distribution_Mnist import get_intv_dist
from Causal_Partial_Mnist.RejectionSampling_Optimized import rejection_sampling_optimized
from ModularUtils.ControllerConstants import map_fill_to_discrete, map_dictfill_to_discrete
from ModularUtils.ControllerModel import get_generated_labels
from ModularUtils.FunctionsConstant import asKey
from ModularUtils.FunctionsDistribution import compare_conditionals_within, calculate_TVD, match_with_true_dist
from ModularUtils.FunctionsTraining import save_results


def compare_conditionals(Exp, label_generators, obs_real_dataset, vars):
    generated_labels_dict = get_generated_labels(Exp, label_generators, {}, {}, {}, Exp.label_names,
                                                 Exp.Synthetic_Sample_Size, vars)
    y_dims = sum([Exp.label_dim[lb][feat] for lb in Exp.label_names for feat in Exp.features])
    ret = list(generated_labels_dict.values())
    generated_labels_full = torch.cat(ret, 1).view(-1, y_dims)

    dims_list = [Exp.label_dim[lb][feat] for lb in Exp.label_names for feat in Exp.features]
    generated_labels_full = map_fill_to_discrete(Exp, generated_labels_full,
                                                 dims_list).detach().cpu().numpy().astype(int)

    # genZ_doX= get_generated_labels(Exp, label_generators, {}, {}, {"X":0}, ["Z"], Exp.Synthetic_Sample_Size, vars,needDataset=False)
    # genY_doX = get_generated_labels(Exp, label_generators, {}, {}, {"Z":genZ_doX}, ["Y"], Exp.Synthetic_Sample_Size, vars,needDataset=False)
    #
    # genY_doX_disc = map_fill_to_discrete(Exp, genY_doX).detach().cpu().numpy().astype(int)
    #
    # _, _, _, true_dist_dict = get_synthetic_dist(Exp, ["Y"], {"X":0}, load_scm=1)
    # doX_tvd = match_with_true_dist(Exp, ["Y"], genY_doX_disc, true_dist_dict, doPrint=False)
    # print("P(Y|do(X=0)", doX_tvd)

    # print("P(Y|do(X))")
    # ret1=compare_conditionals_within(Exp, generated_intv_full, ["Y"], ["X"],
    #                             doPrint=False)
    #
    # ret2=compare_conditionals_within(Exp, intv_real_dataset.detach().cpu().numpy().astype(int), ["Y"], ["X"],
    #                             doPrint=False)
    #
    # print("TVD",calculate_TVD(ret1, ret2, doPrint=True))

    for feat in Exp.features:
        mech_tvd = 0
        for lbid, label in enumerate(Exp.label_names):

            conditons = copy.deepcopy(Exp.train_mech_list[lbid]["compare"])
            conditons.remove(label)
            pstr = feat + ":P(" + label + "|" + str(conditons) + ")"
            print(pstr)
            # ret1 = compare_conditionals_within(Exp, generated_labels_full, feat, [label],  Exp.Observed_DAG[label], doPrint=False)
            ret1 = compare_conditionals_within(Exp, generated_labels_full, feat, [label], conditons, doPrint=False)

            ret2 = compare_conditionals_within(Exp, obs_real_dataset.detach().cpu().numpy().astype(int), feat, [label],
                                               conditons, doPrint=False)

            # div=Exp.label_dim**len(Exp.Observed_DAG[label])
            div = np.prod([Exp.label_dim[lb][feat] for lb in conditons])
            tvd = calculate_TVD(ret1, ret2, doPrint=False) / div
            print(pstr + " TVD:", tvd)
            print("--------------")

            if label == vars["mech"]:
                mech_tvd = tvd

        # last_id= Exp.label_names.index(vars["mech"])
        # for lbid in range(last_id+1):
        #     var_set= Exp.label_names[0:lbid+1]
        #     pstr = "P(" + str(var_set) + ")"
        #     ret1 = compare_conditionals_within(Exp, generated_labels_full, feat, var_set , [], doPrint=False)
        #     ret2 = compare_conditionals_within(Exp, obs_real_dataset.detach().cpu().numpy().astype(int), feat, var_set, [],doPrint=False)
        #     print(pstr + " TVD:", calculate_TVD(ret1, ret2, doPrint=False))
        #     print("--------------")

        # print("Joint distribution, P(" + str(vars["compare"]) + ")")
        # ret1 = compare_conditionals_within(Exp, generated_labels_full, feat, vars["compare"], [], doPrint=False)
        # ret2 = compare_conditionals_within(Exp, obs_real_dataset.detach().cpu().numpy().astype(int), feat, vars["compare"], [],
        #                                    doPrint=False)
        # print("Joint dist TVD:", calculate_TVD(ret1, ret2, doPrint=False))
        # print("--------------")

    return mech_tvd





def get_observational_loss(Exp, obs_vars, label_generators, tvd_diff, kl_diff):
    feat= "feature"
    generated_labels_dict = get_generated_labels(Exp, label_generators, {}, {}, {}, obs_vars, Exp.Synthetic_Sample_Size)
    generated_labels_full = map_dictfill_to_discrete(Exp, generated_labels_dict, obs_vars)

    # _, _, _, true_dist_dict = get_synthetic_dist(Exp, Exp.label_names, obs_bn[feat])
    query_str = getdoKey(obs_vars, {})  # getting the scm saving file name
    true_dist_dict = get_intv_dist(Exp, obs_vars, [], query_str)

    tvd, kl, true_dist, fake_dist = match_with_true_dist(Exp, obs_vars, generated_labels_full, true_dist_dict, feat, doPrint=False)

    tvd_diff[query_str]=[tvd]
    kl_diff[query_str]= [kl]

    return tvd_diff, kl_diff, true_dist, fake_dist







def get_expected_loss_interventions(Exp, cur_mechs, label_generators, tvd_diff, kl_diff):
    feat="feature"

    fake_expected_dist={}
    true_expected_dist = {}
    for query in Exp.interv_queries:

        if bool(set(query["obs"]) & set(cur_mechs)) ==False:
            continue

        compare_Var = list(query["intervs"][0].keys())  #getting the intervened variables
        query_str = getdoKey(compare_Var, dict({}))  # getting the scm saving file name
        obs_dist = get_intv_dist(Exp, compare_Var , dict({}), query_str) # getting the obs distribution of intv variables

        # {"obs": obs_vars, "intervs": key_val, "expr": intervention["expr"]}
        tvd_sum = 0
        kl_sum = 0
        for intv_key in query["intervs"]:

            query_string= getdoKey(query["obs"], intv_key)
            true_dist= get_intv_dist(Exp, query["obs"], intv_key, query_string)

            generated_labels_dict = get_generated_labels(Exp, label_generators, {}, {}, intv_key, query["obs"], Exp.Synthetic_Sample_Size)
            generated_labels_full = map_dictfill_to_discrete(Exp, generated_labels_dict, query["obs"])
            obs_tvd, obs_kl, true_dist, fake_dist = match_with_true_dist(Exp, query["obs"], generated_labels_full, true_dist, feat, doPrint=False)  # get it from scm

            print(f'{intv_key}: tvd:{obs_tvd}, kl:{obs_kl} and tvd<={np.sqrt(0.5 * obs_kl)}')

            tvd_sum += obs_tvd * obs_dist[tuple(intv_key.values())]
            kl_sum += obs_kl * obs_dist[tuple(intv_key.values())]

            for key_dist in true_dist:
                if key_dist not in true_expected_dist:
                    true_expected_dist[key_dist]=0
                    fake_expected_dist[key_dist]=0

                true_expected_dist[key_dist]+= true_dist[key_dist]* (1/len(query["intervs"]))
                fake_expected_dist[key_dist]+= fake_dist[key_dist]* (1/len(query["intervs"]))



        print(f'--->Average tvd:{tvd_sum}, kl:{kl_sum} and tvd<={np.sqrt(0.5 * kl_sum)}')
        tvd_diff[query["expr"]].append(round(tvd_sum, 4))
        kl_diff[query["expr"]].append(round(kl_sum, 4))


    return tvd_diff, kl_diff, true_expected_dist, fake_expected_dist




def get_expected_true_cf(Exp):
    cfquery = Exp.cf_queries[0]
    true_expected_dist = {}
    for evidence in cfquery["evidence"]:
        for intv_key in cfquery["intervs"]:
            print("ev:",evidence, "intv:",intv_key)
            true_cf_dist = get_cf_dist(Exp, cfquery["obs"], intv_key, evidence, cfquery["expr"], load_dist=True)
            for key_dist in true_cf_dist:
                if key_dist not in true_expected_dist:
                    true_expected_dist[key_dist] = 0
                true_expected_dist[key_dist] += true_cf_dist[key_dist] * (1/len(cfquery["intervs"])) * (1/len(cfquery["evidence"]))

            tempo_true = dict(sorted(true_cf_dist.items(), key=lambda item: item[1], reverse=True))
            print(tempo_true)


    return  true_expected_dist


def get_expected_loss_countefactuals(Exp, cur_mechs, label_generators,  tvd_diff, kl_diff):
    feat="feature"
    cfquery = Exp.cf_queries[0]

    if bool(set(cfquery["obs"]) & set(cur_mechs)) == False:
        return  tvd_diff, kl_diff


    evidence_vars = [Exp.twin_map[lb] for lb in cfquery["evidence"][0].keys()]
    compare_Var = list(evidence_vars)  # getting the intervened variables
    query_str = getdoKey(compare_Var, dict({}))  # getting the scm saving file name
    obs_dist = get_intv_dist(Exp, compare_Var, dict({}), query_str)  # getting the obs distribution of intv variables


    final_tvd=0
    final_kl=0

    n_samples = Exp.Synthetic_Sample_Size


    evidence_list= [evidence for evidence in cfquery["evidence"]]
    all_posterior_label, all_posterior_latent, all_gumbel_noise = rejection_sampling_optimized(Exp, label_generators, n_samples, evidence_list,
                                                                             max_rejections=0, warn=100)

    fake_expected_dist = {}
    true_expected_dist = {}
    for evidence in cfquery["evidence"]:

        kev= asKey(evidence)
        posterior_label, posterior_latent, gumbel_noise = all_posterior_label[kev],  all_posterior_latent[kev], all_gumbel_noise[kev]


        tvd_sum = 0
        kl_sum = 0
        for intv_key in cfquery["intervs"]:


            cf_all_labels_dict = get_generated_labels(Exp, label_generators, posterior_label, posterior_latent,
                                                      intv_key, cfquery["obs"], n_samples, gumbel_noise=gumbel_noise)
            cf_samples = map_dictfill_to_discrete(Exp, cf_all_labels_dict, cfquery["obs"])

            true_cf_dist = get_cf_dist(Exp, cfquery["obs"], intv_key, evidence, cfquery["expr"], load_dist=True)



            cf_tvd, cf_kl, true_dist, fake_dist = match_with_true_dist(Exp, cfquery["obs"], cf_samples, true_cf_dist, feat,
                                                 doPrint=False)  # get it from scm


            tempo_true = dict(sorted(true_cf_dist.items(), key=lambda item: item[1], reverse=True))
            tempo_fake = dict(sorted(fake_dist.items(), key=lambda item: item[1], reverse=True))

            print("---")
            print(tempo_true)
            print(tempo_fake)

            tvd_sum += cf_tvd * obs_dist[tuple(intv_key.values())]
            kl_sum += cf_kl * obs_dist[tuple(intv_key.values())]

            for key_dist in true_dist:
                if key_dist not in true_expected_dist:
                    true_expected_dist[key_dist] = 0
                    fake_expected_dist[key_dist] = 0

                true_expected_dist[key_dist] += true_dist[key_dist] * (1/len(cfquery["intervs"])) * (1/len(cfquery["evidence"]))
                fake_expected_dist[key_dist] += fake_dist[key_dist] * (1/len(cfquery["intervs"])) * (1/len(cfquery["evidence"]))

            print(f"CF query done for evidence:{evidence}, intv_key: {intv_key} ")
            print("tvd:",cf_tvd , " prpb:",obs_dist[tuple(intv_key.values())], "and", "kl:",cf_kl , " prob:",obs_dist[tuple(intv_key.values())])

            ### testing
            # labels_dict = get_generated_labels(Exp, label_generators, {}, {}, intv_key, cfquery["obs"], n_samples)
            # samples = map_dictfill_to_discrete(Exp, labels_dict, cfquery["obs"])
            # intvvv_dist = get_intv_dist(Exp, cfquery["obs"], intv_key, "test")
            # print("intvvv_dist", intvvv_dist)
            # int_tvd, int_kl = match_with_true_dist(Exp, cfquery["obs"], samples, intvvv_dist, feat, doPrint=False)  # get it from scm
            # print("intv dist losses", int_tvd, int_kl)
            ### testing


        final_tvd += tvd_sum * obs_dist[tuple(evidence.values())]
        final_kl += kl_sum * obs_dist[tuple(evidence.values())]

    tvd_diff[cfquery["expr"]].append(final_tvd)
    kl_diff[cfquery["expr"]].append(final_kl)

    return tvd_diff, kl_diff, true_expected_dist, fake_expected_dist


def evaluate_after_epochs(Exp, cur_mechs, label_generators, dataset_dict, tvd_diff, kl_diff):
    for gen in label_generators:
        label_generators[gen].eval()

    with torch.no_grad():
        # observational tvd for each mechanisms so that we can notice that mechanism learning

        feat = "feature"
        all_generated_labels={}
        all_real_labels={}


        for interv_no, key in enumerate(Exp.Data_intervs):

            for cur_mechs in Exp.Data_observs:


                intv_key = asKey(key)

                compare_Var =[]
                # for lb in cur_mechs:
                #     if lb in dict(intv_key):
                #         continue
                #     compare_Var.append(lb)

                for mech in cur_mechs:
                    ret = [lb for lb in Exp.train_mech_dict[mech][interv_no]["compare"] if not lb in compare_Var]
                    compare_Var += ret

                obs_indices = [Exp.label_names.index(lb) for lb in compare_Var]

                current_real_label=[]
                if intv_key in dataset_dict:
                    current_real_label = dataset_dict[intv_key][:, obs_indices].type(torch.LongTensor).view(-1, len(obs_indices)).to(
                        Exp.DEVICE)

                generated_labels_dict = get_generated_labels(Exp, label_generators, {}, {}, dict(intv_key), compare_Var,Exp.Synthetic_Sample_Size)
                generated_labels_full= map_dictfill_to_discrete(Exp, generated_labels_dict, compare_Var)

                query_str = getdoKey(compare_Var, dict(intv_key))
                true_dist_dict = get_intv_dist(Exp, compare_Var, dict(intv_key), query_str)

                obs_tvd, obs_kl, _,_ = match_with_true_dist(Exp, compare_Var, generated_labels_full,true_dist_dict, feat, doPrint=False)


                # query_str = "".join(x for x in compare_Var) + "|do" + "".join(x for x in intv_key.keys()) + "_" + "".join(str(x) for x in intv_key.values())

                tvd_diff[query_str].append(round(obs_tvd , 4))  #todo: fix it
                kl_diff[query_str].append(round(obs_kl , 4))  #Todo: fix it
                all_generated_labels[intv_key] = torch.tensor(generated_labels_full)
                all_real_labels[intv_key] = torch.tensor(current_real_label)


        # if (Exp.curr_epoochs <= 50 and (Exp.curr_epoochs + 1) % 5 == 0) or (Exp.curr_epoochs > 50 and (Exp.curr_epoochs + 1) % 15 == 0):
        # if (Exp.curr_epoochs + 1) % 1 == 0:
        tvd_diff, kl_diff, _, _ = get_observational_loss(Exp, Exp.label_names, label_generators, tvd_diff, kl_diff)
        # tvd_diff, kl_diff, _, _ = get_expected_loss_interventions(Exp, cur_mechs,  label_generators, tvd_diff, kl_diff)
        # tvd_diff, kl_diff= get_expected_loss_countefactuals(Exp, cur_mechs,  label_generators, tvd_diff, kl_diff)



        save_results(Exp, Exp.SAVED_PATH, all_generated_labels ,all_real_labels,
                     tvd_diff, kl_diff, Exp.G_avg_losses, Exp.D_avg_losses)



    for gen in label_generators:
        label_generators[gen].train()

    ll = -min(10, len(list(tvd_diff.values())[0]))
    # printing loss
    for dist in tvd_diff:
        print("###", dist, " loss%:", tvd_diff[dist][ll:])
    print(Exp.SAVED_PATH)

    return tvd_diff , kl_diff



