import sys
import numpy as np
import matplotlib.pyplot as plt
import logging
from scipy.stats import hypergeom
import random

import torch
import torch.nn.functional as F

from mvhg.pt_heaviside import sigmoid

log_level = logging.INFO
relu = torch.nn.ReLU()
sm = torch.nn.Softmax(dim=-1)
hside_approx = sigmoid(slope=1.0)
p_hside = 100


def sample_gumbel(shape, device, eps=1e-20):
    U = torch.rand(shape).to(device)
    return -torch.log(-torch.log(U + eps) + eps)


def resample_omega(log_w, device):
    # gs = sample_gumbel(log_w.size(), device)
    # return log_w + gs
    gauss = torch.randn(log_w.size(), device=device)
    return log_w + gauss


def gumbel_softmax_sample(logits, temperature, gumbel_noise, device):
    gs = sample_gumbel(logits.size(), device)
    if gumbel_noise:
        logits = logits + gs
    return F.softmax(logits / temperature, dim=-1)


def gumbel_softmax(logits, temperature, device, gumbel_noise=True):
    """
    input: [*, n_class]
    return: [*, n_class] an one-hot vector
    """
    y = gumbel_softmax_sample(logits, temperature, gumbel_noise, device)
    shape = y.size()
    _, ind = y.max(dim=-1)
    y_hard = torch.zeros_like(y).view(-1, shape[-1])
    y_hard.scatter_(1, ind.view(-1, 1), 1)
    y_hard = y_hard.view(*shape)
    return (y_hard - y).detach() + y


def get_logits_fisher(n, m, m2, log_w1, log_w2, hside_check, device, eps=1e-20):
    n_samples = log_w1.shape[0]
    x_all = torch.arange(m + 1, dtype=torch.float32).to(device)
    x_all = x_all.unsqueeze(0).repeat(n_samples, 1, 1)
    log_x1_fact = torch.lgamma(x_all + 1.0)
    log_m1_x1_fact = torch.lgamma(m - x_all + 1.0)
    log_x2_fact = torch.lgamma(relu(n - x_all) + 1.0)
    log_m2_x2_fact = torch.lgamma(relu(m2 - relu(n - x_all)) + 1.0)
    log_m_x = log_x1_fact + log_x2_fact + log_m1_x1_fact + log_m2_x2_fact
    x1_log_w1 = x_all * log_w1
    x2_log_w2 = (n - x_all) * log_w2
    log_p_shifted = x1_log_w1 + x2_log_w2 - log_m_x
    log_p_shifted_sup2 = log_p_shifted + p_hside * torch.log(hside_check + eps)
    return log_p_shifted_sup2


def get_log_p_x_from_logits(logits):
    log_p = logits - torch.logsumexp(logits)
    return log_p


def calc_group_w_m(m_group, log_w_group):
    """
    Calculates the merged w and m for a group G of classes

    returns: log(w_G), m_G
    """
    m_group = m_group.float()
    m_G = torch.sum(m_group)
    lse_arg = log_w_group + m_group.log() - m_G.log()
    log_w_G = torch.logsumexp(lse_arg, dim=2)
    return log_w_G, m_G


def get_m_w_hierarchical(key_g, grouping, w_all, device, eps=1e-10):
    g_l = grouping[key_g]["L"]
    g_r = grouping[key_g]["R"]
    m_all = grouping[key_g]["m_all"]
    m_w = grouping[key_g]["m_w"]
    w_l_v = []
    w_r_v = []
    m_r = []
    m_l = []
    for s_k in g_r:
        w_r_v.append(w_all[s_k].unsqueeze(1))
        m_r.append(m_w[s_k] * m_all[s_k])
    m_r = torch.tensor(m_r, device=device)
    if len(w_r_v) > 0:
        w_r_v = torch.cat(w_r_v, dim=1).unsqueeze(1)
        log_w_r, m_r = calc_group_w_m(m_r, w_r_v)
    else:
        w_r_v = torch.tensor(0.0, device=device).unsqueeze(0) + eps
        log_w_r = torch.tensor(-1000.0, device=device).unsqueeze(0)
    for s_k in g_l:
        w_l_v.append(w_all[s_k].unsqueeze(1))
        m_l.append(m_w[s_k] * m_all[s_k])
        m_l = torch.tensor(m_l, device=device)
    if len(w_l_v) > 0:
        w_l_v = torch.cat(w_l_v, dim=1).unsqueeze(1)
        log_w_l, m_l = calc_group_w_m(m_l, w_l_v)
    else:
        w_l_v = torch.tensor(0.0, device=device).unsqueeze(0) + eps
        log_w_l = torch.tensor(-1000.0, device=device).unsqueeze(0)
    return [m_l, log_w_l], [m_r, log_w_r]


