import torch
import matplotlib.pyplot as plt
from src.constants import SRC_PATH, RAW_DATA_PATH
import os


def sample_gaussian_mixture(n_samples, n_mixtures, means):
    _, dim = means.shape

    # sample mixture index
    mixture = torch.randint(n_mixtures, (n_samples, 1))
    mu = means[mixture]

    # sample from gaussian and recenter
    X = torch.randn(n_samples,  dim)  + mu.squeeze(dim=1)

    return  X

def generate_data(n_samples, mix_offsets):
    n_mixtures = 10
    x = sample_gaussian_mixture(n_samples, n_mixtures, mix_offsets)
    return x


if __name__ == "__main__":
    dim = 3
    mix_offsets = torch.empty((10, dim))


    for i in range(10):
        mix_offset = torch.randint(low=0, high=2, size=(dim,)) * 10 - 5
        mix_offsets[i] = mix_offset
    torch.manual_seed(111)

    for n_samples, name in zip([100000, 10000, 20000], ['train', 'val', 'test']): 

        x = generate_data(n_samples, mix_offsets)
        plt.figure()
        plt.scatter(x[:, 0], x[:, 1])
        plt.savefig(os.path.join(SRC_PATH ,f"gauss_mixture_dim_2_mix_10.png"))

        if not os.path.exists(os.path.join(RAW_DATA_PATH ,f"gauss_mixture_dim_2_mix_10/")):
            os.makedirs(os.path.join(RAW_DATA_PATH ,f"gauss_mixture_dim_2_mix_10/"))

        torch.save(x, os.path.join(RAW_DATA_PATH ,f"gauss_mixture_dim_2_mix_10/{name}set.py"))