import math
import torch
import numpy as np

from scipy import linalg
from scipy.stats import multivariate_normal, uniform, norm, truncexpon
from sklearn.datasets import make_swiss_roll

def calc_true_KL_div_gauss(
        myu_numerator,
        myu_denominator,
        sigma_mat_numerator,
        sigma_mat_denominator,
        dim_data
        ) -> float:
    sigma_mat_denominator_inv = linalg.inv(
        sigma_mat_denominator)
    kv = (
        np.trace(
            np.matmul(sigma_mat_denominator_inv,
                      sigma_mat_numerator))
        - dim_data
        + np.matmul(
            np.matmul(sigma_mat_denominator_inv,
                myu_denominator - myu_numerator),
                myu_denominator - myu_numerator)
        + np.log(linalg.det(sigma_mat_denominator))
        - np.log(linalg.det(sigma_mat_numerator))) / 2.0
    return kv

def calc_true_mi_info_gauss(
        rho: float,
        dim_data: int) -> float:
    sigma_denominator_mat = np.repeat(
        rho, dim_data*dim_data).reshape(dim_data, dim_data)
    np.fill_diagonal(sigma_denominator_mat, 1)
    det_sigma_denominator = linalg.det(sigma_denominator_mat)
    joint_entropy = np.log(
        2*np.pi*np.e*det_sigma_denominator)/2.0
    true_mi_info = np.log(2*np.pi*np.e*1)/2.0 - joint_entropy
    return true_mi_info

def calc_log_prob_gauss(samples, means, cov_matrix):
    gauss = multivariate_normal(
          mean=means, cov=cov_matrix)
    log_prob = gauss.logpdf(samples)
    return log_prob

def calc_cdf_gauss(x, means, cov_matrix):
    gauss = multivariate_normal(
          mean=means, cov=cov_matrix)
    cdf = gauss.cdf(x)
    return cdf

class GenDataCorrelatedGaussVsNonCorrelatedGauss(object):
    def __init__(self, rho, dim_data):
        self.means_ = np.repeat(0.0, dim_data)
        self.dim_ = dim_data
        self.rho_ = rho
        sigma_mat = np.repeat(
            rho, dim_data*dim_data).reshape(dim_data, dim_data)
        np.fill_diagonal(sigma_mat, 1)
        self.variances_ = sigma_mat
        self.cov_matrix_numerator_ = sigma_mat  
        self.cov_matrix_denominator_ = np.diag(np.ones(self.dim_))
        self.true_KL_ = calc_true_KL_div_gauss(
            self.means_,
            self.means_,
            self.cov_matrix_numerator_,
            self.cov_matrix_denominator_,
            self.dim_)

    def get_true_KL(self):
        return self.true_KL_

    def sample_gaussian(self, n_samples, means, cov_matrix):
        gauss = multivariate_normal(
              mean=means, cov=cov_matrix)
        return gauss.rvs(n_samples).astype(np.float32)

    def sample_numerator(self, n_samples):
        return self.sample_gaussian(
                  n_samples, 
                  self.means_,
                  self.cov_matrix_numerator_)

    def sample_denominator(self, n_samples):
        return self.sample_gaussian(
                  n_samples, 
                  self.means_,
                  self.cov_matrix_denominator_)

    def calc_true_density_rate(self, samples):
        de_logprob = calc_log_prob_gauss(
              samples,
              self.means_,
              self.cov_matrix_denominator_)
        nu_logprob = calc_log_prob_gauss(
              samples,
              self.means_,
              self.cov_matrix_numerator_)
        true_dre = np.exp(nu_logprob - de_logprob)
        return true_dre

    def sample(self, n):
        de = self.sample_denominator(n) 
        nu = self.sample_numerator(n)
        true_dre = self.calc_true_density_rate(de)
        return de, nu, true_dre

class GenDataShiftedGaussVsCenteredGauss(object):
    def __init__(self, rho, dim_data):
        self.means_numerator_ = np.repeat(0.0, dim_data)
        self.means_numerator_[0] = 1.0
        self.means_denominator_ = np.repeat(0.0, dim_data)
        self.dim_ = dim_data
        self.rho_ = rho
        sigma_mat = np.repeat(
            rho, dim_data*dim_data).reshape(dim_data, dim_data)
        np.fill_diagonal(sigma_mat, 1)
        self.variances_ = sigma_mat
        self.cov_matrix_numerator_ = sigma_mat
        self.cov_matrix_denominator_ = np.diag(np.ones(self.dim_))
        self.true_KL_ = calc_true_KL_div_gauss(
            self.means_numerator_,
            self.means_denominator_,
            self.cov_matrix_numerator_,
            self.cov_matrix_denominator_,
            self.dim_)

    def get_true_KL(self):
        return self.true_KL_

    def sample_gaussian(self, n_samples, means, cov_matrix):
        gauss = multivariate_normal(
              mean=means, cov=cov_matrix)
        samples = gauss.rvs(n_samples).astype(np.float32)
        return samples
    
    def calc_true_density_rate(self, samples):
        de_logprob = calc_log_prob_gauss(
              samples,
              self.means_denominator_,
              self.cov_matrix_denominator_)
        nu_logprob = calc_log_prob_gauss(
              samples,
              self.means_numerator_,
              self.cov_matrix_numerator_)
        true_dre = np.exp(nu_logprob - de_logprob)
        return true_dre

    def sample_numerator(self, n_samples):
        return self.sample_gaussian(
                  n_samples, 
                  self.means_numerator_,
                  self.cov_matrix_numerator_)

    def sample_denominator(self, n_samples):
        return self.sample_gaussian(
                  n_samples, 
                  self.means_denominator_,
                  self.cov_matrix_denominator_)

    def sample(self, n):
        de = self.sample_denominator(n)
        nu = self.sample_numerator(n)
        true_dre = self.calc_true_density_rate(de)
        return de, nu, true_dre

