import torch
import matplotlib.pyplot as plt
from src.constants import SRC_PATH, RAW_DATA_PATH
import os
import numpy as np



def generate_data(n_samples, d):
    """Gaussian distribution with mean 5 and standard deviation 1. Returns energy and sample functions."""
    
    mean = 5
    std_dev = 1

    # sample from Gaussian distribution
    x = torch.normal(mean=mean, std=std_dev, size=(n_samples, d))

    return x



if __name__ == "__main__":
    dim = 100

    for n_samples, name in zip([100000, 10000, 20000], ['train', 'val', 'test']): 
        x = generate_data(n_samples, dim)
        plt.figure()
        plt.scatter(x[:, 0], x[:, 1])
        plt.savefig(os.path.join(SRC_PATH ,f"datasets_sections/simple_gaussian_dim_{dim}.png"))

        if not os.path.exists(os.path.join(RAW_DATA_PATH ,f"simple_gaussian_dim_{dim}/")):
            os.makedirs(os.path.join(RAW_DATA_PATH ,f"simple_gaussian_dim_{dim}/"))

        torch.save(x, os.path.join(RAW_DATA_PATH ,f"simple_gaussian_dim_{dim}/{name}set.py"))