def get_m_w_rest_mods(relations, m_all, w_all, x_all, key_m, key_s, device, eps=1e-10):
    subsets = relations["mods_subsets"][key_m]
    w_r_v = []
    m_r = []
    s_r = False
    subs = []
    for s_k in subsets:
        if s_r:
            if not s_k in x_all.keys():
                w_r_v.append(w_all[s_k].unsqueeze(1))
                m_r.append(m_all[s_k])
                subs.append(s_k)
        if s_k == key_s:
            s_r = True
    m_r = torch.tensor(m_r, device=device)

    if len(w_r_v) > 0:
        w_r_v = torch.cat(w_r_v, dim=1).unsqueeze(1)
        log_w_r, m_r = calc_group_w_m(m_r, w_r_v)
    else:
        w_r_v = torch.tensor(0.0, device=device).unsqueeze(0) + eps
        log_w_r = torch.tensor(-1000.0, device=device).unsqueeze(0)
    return m_r, log_w_r


def get_n_hierarchical(key_g, grouping, y_mask_all, device, n_samples):
    n_prior = grouping[key_g]["n_prior"]
    n_prevs = grouping[key_g]["n_prev"]
    n_w = grouping[key_g]["n_w"]
    mult_n = grouping[key_g]["n_m"]
    n_mods = []
    for m_k, m_v in n_prevs.items():
        x_p_sum_m = torch.zeros((n_samples, 1, 1), device=device)
        for s_k in m_v:
            enum, denom = n_w[s_k]
            v = y_mask_all[s_k]
            sub_ = v[:, :, 0::denom]
            x_i = torch.sum(sub_, dim=2).unsqueeze(2) - 1.0
            x_p_sum_m += enum * x_i
        if isinstance(n_prior, dict):
            n_p = n_prior[m_k]
        else:
            n_p = n_prior
        n_m = n_p - x_p_sum_m
        n_mods.append(n_m)
    n = mult_n * torch.min(torch.cat(n_mods, dim=2), dim=2)[0].unsqueeze(1)
    return n


def get_n_mods(relations, m_all, x_all, key_m, key_s, device, n_samples):
    subsets = relations["mods_subsets"][key_m]
    sum_prev_v = []
    n_mods = []
    for m_k, m_v in relations["mods_subsets"].items():
        if not key_s in m_v:
            continue
        n_prior = relations["n"][m_k]
        x_prev_mod = torch.zeros((n_samples, 1, 1), device=device)
        x_prev = []
        for s_k in m_v:
            if not s_k in x_all.keys():
                continue
            x_prev.append(s_k)
            x_s_prev = x_all[s_k]
            x_prev_mod += x_s_prev
        n_mod = relu(n_prior - x_prev_mod)
        n_mods.append(n_mod)
    n = torch.min(torch.cat(n_mods, dim=2), dim=2)[0].unsqueeze(1)
    return n