class GenDataTwoShiftedGausses(object):
    def __init__(self, 
                 distance_between_nu_and_de, dim_data):
        self.means_numerator_ = np.repeat(0.0, dim_data)
        self.means_numerator_[0] = distance_between_nu_and_de / 2.0
        self.means_denominator_ = np.repeat(0.0, dim_data)
        self.means_denominator_[0] = - distance_between_nu_and_de / 2.0
        self.dim_ = dim_data
        self.cov_matrix_numerator_ = np.diag(np.ones(self.dim_))
        self.cov_matrix_denominator_ = np.diag(np.ones(self.dim_))
        self.true_KL_ = calc_true_KL_div_gauss(
            self.means_numerator_,
            self.means_denominator_,
            self.cov_matrix_numerator_,
            self.cov_matrix_denominator_,
            self.dim_)

    def get_true_KL(self):
        return self.true_KL_

    def sample_gaussian(self, n_samples, means, cov_matrix):
        gauss = multivariate_normal(
              mean=means, cov=cov_matrix)
        samples = gauss.rvs(n_samples).astype(np.float32)
        return samples
    
    def calc_true_density_rate(self, samples):
        de_logprob = calc_log_prob_gauss(
              samples,
              self.means_denominator_,
              self.cov_matrix_denominator_)
        nu_logprob = calc_log_prob_gauss(
              samples,
              self.means_numerator_,
              self.cov_matrix_numerator_)
        true_dre = np.exp(nu_logprob - de_logprob)
        return true_dre

    def sample_numerator(self, n_samples):
        return self.sample_gaussian(
                  n_samples, 
                  self.means_numerator_,
                  self.cov_matrix_numerator_)

    def sample_denominator(self, n_samples):
        return self.sample_gaussian(
                  n_samples, 
                  self.means_denominator_,
                  self.cov_matrix_denominator_)

    def sample(self, n):
        de = self.sample_denominator(n)
        nu = self.sample_numerator(n)
        true_dre = self.calc_true_density_rate(de)
        return de, nu, true_dre

class GenDataTwoAllShiftedGausses_WithDifferentSD(object):
    def __init__(self, 
                 distance_between_nu_and_de,
                 denominator_sigma,
                 numerator_sigma,
                 dim_data):
        self.means_numerator_ = np.repeat(-distance_between_nu_and_de/2, dim_data)
        self.means_denominator_ = np.repeat(distance_between_nu_and_de/2, dim_data)
        self.dim_ = dim_data
        self.cov_matrix_numerator_ = np.diag(
            numerator_sigma * np.ones(self.dim_))
        self.cov_matrix_denominator_ = np.diag(
            denominator_sigma * np.ones(self.dim_))
        self.true_KL_ = calc_true_KL_div_gauss(
            self.means_numerator_,
            self.means_denominator_,
            self.cov_matrix_numerator_,
            self.cov_matrix_denominator_,
            self.dim_)

    def get_true_KL(self):
        return self.true_KL_

    def sample_gaussian(self, n_samples, means, cov_matrix):
        gauss = multivariate_normal(
              mean=means, cov=cov_matrix)
        samples = gauss.rvs(n_samples).astype(np.float32)
        return samples
    
    def calc_true_density_rate(self, samples):
        de_logprob = calc_log_prob_gauss(
              samples,
              self.means_denominator_,
              self.cov_matrix_denominator_)
        nu_logprob = calc_log_prob_gauss(
              samples,
              self.means_numerator_,
              self.cov_matrix_numerator_)
        true_dre = np.exp(nu_logprob - de_logprob)
        return true_dre

    def sample_numerator(self, n_samples):
        return self.sample_gaussian(
                  n_samples, 
                  self.means_numerator_,
                  self.cov_matrix_numerator_)

    def sample_denominator(self, n_samples):
        return self.sample_gaussian(
                  n_samples, 
                  self.means_denominator_,
                  self.cov_matrix_denominator_)

    def sample(self, n):
        de = self.sample_denominator(n)
        nu = self.sample_numerator(n)
        true_dre = self.calc_true_density_rate(de)
        return de, nu, true_dre

class GenDataTwoShiftedGausses_constKL(object):
    def __init__(self,
                 target_kl,
                 dim_data):
        self.dim_ = dim_data
        self.means_numerator_ = np.repeat(0.0, self.dim_)
        self.means_denominator_ = np.repeat(0.0, self.dim_)
        len_shift = min(self.dim_, target_kl)
        self.means_numerator_[0] = np.sqrt(target_kl/2.0) 
        self.means_denominator_[0] = -np.sqrt(target_kl/2.0)
        self.cov_matrix_numerator_ = np.diag(np.ones(self.dim_))
        self.cov_matrix_denominator_ = np.diag(np.ones(self.dim_))
        self.true_KL_ = calc_true_KL_div_gauss(
            self.means_numerator_,
            self.means_denominator_,
            self.cov_matrix_numerator_,
            self.cov_matrix_denominator_,
            self.dim_)

    def get_true_KL(self):
        return self.true_KL_

    def sample_gaussian(self, n_samples, means, cov_matrix):
        gauss = multivariate_normal(
              mean=means, cov=cov_matrix)
        samples = gauss.rvs(n_samples).astype(np.float32)
        return samples
    
    def calc_true_density_rate(self, samples):
        de_logprob = calc_log_prob_gauss(
              samples,
              self.means_denominator_,
              self.cov_matrix_denominator_)
        nu_logprob = calc_log_prob_gauss(
              samples,
              self.means_numerator_,
              self.cov_matrix_numerator_)
        true_dre = np.exp(nu_logprob - de_logprob)
        return true_dre

    def sample_numerator(self, n_samples):
        return self.sample_gaussian(
                  n_samples, 
                  self.means_numerator_,
                  self.cov_matrix_numerator_)

    def sample_denominator(self, n_samples):
        return self.sample_gaussian(
                  n_samples, 
                  self.means_denominator_,
                  self.cov_matrix_denominator_)

    def sample(self, n):
        de = self.sample_denominator(n)
        nu = self.sample_numerator(n)
        true_dre = self.calc_true_density_rate(de)
        return de, nu, true_dre

