import torch
import torch.functional as F


def make_covariance_from_vector(vector, num_of_cov):
    matrix = torch.eye(num_of_cov)
    indices = torch.triu_indices(row=num_of_cov, col=num_of_cov, offset=1)
    matrix[indices[0], indices[1]] = vector
    matrix[indices[1], indices[0]] = vector
    return matrix

def extract_distributional_pars(arr_, num_of_cov):
    arr_mu = arr_[..., 0:num_of_cov]
    arr_var = F.softplus(arr_[..., num_of_cov:(2 * num_of_cov)])
    arr_corr_vec = torch.tanh(arr_[..., 2 * num_of_cov::])  # [-1, 1]
    arr_std = torch.sqrt(arr_var)

    arr_var_mtx = arr_std[..., None] * arr_std[..., None, :]
    arr_corr_mtx = make_covariance_from_vector(arr_corr_vec, num_of_cov)
    arr_covariance = arr_corr_mtx * arr_var_mtx

    return arr_mu, arr_var, arr_covariance



if __name__ == "__main__":
    vector = torch.randn(6)  #
    n = 4
    matrix = make_covariance_from_vector(vector, n)
    print(matrix)

    arr_ = torch.randn(14)  #
    arr_cov = extract_distributional_pars(arr_, n)


    print(arr_cov)
