import torch


def check_mvhg_hierarchical(relations, y_mask_all, p_x_all, n_all, device, n_samples):
    g_order = relations["order"]
    grouping = relations["grouping"]
    for m_k in g_order[-3:]:
        m_all_m = grouping[m_k]["m_all"]
        n_w = grouping[m_k]["n_w"]
        n_m = m_all_m[m_k]
        x_c = torch.zeros((n_samples, 1, 1), device=device)
        for s_k in m_all_m.keys():
            v = y_mask_all[s_k]
            if s_k in n_w.keys():
                enum, denom = n_w[s_k]
                sub_ = v[:, :, 0::denom]
                x_c += torch.sum(sub_, dim=2).unsqueeze(2) - 1.0
            else:
                x_c += torch.sum(v, dim=2).unsqueeze(2) - 1.0

        if (x_c == n_m).all():
            continue
        print(m_k, ": INCORRECT number of drawn samples")
        for k in range(0, n_samples):
            if x_c[k] == n_m:
                continue
            for s_k in m_all_m.keys():
                v = y_mask_all[s_k][k]
                if s_k in n_w.keys():
                    enum, denom = n_w[s_k]
                    sub_ = v[:, 0::denom]
                    x_i = torch.sum(sub_) - 1.0
                else:
                    x_i = torch.sum(v) - 1.0
                print(s_k, ": ", x_i.item(), end=", ")
                print(v)
                print(sub_)
            print("")
            print(
                "m0_m1: ",
                torch.sum(y_mask_all["m0_m1"][k]).item(),
                "  n: ",
                n_all["m0_m1"][k].item(),
            )
            print(
                "m1_m2: ",
                torch.sum(y_mask_all["m1_m2"][k]).item(),
                "  n: ",
                n_all["m1_m2"][k].item(),
            )
            print(
                "m0_m2: ",
                torch.sum(y_mask_all["m0_m2"][k]).item(),
                "  n: ",
                n_all["m0_m2"][k].item(),
            )
            print("")


def check_mvhg_mods(
    relations,
    m_all,
    m_rest_all,
    w_all,
    w_rest_all,
    x_all,
    p_x_all,
    logits_all,
    hside_all,
    n_new_all,
    device,
    n_samples,
):
    mods_subsets = relations["mods_subsets"]
    n_mods = relations["n"]
    for m_k, m_v in mods_subsets.items():
        sum_mod = torch.zeros((n_samples, 1, 1), device=device)
        for s_k in m_v:
            x_c = x_all[s_k]
            sum_mod += x_c
        if not (n_mods[m_k] == sum_mod).all():
            print("")
            print(m_k)
            for k in range(0, n_samples):
                if not n_mods[m_k] == sum_mod[k]:
                    for s_k in m_v:
                        x_s_k = x_all[s_k]
                        print(s_k, "n_new: ", n_new_all[s_k][k])
                        print("m: ", m_all[s_k])
                        print("m_rest: ", m_rest_all[s_k])
                        print("x: ", x_s_k[k])
                        print("hside: ", hside_all[s_k][k])
                        print("w: ", w_all[s_k][k])
                        if s_k in p_x_all.keys():
                            print("w_rest: ", w_rest_all[s_k][k])
                            print("p_x_k: ", p_x_all[s_k][k])
                            print("logits_x_k: ", logits_all[s_k][k])