class GenDataTwoShiftedGausses_KL5(object):
    def __init__(self, 
                 dim_data):
        self.dim_ = dim_data
        self.means_numerator_ = np.repeat(0.0, self.dim_)
        self.means_denominator_ = np.repeat(0.0, self.dim_)
        len_shift = min(self.dim_, 5)
        self.means_numerator_[0:len_shift] = np.sqrt(1.0/2.0) 
        self.means_denominator_[0:len_shift] = -np.sqrt(1.0/2.0)
        self.cov_matrix_numerator_ = np.diag(np.ones(self.dim_))
        self.cov_matrix_denominator_ = np.diag(np.ones(self.dim_))
        self.true_KL_ = calc_true_KL_div_gauss(
            self.means_numerator_,
            self.means_denominator_,
            self.cov_matrix_numerator_,
            self.cov_matrix_denominator_,
            self.dim_)

    def get_true_KL(self):
        return self.true_KL_

    def sample_gaussian(self, n_samples, means, cov_matrix):
        gauss = multivariate_normal(
              mean=means, cov=cov_matrix)
        samples = gauss.rvs(n_samples).astype(np.float32)
        return samples
    
    def calc_true_density_rate(self, samples):
        de_logprob = calc_log_prob_gauss(
              samples,
              self.means_denominator_,
              self.cov_matrix_denominator_)
        nu_logprob = calc_log_prob_gauss(
              samples,
              self.means_numerator_,
              self.cov_matrix_numerator_)
        true_dre = np.exp(nu_logprob - de_logprob)
        return true_dre

    def sample_numerator(self, n_samples):
        return self.sample_gaussian(
                  n_samples, 
                  self.means_numerator_,
                  self.cov_matrix_numerator_)

    def sample_denominator(self, n_samples):
        return self.sample_gaussian(
                  n_samples, 
                  self.means_denominator_,
                  self.cov_matrix_denominator_)

    def sample(self, n):
        de = self.sample_denominator(n)
        nu = self.sample_numerator(n)
        true_dre = self.calc_true_density_rate(de)
        return de, nu, true_dre

class GenDataTwoUnitShiftedGausses_WithDifferentSD(object):
    def __init__(self, 
                 distance_between_nu_and_de,
                 denominator_rho,
                 denominator_sigma,
                 numerator_rho,
                 numerator_sigma,
                 dim_data):
        self.dim_ = dim_data
        self.means_numerator_ = np.repeat(distance_between_nu_and_de/ 2.0, self.dim_)
        self.means_denominator_ = np.repeat(-distance_between_nu_and_de/2.0, self.dim_)
        #self.means_numerator_[0] = distance_between_nu_and_de / 2.0
        #self.means_denominator_[0] = -distance_between_nu_and_de / 2.0
        #self.denominator_rho_ = denominator_rho
        denominator_sigma_mat = np.repeat(
            denominator_rho, 
            self.dim_*self.dim_).reshape(self.dim_, self.dim_)
        np.fill_diagonal(denominator_sigma_mat, denominator_sigma*denominator_sigma)
        self.cov_matrix_denominator_ = denominator_sigma_mat
        numerator_sigma_mat = np.repeat(
            numerator_rho, 
            self.dim_*self.dim_).reshape(self.dim_, self.dim_)
        np.fill_diagonal(numerator_sigma_mat, numerator_sigma*numerator_sigma)
        self.cov_matrix_numerator_ = numerator_sigma_mat
        self.true_KL_ = calc_true_KL_div_gauss(
            self.means_numerator_,
            self.means_denominator_,
            self.cov_matrix_numerator_,
            self.cov_matrix_denominator_,
            self.dim_)

    def get_true_KL(self):
        return self.true_KL_

    def sample_gaussian(self, n_samples, means, cov_matrix):
        gauss = multivariate_normal(
              mean=means, cov=cov_matrix)
        samples = gauss.rvs(n_samples).astype(np.float32)
        return samples
    
    def calc_true_density_rate(self, samples):
        de_logprob = calc_log_prob_gauss(
              samples,
              self.means_denominator_,
              self.cov_matrix_denominator_)
        nu_logprob = calc_log_prob_gauss(
              samples,
              self.means_numerator_,
              self.cov_matrix_numerator_)
        true_dre = np.exp(nu_logprob - de_logprob)
        return true_dre

    def sample_numerator(self, n_samples):
        return self.sample_gaussian(
                  n_samples, 
                  self.means_numerator_,
                  self.cov_matrix_numerator_)

    def sample_denominator(self, n_samples):
        return self.sample_gaussian(
                  n_samples, 
                  self.means_denominator_,
                  self.cov_matrix_denominator_)

    def sample(self, n):
        de = self.sample_denominator(n)
        nu = self.sample_numerator(n)
        true_dre = self.calc_true_density_rate(de)
        return de, nu, true_dre

