import torch
import matplotlib.pyplot as plt
from src.constants import SRC_PATH, RAW_DATA_PATH
import os
# sample from gausiann mixture 
def sample_gaussian_mixture(n_samples, n_mixtures, mix_offsets):
    # sample mixture
    mixture = torch.randint(n_mixtures, (n_samples, 1))
    # sample from gaussian
    xy = torch.randn(n_samples, 2) * 1.
    x_offsets = mix_offsets[:, 0][mixture]
    y_offsets = mix_offsets[:, 1][mixture]
    # sample from gaussian mixture
    xy = xy + torch.cat([x_offsets, y_offsets], dim=-1)
    return xy

def generate_data(n_samples, mix_offsets):
    # generate data
    n_mixtures = 4
    x = sample_gaussian_mixture(n_samples, n_mixtures, mix_offsets)
    return x


if __name__ == "__main__":
    # mix_offsets = torch.tensor([[-0.75, -0.75], [0.75, 0.75], [0.75, -0.75], [-0.75, 0.75]])
    mix_offsets = torch.tensor([[-5, 5], [5, -5], [-5, -5], [5, 5]])
    # mix_offsets = torch.tensor([[-1.00, -1.00], [1.00, 1.00]])
    torch.manual_seed(111)

    for n_samples, name in zip([100000, 10000, 20000], ['train', 'val', 'test']): 

        x = generate_data(n_samples, mix_offsets)
        # plot x and open in new tab
        plt.figure()
        plt.scatter(x[:, 0], x[:, 1])
        plt.show()
        # savefig 
        # plt.savefig(os.path.join(SRC_PATH ,f"gaussian_mixture.png"))

        if not os.path.exists(os.path.join(RAW_DATA_PATH ,f"gauss_mixture_v3/")):
            os.makedirs(os.path.join(RAW_DATA_PATH ,f"gauss_mixture_v3/"))

        torch.save(x, os.path.join(RAW_DATA_PATH ,f"gauss_mixture_v3/{name}set.py"))