def check_mvhg_hierarchical(relations, y_mask_all, p_x_all, n_all, device, n_samples):
    g_order = relations["order"]
    grouping = relations["grouping"]
    for m_k in g_order[-3:]:
        m_all_m = grouping[m_k]["m_all"]
        n_w = grouping[m_k]["n_w"]
        n_m = m_all_m[m_k]
        x_c = torch.zeros((n_samples, 1, 1), device=device)
        for s_k in m_all_m.keys():
            v = y_mask_all[s_k]
            if s_k in n_w.keys():
                enum, denom = n_w[s_k]
                sub_ = v[:, :, 0::denom]
                x_c += torch.sum(sub_, dim=2).unsqueeze(2) - 1.0
            else:
                x_c += torch.sum(v, dim=2).unsqueeze(2) - 1.0

        if (x_c == n_m).all():
            continue
        print(m_k, ": INCORRECT number of drawn samples")
        for k in range(0, n_samples):
            if x_c[k] == n_m:
                continue
            for s_k in m_all_m.keys():
                v = y_mask_all[s_k][k]
                if s_k in n_w.keys():
                    enum, denom = n_w[s_k]
                    sub_ = v[:, 0::denom]
                    x_i = torch.sum(sub_) - 1.0
                else:
                    x_i = torch.sum(v) - 1.0
                print(s_k, ": ", x_i.item(), end=", ")
                print(v)
                print(sub_)
            print("")
            print(
                "m0_m1: ",
                torch.sum(y_mask_all["m0_m1"][k]).item(),
                "  n: ",
                n_all["m0_m1"][k].item(),
            )
            print(
                "m1_m2: ",
                torch.sum(y_mask_all["m1_m2"][k]).item(),
                "  n: ",
                n_all["m1_m2"][k].item(),
            )
            print(
                "m0_m2: ",
                torch.sum(y_mask_all["m0_m2"][k]).item(),
                "  n: ",
                n_all["m0_m2"][k].item(),
            )
            print("")


def check_mvhg_mods(
    relations,
    m_all,
    m_rest_all,
    w_all,
    w_rest_all,
    x_all,
    p_x_all,
    logits_all,
    hside_all,
    n_new_all,
    device,
    n_samples,
):
    mods_subsets = relations["mods_subsets"]
    n_mods = relations["n"]
    for m_k, m_v in mods_subsets.items():
        sum_mod = torch.zeros((n_samples, 1, 1), device=device)
        for s_k in m_v:
            x_c = x_all[s_k]
            sum_mod += x_c
        if not (n_mods[m_k] == sum_mod).all():
            print("")
            print(m_k)
            for k in range(0, n_samples):
                if not n_mods[m_k] == sum_mod[k]:
                    for s_k in m_v:
                        x_s_k = x_all[s_k]
                        print(s_k, "n_new: ", n_new_all[s_k][k])
                        print("m: ", m_all[s_k])
                        print("m_rest: ", m_rest_all[s_k])
                        print("x: ", x_s_k[k])
                        print("hside: ", hside_all[s_k][k])
                        print("w: ", w_all[s_k][k])
                        if s_k in p_x_all.keys():
                            print("w_rest: ", w_rest_all[s_k][k])
                            print("p_x_k: ", p_x_all[s_k][k])
                            print("logits_x_k: ", logits_all[s_k][k])


def get_hside_check(n, x_i, m_r):
    check_x_i = n - x_i
    check_x_rest = m_r - relu(n - x_i)
    hside_check_x_i = hside_approx(check_x_i)
    hside_check_x_rest = hside_approx(check_x_rest)
    hside_check = hside_check_x_i * hside_check_x_rest
    return hside_check


def pmf_mvfnchg_modbased(
    w_all,
    temperature,
    relations,
    logits_func=get_logits_fisher,
    device="cuda",
    eps=1e-10,
):
    n_samples = w_all[list(w_all.keys())[0]].shape[0]
    y_all = {}
    y_mask_all = {}
    x_all = {}
    p_x_i_all = {}
    n_new_all = {}
    hside_all = {}
    logits_all = {}
    w_rest_all = {}
    m_rest_all = {}

    m_all = relations["m_all"]
    mods_subsets = relations["mods_subsets"]
    for m_k, m_v in random.sample(list(mods_subsets.items()), len(mods_subsets)):
        for s_k in m_v:
            if s_k in y_mask_all.keys():
                continue
            w_i = w_all[s_k].unsqueeze(1).unsqueeze(1)
            m_i = m_all[s_k]
            rests = get_m_w_rest_mods(relations, m_all, w_all, x_all, m_k, s_k, device)
            m_rest, log_w_rest = rests
            m_rest_sum = torch.sum(m_rest)
            m_rest_all[s_k] = m_rest_sum
            log_w_rest = log_w_rest.unsqueeze(1)
            log_w_rest_all[s_k] = log_w_rest
            n_new = get_n_mods(relations, m_all, x_all, m_k, s_k, device, n_samples)
            n_new_all[s_k] = n_new
            hside_check = get_hside_check(n_new, x_i, m_rest_sum)
            hside_all[s_k] = hside_check
            lc_rel = relations["last_class"]
            if not lc_rel[s_k]:
                logits_p_x_i = logits_func(
                    n_new, m_i, m_rest_sum, w_i, w_rest, hside_check, device, eps
                )
                logits_all[s_k] = logits_p_x_i
                p_x_i = sm(logits_p_x_i)
                p_x_i_all[s_k] = p_x_i
                y_i = gumbel_softmax(logits_p_x_i, temperature, device, True)
            else:
                y_i = hside_check
            y_all[s_k] = y_i
            ones = torch.ones((n_samples, int(m_i) + 1, int(m_i) + 1))
            ones = ones.to(device)
            lt = torch.tril(ones)
            y_mask_filled = torch.matmul(y_i, lt)
            y_mask_filled[:, :, 0] = torch.ones((1, 1), device=device)
            y_mask_all[s_k] = y_mask_filled
            x_i = torch.sum(y_mask_filled, dim=2).unsqueeze(2) - 1.0
            x_all[s_k] = x_i
    return y_all, x_all, y_mask_all