class GenDataCenteredGaussWithDifferentSD(object):
    def __init__(self, 
                denominator_sigma,
                numerator_sigma,
                dim_data):
        self.rho_ = 0.0
        self.dim_ = dim_data
        self.denominator_sigma_ = denominator_sigma
        self.numerator_sigma_ = numerator_sigma
        self.means_numerator_ = np.repeat(0.0, self.dim_)
        self.means_denominator_ = np.repeat(0.0, self.dim_)
        cov_matrix_numerator = np.repeat(
            self.rho_, self.dim_*self.dim_).reshape(self.dim_, self.dim_)
        np.fill_diagonal(cov_matrix_numerator, self.numerator_sigma_)
        self.cov_matrix_numerator_ = cov_matrix_numerator
        cov_matrix_denominator = np.repeat(
            self.rho_, self.dim_*self.dim_).reshape(self.dim_, self.dim_)
        np.fill_diagonal(cov_matrix_denominator, self.denominator_sigma_)
        self.cov_matrix_denominator_ = cov_matrix_denominator
        self.true_KL_ = calc_true_KL_div_gauss(
            self.means_numerator_,
            self.means_denominator_,
            self.cov_matrix_numerator_,
            self.cov_matrix_denominator_,
            self.dim_)

    def get_true_KL(self):
        return self.true_KL_

    def sample_gaussian(self, n_samples, means, cov_matrix):
        gauss = multivariate_normal(
              mean=means, cov=cov_matrix)
        samples = gauss.rvs(n_samples).astype(np.float32)
        return samples
    
    def calc_true_density_rate(self, samples):
        de_logprob = calc_log_prob_gauss(
              samples,
              self.means_denominator_,
              self.cov_matrix_denominator_)
        nu_logprob = calc_log_prob_gauss(
              samples,
              self.means_numerator_,
              self.cov_matrix_numerator_)
        true_dre = np.exp(nu_logprob - de_logprob)
        return true_dre

    def sample_numerator(self, n_samples):
        return self.sample_gaussian(
                  n_samples, 
                  self.means_numerator_,
                  self.cov_matrix_numerator_)

    def sample_denominator(self, n_samples):
        return self.sample_gaussian(
                  n_samples, 
                  self.means_denominator_,
                  self.cov_matrix_denominator_)

    def sample(self, n):
        de = self.sample_denominator(n)
        nu = self.sample_numerator(n)
        true_dre = self.calc_true_density_rate(de)
        return de, nu, true_dre

class GenDataMultiGaussVsSingleGauss(object):
    def __init__(self, 
            rho_numerator_, 
            sigma_numerator_,
            rho_denominator_,
            sigma_denominator_,
            n_modal_numerator_,
            dim_data, ): 
        p_vec_org = np.random.uniform(size=n_modal_numerator_)
        self.p_vec_ = p_vec_org / np.sum(p_vec_org)
        self.n_modal_numerator_ = n_modal_numerator_
        self.means_numerator_mat_ = np.random.uniform(
                size=(n_modal_numerator_, dim_data))
        self.means_denominator_ = np.repeat(0.0, dim_data)
        self.dim_ = dim_data
        self.rho_numerator_ = rho_numerator_
        self.sigma_numerator_ = sigma_numerator_
        sigma_numerator_mat = np.repeat(
            self.rho_numerator_,
            dim_data*dim_data).reshape(dim_data, dim_data)
        np.fill_diagonal(sigma_numerator_mat, self.sigma_numerator_)
        self.cov_matrix_numerator_ = sigma_numerator_mat
        
        self.rho_denominator_ = rho_denominator_
        self.sigma_denominator_ = sigma_denominator_
        sigma_denominator_mat = np.repeat(
            self.rho_denominator_,
            dim_data*dim_data).reshape(dim_data, dim_data)
        np.fill_diagonal(sigma_denominator_mat, self.sigma_denominator_)
        self.cov_matrix_denominator_= sigma_denominator_mat

        self.true_KL_  = 0
        for _i_mdl in range(n_modal_numerator_):
            self.true_KL_  += calc_true_KL_div_gauss(
            self.means_numerator_,
            self.means_denominator_,
            self.cov_matrix_numerator_,
            self.cov_matrix_denominator_,
            self.dim_)

    def get_true_KL(self):
        return self.true_KL_

    def sample_gaussian(self, n_samples, means, cov_matrix):
        gauss = multivariate_normal(
              mean=means, cov=cov_matrix)
        samples = gauss.rvs(n_samples).astype(np.float32)
        return samples

    def sample_numerator(self, n_samples):
        sample_each_modal_list = []
        for _i_mdl in range(self.n_modal_numerator_):
            samples_each_modal = self.sample_gaussian(
                    n_samples, 
                    self.means_numerator_mat_[_i_mdl, :],
                    self.cov_matrix_numerator_)
            sample_each_modal_list.append(samples_each_modal)
        sample_each_modal_arr = np.array(sample_each_modal_list)
        means_numerator_idxs = np.random.choice(
            self.n_modal_numerator_, n_samples).reshape(n_samples, 1)
        tmp_zeros = np.zeros(
            (n_samples, self.n_modal_numerator_))
        np.put_along_axis(tmp_zeros, means_numerator_idxs, 1, axis=1)
        one_hot_to_chose = np.expand_dims(tmp_zeros.T, 2)
        samples =  np.sum(
            sample_each_modal_arr * one_hot_to_chose, axis=0).astype(np.float32) 
        logprob = self.calc_log_prob_numerator(samples)
        return samples, logprob

    def sample_denominator(self, n_samples):
        samples = self.sample_gaussian(
                  n_samples, 
                  self.means_denominator_,
                  self.cov_matrix_denominator_)
        logprob = self.calc_log_prob_denominator(samples)
        return samples, logprob

    def calc_log_prob_denominator(self, samples):
        logprob = calc_log_prob_gauss(
                    samples,
                    self.means_denominator_,
                    self.cov_matrix_denominator_)
        return logprob
    
    def calc_log_prob_numerator(self, samples):
        logprob_each_modal_list = []
        for _i_mdl in range(self.n_modal_numerator_):
            logprob_each_modal = calc_log_prob_gauss(
                        samples,
                        self.means_numerator_mat_[_i_mdl, :],
                        self.cov_matrix_numerator_)
            logprob_each_modal_list.append(logprob_each_modal)
        logprob_each_modal_arr = np.array(logprob_each_modal_list).T
        logprob = np.sum(
            logprob_each_modal_arr * self.p_vec_, axis=1).astype(np.float32)
        return logprob
    
    def sample(self, n):
        de, log_prob_de = self.sample_denominator(n)
        nu, log_prob_nu = self.sample_numerator(n)
        de_logprob = self.calc_log_prob_denominator(de)
        nu_logprob = self.calc_log_prob_numerator(de)
        true_dre = np.exp(nu_logprob - de_logprob)
        return de, nu, true_dre, log_prob_de, log_prob_nu

