import torch
import matplotlib.pyplot as plt
from src.constants import SRC_PATH, RAW_DATA_PATH
import os


def sample_gaussian_mixture(n_samples, n_mixtures, means):
    _, dim = means.shape

    # sample mixture index
    mixture = torch.randint(n_mixtures, (n_samples, 1))
    mu = means[mixture]

    # sample from gaussian and recenter
    X = torch.randn(n_samples,  dim)  + mu.squeeze(dim=1)

    return  X

def generate_data(n_samples, mix_offsets, n_mixtures):
    x = sample_gaussian_mixture(n_samples, n_mixtures, mix_offsets)
    return x


if __name__ == "__main__":
    dim = 50
    n_mixtures = 52

    mix_offsets = torch.empty((n_mixtures, dim))


    for i in range(n_mixtures):
        mix_offset = torch.rand(dim,) * 6 - 3

        mix_offsets[i] = mix_offset
    torch.manual_seed(111)

    for n_samples, name in zip([100000, 10000, 20000], ['train', 'val', 'test']): 
        x = generate_data(n_samples, mix_offsets, n_mixtures)
        plt.figure()
        plt.scatter(x[:, 0], x[:, 1])
        plt.savefig(os.path.join(SRC_PATH ,f"gauss_mixture_sphere_dim_{dim}_mix_{n_mixtures}.png"))

        if not os.path.exists(os.path.join(RAW_DATA_PATH ,f"gauss_mixture_sphere_dim_{dim}_mix_{n_mixtures}/")):
            os.makedirs(os.path.join(RAW_DATA_PATH ,f"gauss_mixture_sphere_dim_{dim}_mix_{n_mixtures}/"))

        torch.save(x, os.path.join(RAW_DATA_PATH ,f"gauss_mixture_sphere_dim_{dim}_mix_{n_mixtures}/{name}set.py"))