def get_log_prob_modbased(
    X_all,
    log_w_all,
    temperature,
    relations,
    logits_func=get_logits_fisher,
    device="cuda",
    eps=1e-10,
):
    n_samples = log_w_all[list(log_w_all.keys())[0]].shape[0]
    y_all = {}
    y_mask_all = {}
    x_all = {}
    p_x_i_all = {}
    n_new_all = {}
    hside_all = {}
    logits_all = {}
    log_w_rest_all = {}
    m_rest_all = {}

    m_all = relations["m_all"]
    mods_subsets = relations["mods_subsets"]
    for m_k, m_v in random.sample(list(mods_subsets.items()), len(mods_subsets)):
        for s_k in m_v:
            if s_k in y_mask_all.keys():
                continue
            x_sel = X_all[s_k]
            log_w_i = log_w_all[s_k].unsqueeze(1).unsqueeze(1)
            m_i = m_all[s_k]
            x_i = torch.arange(m_i + 1).to(device)
            rests = get_m_w_rest_mods(
                relations, m_all, log_w_all, x_all, m_k, s_k, device
            )
            m_rest, log_w_rest = rests
            m_rest_sum = torch.sum(m_rest)
            m_rest_all[s_k] = m_rest_sum
            log_w_rest = log_w_rest.unsqueeze(1)
            log_w_rest_all[s_k] = log_w_rest
            n_new = get_n_mods(relations, m_all, x_all, m_k, s_k, device, n_samples)
            n_new_all[s_k] = n_new
            hside_check = get_hside_check(n_new, x_i, m_rest_sum)
            hside_all[s_k] = hside_check
            lc_rel = relations["last_class"]
            if not lc_rel[s_k]:
                logits_p_x_i = logits_func(
                    n_new, m_i, m_rest_sum, w_i, w_rest, hside_check, device, eps
                )
                logits_all[s_k] = logits_p_x_i
                log_p_x_i = get_log_p_x_from_logits(logits_p_x_i)
                log_p_x_i_all[s_k] = log_p_x_i
    log_p_X = torch.sum(log_p_x_i_all.values(), dim=1)
    p_X = torch.exp(log_p_X)
    return log_p_X, p_X