class GenDataMultiGaussVsSingleGauss_with_KL(object):
    def __init__(self, 
            KL_div,
            dim_data, 
            n_modal_numerator,
            ): 
        self.dim_ = dim_data
        self.n_modal_numerator_ = n_modal_numerator
        vec = np.random.uniform(
                size=(self.n_modal_numerator_, self.dim_))
        norm_vec = np.reshape(
            np.linalg.norm(vec, axis=1), 
            (self.n_modal_numerator_, 1))
        self.means_numerator_mat_ = np.sqrt(KL_div * 2.0) * (vec / norm_vec)
        self.means_denominator_ = np.repeat(0.0, dim_data)
        self.cov_matrix_denominator_  = np.diag(np.ones(self.dim_))
        self.cov_matrix_numerator_ = np.diag(np.ones(self.dim_))
        self.true_KL_  = 0
        for _i_mdl in range(self.n_modal_numerator_):
            self.true_KL_  += calc_true_KL_div_gauss(
            self.means_numerator_mat_[_i_mdl,:],
            self.means_denominator_,
            self.cov_matrix_numerator_,
            self.cov_matrix_denominator_,
            self.dim_)/self.n_modal_numerator_

    def get_true_KL(self):
        return self.true_KL_

    def sample_gaussian(self, n_samples, means, cov_matrix):
        gauss = multivariate_normal(
              mean=means, cov=cov_matrix)
        samples = gauss.rvs(n_samples).astype(np.float32)
        return samples

    def sample_numerator(self, n_samples):
        sample_each_modal_list = []
        for _i_mdl in range(self.n_modal_numerator_):
            samples_each_modal = self.sample_gaussian(
                    n_samples, 
                    self.means_numerator_mat_[_i_mdl, :],
                    self.cov_matrix_numerator_)
            sample_each_modal_list.append(samples_each_modal)
        sample_each_modal_arr = np.array(sample_each_modal_list)
        means_numerator_idxs = np.random.choice(
            self.n_modal_numerator_, n_samples).reshape(n_samples, 1)
        tmp_zeros = np.zeros(
            (n_samples, self.n_modal_numerator_))
        np.put_along_axis(tmp_zeros, means_numerator_idxs, 1, axis=1)
        one_hot_to_chose = np.expand_dims(tmp_zeros.T, 2)
        samples =  np.sum(
            sample_each_modal_arr * one_hot_to_chose, axis=0).astype(np.float32) 
        logprob = self.calc_log_prob_numerator(samples)
        return samples, logprob

    def sample_denominator(self, n_samples):
        samples = self.sample_gaussian(
                  n_samples, 
                  self.means_denominator_,
                  self.cov_matrix_denominator_)
        logprob = self.calc_log_prob_denominator(samples)
        return samples, logprob

    def calc_log_prob_denominator(self, samples):
        logprob = calc_log_prob_gauss(
                    samples,
                    self.means_denominator_,
                    self.cov_matrix_denominator_)
        return logprob
    
    def calc_log_prob_numerator(self, samples):
        logprob_each_modal_list = []
        for _i_mdl in range(self.n_modal_numerator_):
            logprob_each_modal = calc_log_prob_gauss(
                        samples,
                        self.means_numerator_mat_[_i_mdl, :],
                        self.cov_matrix_numerator_)
            logprob_each_modal_list.append(logprob_each_modal)
        logprob_each_modal_arr = np.array(logprob_each_modal_list).T
        logprob = np.sum(
            logprob_each_modal_arr, axis=1).astype(np.float32)/self.n_modal_numerator_
        return logprob
    
    def sample(self, n):
        de, log_prob_de = self.sample_denominator(n)
        nu, log_prob_nu = self.sample_numerator(n)
        de_logprob = self.calc_log_prob_denominator(de)
        nu_logprob = self.calc_log_prob_numerator(de)
        true_dre = np.exp(nu_logprob - de_logprob)
        return de, nu, true_dre, log_prob_de, log_prob_nu

