
import sys
import numpy as np
import matplotlib.pyplot as plt
import logging
from scipy.stats import hypergeom

import torch
import torch.nn.functional as F

from os.path import dirname
sys.path.append(dirname(dirname(__file__)))

relu = torch.nn.ReLU()
sm = torch.nn.Softmax(dim=-1)


def sample_gumbel(shape, device, eps=1e-20):
    U = torch.rand(shape).to(device)
    return -torch.log(-torch.log(U + eps) + eps)


def gumbel_softmax_sample(logits, temperature, gumbel_noise, device):
    gs = sample_gumbel(logits.size(), device)
    if gumbel_noise:
        logits = logits + gs
    # y = sm(logits)
    return F.softmax(logits / temperature, dim=-1)


def gumbel_softmax(logits, temperature, device, gumbel_noise=True):
    """
    input: [*, n_class]
    return: [*, n_class] an one-hot vector
    """
    y = gumbel_softmax_sample(logits, temperature, gumbel_noise, device)
    shape = y.size()
    _, ind = y.max(dim=-1)
    y_hard = torch.zeros_like(y).view(-1, shape[-1])
    y_hard.scatter_(1, ind.view(-1, 1), 1)
    y_hard = y_hard.view(*shape)
    return (y_hard - y).detach() + y

def get_logits(x, n, m, m2, w, hside_check, device, eps=1e-20):
    n_samples = w.shape[0]
    x_all  = torch.arange(x + 1, dtype=torch.float64).to(device)
    x_all = x_all.unsqueeze(0).repeat(n_samples, 1, 1)
    log_x1_fact = torch.lgamma(x_all + 1.0)
    log_m1_x1_fact = torch.lgamma(m - x_all + 1.0)
    log_x2_fact = torch.lgamma(relu(n - x_all) + 1.0)
    log_m2_x2_fact = torch.lgamma(m2 - relu(n - x_all) + 1.0)
    log_m_x = log_x1_fact + log_x2_fact + log_m1_x1_fact + log_m2_x2_fact
    x1_log_w = x_all * torch.log(w)
    log_p_shifted = x1_log_w - log_m_x
    logging.info('log_p_shifted.shape: \n' + str(log_p_shifted.shape))
    logging.info('log_p_shifted: \n' + str(log_p_shifted[0]))
    log_p_shifted_sup = log_p_shifted + torch.log(hside_check)
    logging.info('log_p_shifted_sup.shape: \n' + str(log_p_shifted_sup.shape))
    logging.info('log_p_shifted_sup: \n' + str(log_p_shifted_sup[0]))
    return log_p_shifted_sup


