import torch

from tfFunctionsUtils import get_fake_distribution, get_joint_distributions_from_samples, calculate_TVD, \
    calculate_KL, getdoKey


def csXrayEvaluation(Exp, label_generators, train_dataset, tvd_diff, kl_diff):
    for gen in label_generators:
        label_generators[gen].eval()

    with torch.no_grad():

        compare_Var= ['covid_19', 'pneum']
        fake_dist_dict = get_fake_distribution(Exp, label_generators, {}, ['covid_19', 'pneum'])
        dataset_dist_dict = get_joint_distributions_from_samples(Exp, compare_Var,
                                                                 train_dataset.detach().cpu().numpy().astype(
                                                                     int), "feature")

        obs_tvd = calculate_TVD(fake_dist_dict, dataset_dist_dict, doPrint=False)
        obs_kl = calculate_KL(fake_dist_dict, dataset_dist_dict, doPrint=False)

        query_str = getdoKey(compare_Var, {})
        tvd_diff[query_str].append(round(obs_tvd, 4))  # todo: fix it
        kl_diff[query_str].append(round(obs_kl, 4))


    for gen in label_generators:
        label_generators[gen].train()

    ll = -min(10, len(list(tvd_diff.values())[0]))
    # printing loss
    for dist in tvd_diff:
        print("###", dist, " loss%:", [round(val, 4) for val in tvd_diff[dist][ll:]])
    print(Exp.SAVED_PATH)

    return tvd_diff, kl_diff

    #     ######
    #
    #     all_generated_labels={}
    #     all_real_labels={}
    #
    #     for query in Exp.interv_queries:
    #         for key in query["intervs"]:
    #             compare_Var= query["obs"]
    #             # for interv_no, key in enumerate(Exp.Data_intervs):
    #             intv_key = asKey(key)
    #             query_str = getdoKey(compare_Var, dict(intv_key))
    #
    #             if key=={}:
    #                 # continue
    #
    #                 if len(compare_Var)==0:
    #                     continue
    #
    #                 _, _, _, graph_label_vars = get_training_variables(Exp, Exp.label_names, 0, key)
    #                 obs_indices = [graph_label_vars.index(lb) for lb in compare_Var]
    #                 current_real_label = []
    #                 if intv_key in dataset_dict:
    #                     current_real_label = dataset_dict[intv_key]["obs"][:, obs_indices].type(torch.LongTensor).view(-1,len(obs_indices)).to(Exp.DEVICE)
    #
    #
    #                 fake_dist_dict= get_fake_distribution(Exp, label_generators, intv_key, compare_Var)
    #                 dataset_dist_dict = get_joint_distributions_from_samples(Exp, compare_Var,
    #                                                                          current_real_label.detach().cpu().numpy().astype(
    #                                                                              int), "feature")
    #
    #                 # true_dist_dict = get_intv_dist(Exp, compare_Var, dict(intv_key), query_str, load_prev=False)
    #                 # dataset_dist_dict=true_dist_dict
    #
    #                 obs_tvd = calculate_TVD(fake_dist_dict, dataset_dist_dict, doPrint=False)
    #                 obs_kl= calculate_KL(fake_dist_dict, dataset_dist_dict, doPrint=False)
    #
    #                 if query_str in tvd_diff:
    #                     tvd_diff[query_str].append(round(obs_tvd, 4))  # todo: fix it
    #                     kl_diff[query_str].append(round(obs_kl, 4))
    #
    #                 # print("")
    #
    #
    #             # elif query["expr"]=="P(genC|do(D))":
    #             else:
    #                 fake_dist_dict= get_fake_distribution(Exp, label_generators, key, compare_Var)
    #                 print('fake intv dist_dict',fake_dist_dict)
    #                 D = get_dataset(Exp, 'medD', 0)
    #                 U0 = get_dataset(Exp, 'medU0', 0)
    #                 genC = get_dataset(Exp, 'medC', 0)
    #                 cur_data = torch.cat([U0, D, genC], 1).cpu().numpy()
    #                 px = pd.DataFrame(cur_data)
    #                 px = px.rename(columns={0: 'medU0', 1: 'medD', 2: 'medC'})
    #                 dataset_dist_dict = estiamte_ate_backdoor_direct(Exp, px, 'medD', 'medC', ['medU0'])[list(key.values())[0]]
    #                 dataset_dist_dict= {tuple([key]):val for key,val in dataset_dist_dict.items()}
    #                 obs_tvd = calculate_TVD(fake_dist_dict, dataset_dist_dict, doPrint=False)
    #                 obs_kl = calculate_KL(fake_dist_dict, dataset_dist_dict, doPrint=False)
    #
    #                 if query_str in tvd_diff:
    #                     tvd_diff[query_str].append(round(obs_tvd, 4))  # todo: fix it
    #                     kl_diff[query_str].append(round(obs_kl, 4))
    #
    #
    #     #####-----------
    #     # if set(all_compare_Var) & set(Exp.image_labels) !=set():
    #         # compare_Var = cur_mechs[0:-1]
    #
    #         # compare_Var=["D", "C"]
    #
    #             showImage=True
    #             if key=={} and showImage and (Exp.curr_epoochs+1)%1==0:
    #                 minibatch = 2
    #                 generated_labels_dict = get_generated_labels(Exp, label_generators, {}, {}, {}, compare_Var+[Exp.image_labels[0]], minibatch, hard=True)
    #                 generated_image = generated_labels_dict[Exp.image_labels[0]]
    #                 del generated_labels_dict[Exp.image_labels[0]]
    #
    #                 # y_dims = sum([Exp.label_dim[lb]["feature"] for lb in compare_Var])
    #                 # ret = list(generated_labels_dict.values())
    #                 # generated_labels_ig = torch.cat(ret, 1).view(-1, y_dims)
    #                 generated_labels_ig = map_dictfill_to_discrete(Exp, generated_labels_dict, compare_Var)
    #
    #                 for grow, genimg in zip(generated_labels_ig, generated_image):
    #                     print("gen", grow)
    #                     genimg = genimg.permute(1, 2, 0).detach().cpu().numpy()
    #                     # plot_dataset_digits(1, 2, [obsimg, genimg], f'Real {Ores_digit[id]}')
    #
    #
    #                     cur_fold=os.getcwd()
    #                     plot_trained_digits(1, 1, [genimg], f'Real {grow}', f'{cur_fold}/PLOTS')
    #
    #
    #     save_results(Exp, Exp.SAVED_PATH, all_generated_labels ,all_real_labels,
    #                  tvd_diff, kl_diff, Exp.G_avg_losses, Exp.D_avg_losses)
    #
    #
    #
    # for gen in label_generators:
    #     label_generators[gen].train()
    #
    # ll = -min(10, len(list(tvd_diff.values())[0]))
    # # printing loss
    # for dist in tvd_diff:
    #     print("###", dist, " loss%:",  [round(val, 4) for val in tvd_diff[dist][ll:]])
    # print(Exp.SAVED_PATH)
    #
    # return tvd_diff , kl_diff
    #
    #
    #
    #