class GenDataMultiGaussVsUniform(object):
    def __init__(self, rho, n_modal, 
                uniform_lower_support,
                uniform_upper_support,
                dim_data): 
        p_vec_org = np.random.uniform(size=n_modal)
        self.p_vec_ = p_vec_org / np.sum(p_vec_org)
        self.n_modal_ = n_modal
        self.means_numerator_mat_ = np.random.uniform(
                size=(n_modal, dim_data))
        self.dim_ = dim_data
        self.rho_ = rho
        sigma_mat = np.repeat(
            rho, dim_data*dim_data).reshape(dim_data, dim_data)
        np.fill_diagonal(sigma_mat, 1)
        self.variances_ = sigma_mat
        self.cov_matrix_numerator_ = sigma_mat
        self.uniform_loc_ = uniform_lower_support
        self.uniform_scale_ = (
            uniform_upper_support - uniform_lower_support)

    def sample_gaussian(self, n_samples, means, cov_matrix):
        gauss = multivariate_normal(
              mean=means, cov=cov_matrix)
        samples = gauss.rvs(n_samples).astype(np.float32)
        return samples

    def sample_numerator(self, n_samples):
        sample_each_modal_list = []
        for _i_mdl in range(self.n_modal_):
            samples_each_modal = self.sample_gaussian(
                    n_samples, 
                    self.means_numerator_mat_[_i_mdl, :],
                    self.cov_matrix_numerator_)
            sample_each_modal_list.append(samples_each_modal)
        sample_each_modal_arr = np.array(sample_each_modal_list)
        means_numerator_idxs = np.random.choice(
            self.n_modal_, n_samples).reshape(n_samples, 1)
        tmp_zeros = np.zeros(
            (n_samples, self.n_modal_))
        np.put_along_axis(tmp_zeros, means_numerator_idxs, 1, axis=1)
        one_hot_to_chose = np.expand_dims(tmp_zeros.T, 2)
        samples =  np.sum(
            sample_each_modal_arr * one_hot_to_chose, axis=0).astype(np.float32)
        logprob = self.calc_log_prob_numerator(samples)
        return samples, logprob

    def sample_denominator(self, n_samples):
        sample_each_col_list = []
        for _i_mdl in range(self.dim_):
            samples_each_col = uniform.rvs(
                    loc=self.uniform_loc_,
                    scale=self.uniform_scale_,
                    size=n_samples).reshape(n_samples, 1)
            sample_each_col_list.append(samples_each_col)
        samples = np.concatenate(
            sample_each_col_list, axis=1)
        logprob = self.calc_log_prob_denominator(samples)
        return samples, logprob


    def calc_log_prob_denominator(self, samples):
        logprob = - self.dim_*np.log(
                        self.uniform_scale_).astype(np.float32)    
        return logprob
    
    def calc_log_prob_numerator(self, samples):
        logprob_each_modal_list = []
        for _i_mdl in range(self.n_modal_):
            logprob_each_modal = calc_log_prob_gauss(
                        samples,
                        self.means_numerator_mat_[_i_mdl, :],
                        self.cov_matrix_numerator_)
            logprob_each_modal_list.append(logprob_each_modal)
        logprob_each_modal_arr = np.array(logprob_each_modal_list).T
        logprob = np.sum(
            logprob_each_modal_arr * self.p_vec_, axis=1).astype(np.float32)
        return logprob
    
    def sample(self, n):
        de, log_prob_de = self.sample_denominator(n)
        nu, log_prob_nu = self.sample_numerator(n)
        de_logprob = self.calc_log_prob_denominator(de)
        nu_logprob = self.calc_log_prob_numerator(de)
        true_dre = np.exp(nu_logprob - de_logprob)
        return de, nu, true_dre, log_prob_de, log_prob_nu

class GenDataSingleGaussVsUniform(object):
    def __init__(self, 
                sigma,
                uniform_lower_support,
                uniform_upper_support,
                dim_data, 
                ): 
        self.dim_ = dim_data
        self.sigma_ = sigma
        self.means_numerator_ = np.repeat(0.0, self.dim_)
        sigma_mat = np.repeat(
           0.0, self.dim_*self.dim_).reshape(self.dim_, self.dim_)
        np.fill_diagonal(sigma_mat, sigma*sigma)
        self.cov_matrix_numerator_ = sigma_mat
        self.uniform_lower_support_ = uniform_lower_support
        self.uniform_upper_support_ = uniform_upper_support
        self.uniform_loc_ = uniform_lower_support
        self.uniform_scale_ = (
            uniform_upper_support - uniform_lower_support)
        #prob_over_support = self._calc_all_prob_numerator_gauss_over_support()
        self.true_KL_  = (
                -  self.dim_ * np.log(self.uniform_scale_)
                +  self.dim_ * np.log(np.sqrt(2*np.pi)*sigma)
                +  self.dim_ * (
                        (np.power(self.uniform_upper_support_, 3.0)
                        - np.power(self.uniform_lower_support_, 3.0)) / (6.0*sigma**2)
                        / self.uniform_scale_
                    ) #/ np.power(self.uniform_scale_, self.dim_)
                #+  np.log(self._calc_all_prob_numerator_gauss_over_support())
                + self._calc_log_all_prob_numerator_gauss_over_support()
            )

    def _calc_log_all_prob_numerator_gauss_over_support(self):
        gauss = norm(0.0, self.sigma_)
        unit_prob = (
            gauss.cdf(self.uniform_upper_support_) 
            - gauss.cdf(self.uniform_lower_support_))
        return self.dim_ * np.log(unit_prob)

    def get_true_KL(self):
        return self.true_KL_

    def sample_gaussian(self, n_samples, means, cov_matrix):
        sample_list = []
        n_sample_all = 0
        #print(n_samples)
        gauss = multivariate_normal(
            mean=means, cov=cov_matrix)
        while n_sample_all < n_samples:
            tmp = gauss.rvs(n_samples).astype(np.float32)
            cond_in_range = np.all(
                (tmp <= self.uniform_upper_support_)
                & (tmp >= self.uniform_lower_support_),
                axis=1) 
            tmp_in_range = tmp[cond_in_range, :]
            sample_list.append(tmp_in_range)
            n_sample_all += tmp_in_range.shape[0]
            #print(n_sample_all)
            #print(n_samples)
        sample_all = np.concatenate(
            sample_list, axis=0)
        #print(sample_all.shape)
        samples = sample_all[0:n_samples, :]
        #print(samples.shape)
        return samples

    def sample_numerator(self, n_samples):
        samples = self.sample_gaussian(
                    n_samples, 
                    self.means_numerator_,
                    self.cov_matrix_numerator_)
        logprob = self.calc_log_prob_numerator(samples)        
        return samples, logprob

    def sample_denominator(self, n_samples):
        sample_each_col_list = []
        for _i_mdl in range(self.dim_):
            samples_each_col = uniform.rvs(
                    loc=self.uniform_loc_,
                    scale=self.uniform_scale_,
                    size=n_samples).reshape(n_samples, 1)
            sample_each_col_list.append(samples_each_col)
        samples = np.concatenate(
            sample_each_col_list, axis=1).astype(np.float32)
        logprob = self.calc_log_prob_denominator(samples)
        return samples, logprob

    def calc_log_prob_denominator(self, samples):
        logprob = - self.dim_*np.log(
                        self.uniform_scale_).astype(np.float32)    
        return logprob
    
    def calc_log_prob_numerator(self, samples):
        logprob = calc_log_prob_gauss(
                            samples,
                            self.means_numerator_,
                            self.cov_matrix_numerator_
                        ).astype(np.float32)
        return logprob
    
    def sample(self, n):
        de, log_prob_de = self.sample_denominator(n)
        nu, log_prob_nu = self.sample_numerator(n)
        de_logprob = self.calc_log_prob_denominator(de)
        nu_logprob = self.calc_log_prob_numerator(de)
        true_dre = np.exp(nu_logprob - de_logprob)
        return de, nu, true_dre, log_prob_de, log_prob_nu

