import math
import torch
import numpy as np

from scipy import linalg
from scipy.stats import multivariate_normal

def calc_true_mi_info(
        rho: float,
        dim_data: int) -> float:
    sigma_numerator_mat = np.repeat(
        rho, dim_data*dim_data).reshape(dim_data, dim_data)
    np.fill_diagonal(sigma_numerator_mat, 1)
    det_sigma_numerator = linalg.det(sigma_numerator_mat)
    true_mi_info = - np.log(det_sigma_numerator)/2.0
    return true_mi_info

def calc_log_prob(samples, means, cov_matrix):
    gauss = multivariate_normal(
          mean=means, cov=cov_matrix)
    log_prob = gauss.logpdf(samples)
    return log_prob

class GenDataExpAlpha(object):
    def __init__(self, rho, dim_data):
        self.means_ = np.repeat(0.0, dim_data)
        self.dim_ = dim_data
        self.rho_ = rho
        sigma_mat = np.repeat(
            rho, dim_data*dim_data).reshape(dim_data, dim_data)
        np.fill_diagonal(sigma_mat, 1)
        self.variances_ = sigma_mat
        self.cov_matrix_numerator_ = sigma_mat
        self.cov_matrix_denominator_ = np.diag(np.ones(self.dim_))
        self.true_mi_ = calc_true_mi_info(rho, dim_data)

    def get_true_mi(self):
        return self.true_mi_

    def sample_gaussian(self, n_samples, means, cov_matrix):
        gauss = multivariate_normal(
              mean=means, cov=cov_matrix)
        return gauss.rvs(n_samples).astype(np.float32)

    def sample_numerator(self, n_samples):
        return self.sample_gaussian(
                  n_samples, 
                  self.means_,
                  self.cov_matrix_numerator_)

    def sample_denominator(self, n_samples):
        return self.sample_gaussian(
                  n_samples, 
                  self.means_,
                  self.cov_matrix_denominator_)

    def sample(self, n):
        de = self.sample_denominator(n) 
        nu = self.sample_numerator(n)
        return de, nu


class GenDataExpMSE(object):
    def __init__(self, rho, dim_data):
        self.means_numerator_ = np.repeat(0.0, dim_data)
        self.means_numerator_[0] = 1.0
        self.means_denominator_ = np.repeat(0.0, dim_data)
        self.dim_ = dim_data
        self.rho_ = rho
        sigma_mat = np.repeat(
            rho, dim_data*dim_data).reshape(dim_data, dim_data)
        np.fill_diagonal(sigma_mat, 1)
        self.variances_ = sigma_mat
        self.cov_matrix_numerator_ = sigma_mat
        self.cov_matrix_denominator_ = np.diag(np.ones(self.dim_))
        self.true_mi_ = calc_true_mi_info(rho, dim_data)

    def get_true_mi(self):
        return self.true_mi_

    def sample_gaussian(self, n_samples, means, cov_matrix):
        gauss = multivariate_normal(
              mean=means, cov=cov_matrix)
        samples = gauss.rvs(n_samples).astype(np.float32)
        return samples
    
    def calc_true_density_rate(self, samples):
        de_logprob = calc_log_prob(
              samples,
              self.means_denominator_,
              self.cov_matrix_denominator_)
        nu_logprob = calc_log_prob(
              samples,
              self.means_numerator_,
              self.cov_matrix_numerator_)
        true_dre = np.exp(nu_logprob - de_logprob)
        return true_dre

    def sample_numerator(self, n_samples):
        return self.sample_gaussian(
                  n_samples, 
                  self.means_numerator_,
                  self.cov_matrix_numerator_)

    def sample_denominator(self, n_samples):
        return self.sample_gaussian(
                  n_samples, 
                  self.means_denominator_,
                  self.cov_matrix_denominator_)

    def sample(self, n):
        de = self.sample_denominator(n)
        nu = self.sample_numerator(n)
        true_dre = self.calc_true_density_rate(de)
        return de, nu, true_dre