def pmf_mvfnchg_hierarchical(
    w_all,
    temperature,
    relations,
    logits_func=get_logits_fisher,
    device="cuda",
    eps=1e-20,
):
    n_samples = w_all[list(w_all.keys())[0]].shape[0]
    y_all = {}
    y_mask_all = {}
    x_all = {}
    p_x_all = {}
    n_all = {}
    group_order = relations["order"]
    grouping = relations["grouping"]
    for g_key in group_order:
        m_w_l, m_w_r = get_m_w_hierarchical(g_key, grouping, w_all, device)
        m_l, w_l = m_w_l
        m_r, w_r = m_w_r
        m_l_sum = torch.sum(m_l)
        m_r_sum = torch.sum(m_r)
        w_l = w_l.unsqueeze(1)
        w_r = w_r.unsqueeze(1)
        n_new = get_n_hierarchical(g_key, grouping, y_mask_all, device, n_samples)
        n_all[g_key] = n_new
        x_i = torch.arange(m_l_sum + 1).to(device)
        hside_check = get_hside_check(n_new, x_i, m_r_sum)
        if not grouping[g_key]["last"]:
            logits_p_x_i = logits_func(
                n_new, m_l_sum, m_r_sum, w_l, w_r, hside_check, device, eps
            )
            p_x_i = sm(logits_p_x_i)
            p_x_all[g_key] = p_x_i
            y_i = gumbel_softmax(logits_p_x_i, temperature, device, True)
        else:
            y_i = hside_check
        ones = torch.ones((n_samples, int(m_l_sum) + 1, int(m_l_sum) + 1))
        ones = ones.to(device)
        lt = torch.tril(ones)
        y_mask_filled = torch.matmul(y_i, lt)
        y_all[g_key] = y_i
        y_mask_all[g_key] = y_mask_filled
        x_i = torch.sum(y_mask_filled, dim=2).unsqueeze(2) - 1.0
        x_all[g_key] = x_i
    return y_all, x_all, y_mask_all


def pmf_noncentral_fmvhg(
    m_all,
    n,
    w_all,
    temperature,
    logits_func=get_logits_fisher,
    device="cuda",
    eps=1e-10,
):
    n_c = m_all.shape[0]
    n = n.unsqueeze(1)
    n_samples = w_all.shape[0]
    n_out = torch.zeros((n_samples, 1, 1), device=device)
    m_out = torch.tensor(0.0).to(device)
    w_all = w_all.unsqueeze(1)
    y_all = []
    y_mask_all = []
    x_all = []
    for i in range(0, n_c):
        n_new = relu(n - n_out)
        m_i = m_all[i]
        w_i = w_all[:, :, i].unsqueeze(1)
        x_i = torch.arange(m_i + 1).to(device)
        x_i = x_i.unsqueeze(0).unsqueeze(1).repeat(n_samples, 1, 1)
        if i + 1 < n_c:
            m_rest = torch.sum(m_all) - m_out - m_i
        else:
            m_rest = torch.tensor(0.0)
        N_new = m_i + m_rest
        m_rest_ind = m_all[i + 1 :]
        if i + 1 < n_c:
            w_rest_ind = w_all[:, :, i + 1 :]
            w_rest, m_rest = calc_group_w_m(m_rest_ind, w_rest_ind)
        else:
            w = torch.ones((n_samples, 1, 1)).to(device)

        x_rest = relu(n_new - x_i)
        hside_check = get_hside_check(n_new, x_i, m_rest)
        if i + 1 < n_c:
            logits_p_x_i = logits_func(
                n_new, m_i, m_rest, w_i, w_rest.unsqueeze(-1), hside_check, device, eps
            )
            p_x_i = sm(logits_p_x_i)
            y_i = gumbel_softmax(logits_p_x_i, temperature, device, True)
        else:
            y_i = hside_check
        y_all.append(y_i)
        ones = torch.ones((n_samples, int(m_i) + 1, int(m_i) + 1))
        ones = ones.to(device)
        lt = torch.tril(ones)
        y_mask_filled = torch.matmul(y_i, lt)
        y_mask_all.append(y_mask_filled)
        x_i = torch.sum(y_mask_filled, dim=2).unsqueeze(2) - 1.0
        x_all.append(x_i)
        n_out += x_i
        m_out += m_i
    assert torch.all((n - n_out) == 0), "num elements samples NOT correct"
    return y_all, x_all, y_mask_all