class GenDataUniformVsSingleGauss(object):
    def __init__(self, rho, dim_data, 
                uniform_lower_support,
                uniform_upper_support): 
        self.dim_ = dim_data
        self.means_denominator_ = np.repeat(0.0, self.dim_)
        self.rho_ = rho
        sigma_mat = np.repeat(
            self.rho_, self.dim_*self.dim_).reshape(self.dim_, self.dim_)
        np.fill_diagonal(sigma_mat, 1)
        self.cov_matrix_denominator_ = sigma_mat
        self.uniform_loc_ = uniform_lower_support
        self.uniform_scale_ = (
            uniform_upper_support - uniform_lower_support)

    def sample_gaussian(self, n_samples, means, cov_matrix):
        gauss = multivariate_normal(
              mean=means, cov=cov_matrix)
        samples = gauss.rvs(n_samples).astype(np.float32)
        return samples

    def sample_denominator(self, n_samples):
        samples =  self.sample_gaussian(
                    n_samples, 
                    self.means_denominator_,
                    self.cov_matrix_denominator_)
        logprob = self.calc_log_prob_denominator(samples)
        return samples, logprob

    def sample_numerator(self, n_samples):
        sample_each_col_list = []
        for _i_mdl in range(self.dim_):
            samples_each_col = uniform.rvs(
                    loc=self.uniform_loc_,
                    scale=self.uniform_scale_,
                    size=n_samples).reshape(n_samples, 1)
            sample_each_col_list.append(samples_each_col)
        samples = np.concatenate(
            sample_each_col_list, axis=1)
        logprob = self.calc_log_prob_numerator(samples)
        return samples, logprob

    def calc_log_prob_denominator(self, samples):
        logprob = calc_log_prob_gauss(
                            samples,
                            self.means_denominator_,
                            self.cov_matrix_denominator_
                        ).astype(np.float32)
        return logprob

    def calc_log_prob_numerator(self, samples):
        logprob = - self.dim_*np.log(
                        self.uniform_scale_).astype(np.float32)    
        return logprob
    
    def sample(self, n):
        de, log_prob_de = self.sample_denominator(n)
        nu, log_prob_nu = self.sample_numerator(n)
        de_logprob = self.calc_log_prob_denominator(de)
        nu_logprob = self.calc_log_prob_numerator(de)
        true_dre = np.exp(nu_logprob - de_logprob)
        return de, nu, true_dre, log_prob_de, log_prob_nu

class GenDataSingleGaussVsUniform(object):
    def __init__(self, 
                sigma,
                uniform_lower_support,
                uniform_upper_support,
                dim_data, 
                ): 
        self.dim_ = dim_data
        self.sigma_ = sigma
        self.means_numerator_ = np.repeat(0.0, self.dim_)
        sigma_mat = np.repeat(
           0.0, self.dim_*self.dim_).reshape(self.dim_, self.dim_)
        np.fill_diagonal(sigma_mat, sigma*sigma)
        self.cov_matrix_numerator_ = sigma_mat
        self.uniform_lower_support_ = uniform_lower_support
        self.uniform_upper_support_ = uniform_upper_support
        self.uniform_loc_ = uniform_lower_support
        self.uniform_scale_ = (
            uniform_upper_support - uniform_lower_support)
        #prob_over_support = self._calc_all_prob_numerator_gauss_over_support()
        self.true_KL_  = (
                -  self.dim_ * np.log(self.uniform_scale_)
                +  self.dim_ * np.log(np.sqrt(2*np.pi)*sigma)
                +  self.dim_ * (
                        (np.power(self.uniform_upper_support_, 3.0)
                        - np.power(self.uniform_lower_support_, 3.0)) / (6.0*sigma**2)
                        / self.uniform_scale_
                    ) #/ np.power(self.uniform_scale_, self.dim_)
                #+  np.log(self._calc_all_prob_numerator_gauss_over_support())
                + self._calc_log_all_prob_numerator_gauss_over_support()
            )

    def _calc_log_all_prob_numerator_gauss_over_support(self):
        gauss = norm(0.0, self.sigma_)
        unit_prob = (
            gauss.cdf(self.uniform_upper_support_) 
            - gauss.cdf(self.uniform_lower_support_))
        return self.dim_ * np.log(unit_prob)

    def get_true_KL(self):
        return self.true_KL_

    def sample_gaussian(self, n_samples, means, cov_matrix):
        sample_list = []
        n_sample_all = 0
        #print(n_samples)
        gauss = multivariate_normal(
            mean=means, cov=cov_matrix)
        while n_sample_all < n_samples:
            tmp = gauss.rvs(n_samples).astype(np.float32)
            cond_in_range = np.all(
                (tmp <= self.uniform_upper_support_)
                & (tmp >= self.uniform_lower_support_),
                axis=1) 
            tmp_in_range = tmp[cond_in_range, :]
            sample_list.append(tmp_in_range)
            n_sample_all += tmp_in_range.shape[0]
            #print(n_sample_all)
            #print(n_samples)
        sample_all = np.concatenate(
            sample_list, axis=0)
        #print(sample_all.shape)
        samples = sample_all[0:n_samples, :]
        #print(samples.shape)
        return samples

    def sample_numerator(self, n_samples):
        samples = self.sample_gaussian(
                    n_samples, 
                    self.means_numerator_,
                    self.cov_matrix_numerator_)
        logprob = self.calc_log_prob_numerator(samples)        
        return samples, logprob

    def sample_denominator(self, n_samples):
        sample_each_col_list = []
        for _i_mdl in range(self.dim_):
            samples_each_col = uniform.rvs(
                    loc=self.uniform_loc_,
                    scale=self.uniform_scale_,
                    size=n_samples).reshape(n_samples, 1)
            sample_each_col_list.append(samples_each_col)
        samples = np.concatenate(
            sample_each_col_list, axis=1).astype(np.float32)
        logprob = self.calc_log_prob_denominator(samples)
        return samples, logprob

    def calc_log_prob_denominator(self, samples):
        logprob = - self.dim_*np.log(
                        self.uniform_scale_).astype(np.float32)    
        return logprob
    
    def calc_log_prob_numerator(self, samples):
        logprob = calc_log_prob_gauss(
                            samples,
                            self.means_numerator_,
                            self.cov_matrix_numerator_
                        ).astype(np.float32)
        return logprob
    
    def sample(self, n):
        de, log_prob_de = self.sample_denominator(n)
        nu, log_prob_nu = self.sample_numerator(n)
        de_logprob = self.calc_log_prob_denominator(de)
        nu_logprob = self.calc_log_prob_numerator(de)
        true_dre = np.exp(nu_logprob - de_logprob)
        return de, nu, true_dre, log_prob_de, log_prob_nu