def pmf_noncentral_fmvhg(m_all, n, w_all, temperature, device='cuda'):
    n_c = m_all.shape[0]
    n = n.unsqueeze(1)
    n_samples = w_all.shape[0]
    n_out = torch.zeros((n_samples, 1, 1), device=device)
    m_out = torch.tensor(0.0).to(device)
    w_all = w_all.unsqueeze(1)
    y_all = []
    x_all = []
    for i in range(0, n_c):
        n_new = n - n_out
        m_i = m_all[i]
        w_i = w_all[:, :, i].unsqueeze(1)
        x_i = torch.arange(m_i + 1).to(device)
        x_i = x_i.unsqueeze(0).unsqueeze(1).repeat(n_samples, 1, 1)
        if i+1 < n_c:
            m_rest = torch.sum(m_all) - m_out - m_i
        else:
            m_rest = torch.tensor(0.0)
        N_new = m_i + m_rest
        if i+1 < n_c:
            w_rest_enum = torch.sum(w_all[:, :, i+1:]*m_all[i+1:],
                                    dim=2).unsqueeze(1)
            w_rest_denom = N_new-m_i
            w_rest = w_rest_enum/(w_rest_denom)
            w = w_i/(w_rest)
            logging.info('w_rest: \n' + str(w_rest[0]))
        else:
            w = torch.ones((n_samples, 1, 1)).to(device)

        x_rest = relu(n_new - x_i)
        # DEBUG: Implement differentiable way to enforce constraints
        # TODO: check for differentiability
        check_x_i = n_new - x_i
        check_x_rest = m_rest - x_rest
        hside_check_x_i = torch.heaviside(check_x_i, torch.tensor(1.0))
        hside_check_x_rest = torch.heaviside(check_x_rest, torch.tensor(1.0))
        hside_check = hside_check_x_i * hside_check_x_rest
        logging.info('check x_i: \n' + str( hside_check_x_i[0]))
        logging.info('check x_rest: \n' + str(hside_check_x_rest[0]))
        logging.info('check combined: \n' + str(hside_check[0]))
        if i+1 < n_c:
            logging.info('m_i: \n' + str(m_i))
            logging.info('m_rest: \n' + str(m_rest))
            logging.info('n_i: \n' + str(n_new[0]))
            logging.info('x_i: \n' + str(x_i[0]))
            logging.info('x_rest: \n' + str(x_rest[0]))
            logging.info('w_i: \n' + str(w_i[0]))
            logging.info('w: \n' + str(w[0]))

            # m_x = torch.exp(log_m_x_i + log_m_x_rest)
            logits_p_x_i = get_logits(m_i, n_new, m_i, m_rest, w, hside_check,
                                     device)
            p_x_i = sm(logits_p_x_i)
            logging.info('logits(p_x): \n' + str(logits_p_x_i[0]))
            logging.info('sum(p_x_i): \n' + str(p_x_i[0].sum()))
            logging.info('p_x_i: \n' + str(p_x_i[0]))

            # sample x_j given distribution weights
            y_i = gumbel_softmax(logits_p_x_i, temperature,  device, True)
        else:
            y_i = hside_check.double()
        y_all.append(y_i)
        ones = torch.ones((n_samples, int(m_i)+1, int(m_i)+1))
        ones = ones.to(device)
        lt = torch.tril(ones).double()
        mask_filled = torch.matmul(y_i, lt)
        x_i = torch.sum(mask_filled, dim=2).unsqueeze(2) - 1.0
        x_all.append(x_i)
        n_out += x_i
        m_out += m_i

        logging.info('y_i: \n' + str(y_i[0]))
        logging.info('x_i: \n' + str(x_i[0]))
    assert torch.all((n-n_out) == 0), 'num elements samples NOT correct' 
    return y_all, x_all


if __name__ == '__main__':
    log_level = logging.INFO
    logging.basicConfig(level=log_level)
    class_names = ['violet', 'yellow', 'green']
    colors = ['blueviolet', 'yellow', 'green']
    m = torch.tensor([7, 7, 7])
    n = torch.tensor(10)
    w = torch.tensor([1.0, 1.0, 1.0])
    num_classes = 3
    num_samples = 10000
    tau = 1.0
    create_plot = True
    n = n.unsqueeze(0).repeat(num_samples, 1)
    w = w.unsqueeze(0).repeat(num_samples, 1)
    n_repeats = 1
    # central_experiment(m, n, n_exp)
    # noncentral_experiment(m, n, w, n_exp)
    for h in range(n_repeats):
        y, x = pmf_noncentral_fmvhg(m, n, w, tau, device='cpu')
        if not create_plot:
            print()
            for i in range(num_samples):
                print(h, i)
                sum_sample = 0
                for j in range(num_classes):
                    print(y[j][i], x[j][i])
                    sum_sample += x[j][i]
                if sum_sample != n[0]:
                    print('ERROR !!!! ')
                    sys.exit()
                else:
                    print('correct: ', sum_sample)
        if create_plot:
            str_ws = [str(w_j) for w_j in list(w[0].cpu().numpy().flatten())]
            str_weights = '_'.join(str_ws)
            fn_plot = './results/pt_samples_f_' + str_weights + '.png'
            fig = plt.figure()
            ax = fig.add_subplot(1,1,1)
            for j in range(num_classes):
                ind_j = np.arange(m[j].cpu().numpy()+1)
                y_avg_j = y[j].mean(dim=0).cpu().numpy().flatten()
                ax.bar(ind_j, y_avg_j, alpha=0.5, color=colors[j], label=class_names[j])
            plt.title(str_ws)
            plt.legend()
            fig.tight_layout()
            plt.draw()
            plt.savefig(fn_plot, format='png')