def pmf_noncentral_fmvhg_rsample_w(
    m_all,
    n,
    log_w_all,
    temperature,
    logits_func=get_logits_fisher,
    device="cuda",
    eps=1e-10,
):
    n_c = m_all.shape[0]
    n = n.unsqueeze(1)
    n_samples = log_w_all.shape[0]
    n_out = torch.zeros((n_samples, 1, 1), device=device)
    m_out = torch.tensor(0.0).to(device)
    log_w_all = log_w_all.unsqueeze(1)
    y_all = []
    y_mask_all = []
    x_all = []

    log_w_all = resample_omega(log_w_all, device)
    for i in range(0, n_c):
        n_new = relu(n - n_out)
        m_i = m_all[i]
        log_w_i = log_w_all[:, :, i].unsqueeze(1)
        x_i = torch.arange(m_i + 1).to(device)
        x_i = x_i.unsqueeze(0).unsqueeze(1).repeat(n_samples, 1, 1)
        if i + 1 < n_c:
            m_rest = torch.sum(m_all) - m_out - m_i
        else:
            m_rest = torch.tensor(0.0)
        N_new = m_i + m_rest
        m_rest_ind = m_all[i + 1 :]
        if i + 1 < n_c:
            w_rest_ind = log_w_all[:, :, i + 1 :]
            log_w_rest, m_rest = calc_group_w_m(m_rest_ind, w_rest_ind)
        else:
            w = torch.ones((n_samples, 1, 1)).to(device)

        x_rest = relu(n_new - x_i)
        hside_check = get_hside_check(n_new, x_i, m_rest)
        if i + 1 < n_c:
            logits_p_x_i = logits_func(
                n_new,
                m_i,
                m_rest,
                log_w_i,
                log_w_rest.unsqueeze(-1),
                hside_check,
                device,
                eps,
            )
            p_x_i = sm(logits_p_x_i)
            y_i = gumbel_softmax(logits_p_x_i, temperature, device, False)
        else:
            y_i = hside_check
        y_all.append(y_i)
        ones = torch.ones((n_samples, int(m_i) + 1, int(m_i) + 1))
        ones = ones.to(device)
        lt = torch.tril(ones)
        y_mask_filled = torch.matmul(y_i, lt)
        y_mask_all.append(y_mask_filled)
        x_i = torch.sum(y_mask_filled, dim=2).unsqueeze(2) - 1.0
        x_all.append(x_i)
        n_out += x_i
        m_out += m_i
    assert torch.all((n - n_out) == 0), "num elements samples NOT correct"
    return y_all, x_all, y_mask_all


if __name__ == "__main__":
    logging.basicConfig(level=log_level)
    class_names = ["violet", "yellow", "green", "red"]
    colors = ["blueviolet", "yellow", "green", "red"]
    m = torch.tensor([64, 64, 64, 64])
    n = torch.tensor(64)
    # w = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])
    # w = torch.tensor([0.7828, 3.9019, 2.0981, 0.5994, 1.0950, 1.3564, 0.3448])
    w = torch.tensor([1.3244, 3.9019, 2.0981, 0.5994, 1.0950, 1.3564, 0.3448])
    num_classes = 7
    num_samples = 512
    tau = 1.0
    dist_type = "fisher"
    if dist_type == "fisher":
        get_logits = get_logits_fisher
        fn_pre = "f"
    else:
        print("distribution type not available...exit")
        sys.exit()
    create_plot = True
    n = n.unsqueeze(0).repeat(num_samples, 1)
    w = w.unsqueeze(0).repeat(num_samples, 1)
    n_repeats = 1
    # central_experiment(m, n, n_exp)
    # noncentral_experiment(m, n, w, n_exp)
    for h in range(n_repeats):
        print(m.shape, n.shape, w.shape)
        y, x, y_m = pmf_noncentral_fmvhg(m, n, w, tau, device="cpu")
        if not create_plot:
            print()
            for i in range(num_samples):
                print(h, i)
                sum_sample = 0
                for j in range(num_classes):
                    print(y[j][i], x[j][i])
                    sum_sample += x[j][i]
                if sum_sample != n[0]:
                    print("ERROR !!!! ")
                    sys.exit()
                else:
                    print("correct: ", sum_sample)
        if create_plot:
            str_ws = [str(w_j) for w_j in list(w[0].cpu().numpy().flatten())]
            str_weights = "_".join(str_ws)
            fn_plot = "./results/pt_samples_" + fn_pre + "_" + str_weights + ".png"
            fig = plt.figure()
            ax = fig.add_subplot(1, 1, 1)
            for j in range(num_classes):
                ind_j = np.arange(m[j].cpu().numpy() + 1)
                y_avg_j = y[j].mean(dim=0).cpu().numpy().flatten()
                ax.bar(ind_j, y_avg_j, alpha=0.5, color=colors[j], label=class_names[j])
            plt.title(str_ws)
            plt.legend()
            fig.tight_layout()
            plt.draw()
            plt.savefig(fn_plot, format="png")