class GenDataTruncExponential(object):
    def __init__(self,
                 denominator_lamda,
                 numerator_lamda,
                 right_end_point,
                 dim_data, 
                 ): 
        self.dim_ = dim_data            
        self.denominator_lamda_ = denominator_lamda
        self.numerator_lamda_ = numerator_lamda
        cum_prob_denominator = self._cal_cdf_expon(
            right_end_point, self.denominator_lamda_)
        cum_prob_numerator = self._cal_cdf_expon(
            right_end_point, self.numerator_lamda_)
        self.true_KL_ = self.dim_ * (
            (self.denominator_lamda_
             - self.numerator_lamda_) / self.numerator_lamda_ 
            + 
            np.log(cum_prob_denominator / cum_prob_numerator))
        self.right_end_point_ =  right_end_point    

    def _cal_cdf_expon(self, x, lamda):
        return (1 - np.exp(- lamda * x))/lamda

    def _sample_multi_truncexpon(self, lamda, n_samples):
        rv = truncexpon(b=self.right_end_point_*lamda, scale=1/lamda)
        sample_each_col_list = []
        for _i_dim in range(self.dim_):
            tmp = rv.rvs(size=n_samples) 
            samples_each_col = tmp.reshape(n_samples, 1)
            sample_each_col_list.append(samples_each_col)
        samples = np.concatenate(
            sample_each_col_list, axis=1)
        return samples

    def _calc_log_prob_multi_truncexpon(self, x_mat, lamda):
        rv = truncexpon(b=self.right_end_point_*lamda, scale=1/lamda)
        log_const_norml = np.log(
            self._cal_cdf_expon(self.right_end_point_, lamda))
        log_prob_for_each_col_list = []
        for _i_dim in range(self.dim_):
            x = x_mat[:, _i_dim]
            log_prob_for_each_col = rv.logpdf(x).reshape(x_mat.shape[0], 1)
            log_prob_for_each_col_list.append(log_prob_for_each_col)
        log_prob_all = np.concatenate(
            log_prob_for_each_col_list, axis=1)
        logprob = np.sum(log_prob_all, axis=1)
        return logprob   

    def get_true_KL(self):
        return self.true_KL_

    def sample_denominator(self, n_samples):
        samples = self._sample_multi_truncexpon(
                    self.denominator_lamda_,
                    n_samples).astype(np.float32)  
        return samples

    def sample_numerator(self, n_samples):
        samples = self._sample_multi_truncexpon(
                    self.numerator_lamda_,
                    n_samples).astype(np.float32)  
        return samples

    def calc_log_prob_denominator(self, samples):
        logprob = self._calc_log_prob_multi_truncexpon(
                            samples,
                            self.denominator_lamda_
                        ).astype(np.float32)
        return logprob

    def calc_log_prob_numerator(self, samples):
        logprob = self._calc_log_prob_multi_truncexpon(
                            samples,
                            self.numerator_lamda_
                        ).astype(np.float32) 
        return logprob
    
    def sample(self, n):
        de = self.sample_denominator(n)
        nu = self.sample_numerator(n)
        de_logprob = self.calc_log_prob_denominator(de)
        nu_logprob = self.calc_log_prob_numerator(de)
        true_dre = np.exp(nu_logprob - de_logprob)
        return de, nu, true_dre

class GenSwissRoll():
    def __init__(self): 
        pass

    def sample(self, n_samples):
        tmp_nu = make_swiss_roll(
                    n_samples=n_samples, 
                    noise=0.25
                )[0]
        tmp_nu = tmp_nu.astype('float32')[:, [0, 2]]
        tmp_nu /= 7.5
        nu = tmp_nu
        gauss = multivariate_normal(
            mean=[0, 0], cov=[[1, 0], [0, 1]])
        de = gauss.rvs(n_samples).astype(np.float32)
        return de, nu
    
# class GenSwissRoll():
#     def __init__(self): 
#         pass

#     def sample(self, n_samples):
#         tmp_de = make_swiss_roll(
#                     n_samples=n_samples, 
#                     noise=0.25
#                 )[0]
#         tmp_de = tmp_de.astype('float32')[:, [0, 2]]
#         tmp_de /= 7.5
#         de = tmp_de
#         gauss = multivariate_normal(
#             mean=[0, 0], cov=[[1, 0], [0, 1]])
#         nu = gauss.rvs(n_samples).astype(np.float32)
#         return de, nu