import torch
import matplotlib.pyplot as plt
from src.constants import SRC_PATH, RAW_DATA_PATH
import os
import numpy as np



def generate_data(n_samples, d, sig=3, clip_y=12):
    """Gaussian distribution for testing. Returns energy and sample functions."""
    
    mean = 5
    std_dev = 1

    # sample from Gaussian distribution
    y = torch.tensor(sig * torch.randn(n_samples, 1)).clamp(-clip_y, clip_y)
    x = torch.randn(n_samples, d - 1) * torch.exp(-y / 2) + mean

    # Concatenate y and x
    data = torch.cat((y, x), dim=1)

    return data


if __name__ == "__main__":
    dim = 100

    for n_samples, name in zip([100000, 10000, 20000], ['train', 'val', 'test']): 
        x = generate_data(n_samples, dim)
        plt.figure()
        plt.scatter(x[:, 0], x[:, 1])
        plt.savefig(os.path.join(SRC_PATH ,f"datasets_sections/funnel_dim_{dim}.png"))

        if not os.path.exists(os.path.join(RAW_DATA_PATH ,f"funnel_dim_{dim}/")):
            os.makedirs(os.path.join(RAW_DATA_PATH ,f"funnel_dim_{dim}/"))

        torch.save(x, os.path.join(RAW_DATA_PATH ,f"funnel_dim_{dim}/{name}set.py"))