import numpy as np
from numpy.random import RandomState
import sklearn


def centered_mean_gm(dim, num_components=9, seed=1):
    prng = RandomState(seed)
    centered_mean = [np.zeros([1, dim]),
                     prng.uniform(low=-3, high=3, size=(num_components, dim))]
    random_cov = [np.expand_dims(np.eye(dim), axis=0),
                  np.expand_dims(np.eye(dim) * 1.0, axis=0).repeat(num_components, axis=0),
                  #   cov_find_seed(dim, num_components, seeded=seeded)
                  ]
    return centered_mean, random_cov


def given_singular_spd_cov(n_dim, random_state=None, range_sing=[0.5, 1]):
    generator = sklearn.utils.check_random_state(random_state)
    A = generator.rand(n_dim, n_dim)
    U, _, V = np.linalg.svd(np.dot(A.T, A))
    X = np.dot(np.dot(
        U, np.diag(range_sing[0] + generator.rand(n_dim) * (range_sing[1] - range_sing[0]))), V)
    return X


def cov_find_seed(n_dim, num_components, seeded=True):
    if seeded:
        seed = np.arange(num_components)
    else:
        seed = [None] * num_components
    cov = np.zeros([num_components, n_dim, n_dim])
    for idx in range(num_components):
        cov[idx] = given_singular_spd_cov(n_dim, seed[idx])
    return cov

# * Now the first array in the list is nearly not used anymore.
# * But if we want to change the P_0, we can change this.


def select_mean_and_cov_gauss(dim, seed=1):
    prng = RandomState(seed)
    mean = prng.randn(dim)
    cov = given_singular_spd_cov(dim, random_state=seed)
    return mean, cov


def select_mean_and_cov_gmm(trial, *args, **kwargs):
    # * The first element in list doesn't really matter,
    # * just need to keep the same dimension.
    if abs(trial - 1) <= 0.2:
        mean = [np.array([[0]]),
                np.array([[2],
                          [-2]])]
        cov = [np.array([[[5]]]),
               np.array([[[1]],
                         [[1]]])]
    if abs(trial - 2) <= 0.4:
        mean = [np.array([[0, 0]]),
                np.array([[2, 2],
                          [-2, -2]])]
        cov = [np.array([[[1, 0], [0, 1]]]),
               np.array([[[1.5, -1], [-1, 1.5]],
                         [[1, 0], [0, 1]]])]
    elif trial > 3:
        dim = int(trial)
        num_component = args[0]
        seed = args[1]
        mean, cov = centered_mean_gm(dim, num_component, seed)

    return mean, cov
