import torch
from torch.nn import functional as F
from torch.distributions.multivariate_normal import MultivariateNormal


# def make_correlation_mtx_from_vector(vector, num_of_cov):
#     #matrix = torch.eye((32, num_of_cov))
#     identity = torch.eye(num_of_cov, device=vector.device)
#     matrix = identity.unsqueeze(0).repeat(vector.shape[0], 1, 1)
#
#     indices = torch.triu_indices(row=num_of_cov, col=num_of_cov, offset=1)
#     matrix[:, indices[0], indices[1]] = vector
#     matrix[:, indices[1], indices[0]] = vector
#
#     return matrix


def make_correlation_mtx_from_vector(vector, num_of_cov):
    # matrix = torch.eye((32, num_of_cov))
    identity = torch.eye(num_of_cov, device=vector.device)
    matrix = identity.unsqueeze(0).repeat(vector.shape[0], 1, 1)

    indices = torch.triu_indices(row=num_of_cov, col=num_of_cov, offset=1)
    matrix[:, indices[0], indices[1]] = vector
    # matrix[:, indices[1], indices[0]] = vector

    matrix = torch.bmm(matrix.transpose(2, 1), matrix) #+ identity
    var = torch.diagonal(matrix, dim1=-2, dim2=-1)
    std = torch.sqrt(var)

    outer_d = torch.einsum("bi,bj->bij", std, std)
    correlation_matrix = matrix / (outer_d + 1e-5)
    correlation_matrix = torch.clamp(correlation_matrix, min=-1, max=1)
    return correlation_matrix #, var

def extract_distributional_pars_from_vec(arr_, num_of_cov):
    arr_mu = arr_[..., 0:num_of_cov]
    # arr_var = F.softplus(arr_[..., num_of_cov:(2 * num_of_cov)])
    # arr_corr_vec = torch.tanh(arr_[..., 2 * num_of_cov::])  # [-1, 1]
    # arr_std = torch.sqrt(arr_var)
    #
    # arr_var_mtx = arr_std[..., None] * arr_std[..., None, :]
    # arr_corr_mtx = make_correlation_mtx_from_vector(arr_corr_vec, num_of_cov)
    # arr_covariance = arr_corr_mtx * arr_var_mtx

    arr_var = F.softplus(arr_[..., num_of_cov:(2 * num_of_cov)])
    arr_var = torch.clamp(arr_var, min=0, max=20)
    arr_corr_vec = arr_[..., 2 * num_of_cov::] # torch.tanh(arr_[..., 2 * num_of_cov::])  # [-1, 1]
    arr_std = torch.sqrt(arr_var)
    arr_var_mtx = arr_std[..., None] * arr_std[..., None, :]
    #arr_cov_vec = arr_[..., num_of_cov::]
    arr_corr_mtx = make_correlation_mtx_from_vector(arr_corr_vec, num_of_cov)

    arr_covariance = arr_corr_mtx * arr_var_mtx
    #arr_covariance = arr_corr_mtx

    return arr_mu, arr_var, arr_covariance





def sample_from_multivariate(mean, covariance, num_samples = 1000):
    # Create a multivariate normal distribution
    mvn = MultivariateNormal(mean, covariance)
    # Sample from the distribution
    samples = mvn.sample((num_samples,))
    return samples.permute((1, 0, 2))


if __name__ == "__main__":
    vector = torch.randn(6)  #
    n = 4
    matrix = make_correlation_mtx_from_vector(vector, n)
    print(matrix)

    arr_ = torch.randn(14)  #
    arr_cov = extract_distributional_pars_from_vec(arr_, n)


    print(arr_cov)
