import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import os
from tqdm import tqdm
import math
import sys

# Set the number of data points, input dimensions, and the range of X
num_points = 100
num_input_dims = 10
X_min, X_max = -1, 1
y_max = 10
freq_min, freq_max = int(sys.argv[1]), int(sys.argv[2])

# Parameters for generate data
dataset_size = int(1e9)
batch_size = int(1e6)
data_path_X = f"generated_data/sin_waves/input_dim_{num_input_dims}_min_freq_{freq_min}_max_freq_{freq_max}_X.npy"
data_path_y = f"generated_data/sin_waves/input_dim_{num_input_dims}_min_freq_{freq_min}_max_freq_{freq_max}_y.npy"

os.makedirs(os.path.dirname(data_path_X), exist_ok=True)

DEBUG = True
PLOT_GRAPH = False
if PLOT_GRAPH:
    num_input_dims = 2

if DEBUG:
    num_input_dims = 2
    dataset_size //= int(1e5)
    batch_size //= int(1e5)

fp_X = np.memmap(data_path_X, dtype=np.float32, mode="w+", shape=(dataset_size, num_input_dims))
fp_y = np.memmap(data_path_y, dtype=np.float32, mode="w+", shape=(dataset_size))

# Set the number of sine waves to combine
num_waves = 100
# Generate random frequencies, amplitudes, and phase shifts for each sine wave
frequencies = (freq_min + np.random.rand(num_waves, num_input_dims) * (freq_max - freq_min))
amplitudes = np.random.rand(num_waves) * y_max
phase_shifts = np.random.rand(num_waves, num_input_dims) * 2 * np.pi

# Compute f(X) by combining the input dimensions and sine waves
def function_with_freq(X, num_waves, frequencies, phase_shifts, amplitudes):
    num_points, num_input_dims = X.shape
    f_X = np.zeros(num_points)

    for i in range(num_waves):
        wave = np.prod(np.sin(2 * np.pi * frequencies[i] * X / (X_max - X_min) + phase_shifts[i]), axis=-1)
        f_X += amplitudes[i] * wave
    return f_X

f = lambda X: function_with_freq(X, num_waves, frequencies, phase_shifts, amplitudes)

# Generate X values for each input dimension

if PLOT_GRAPH:
    X = np.linspace(X_min, X_max, num_points)
    X1, X2 = np.meshgrid(X, X)
    X = np.stack([X1.reshape(-1), X2.reshape(-1)], axis=-1)

    f_X = f(X).reshape(num_points, num_points)


    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(111, projection="3d")
    ax.plot_surface(X1, X2, f_X, cmap="viridis", edgecolor="none")
    ax.set_xlabel("X1")
    ax.set_ylabel("X2")
    ax.set_zlabel("f(X)")
    ax.set_title("f(X) with 2D input")

    plt.show()

    exit(0)


for i in tqdm(range(math.ceil(dataset_size / batch_size))):
    start = i*batch_size
    end = (i+1)*batch_size
    end = min(end, dataset_size)
    X = np.random.rand(batch_size, num_input_dims) * (X_max - X_min) + X_min
    y = f(X)

    fp_X[start: end] = X
    fp_y[start: end] = y

    fp_X.flush()
    fp_y.flush()

if DEBUG:
    # Create a 3D plot
    fig = plt.figure(figsize=(10, 8))
    ax = fig.add_subplot(111, projection='3d')

    # Plot the scatter points
    ax.scatter(fp_X[:, 0], fp_X[:, 1], fp_y, c=fp_y, cmap='viridis', s=5)

    ax.set_xlabel('X1')
    ax.set_ylabel('X2')
    ax.set_zlabel('f(X)')
    ax.set_title('High-frequency data with 2D input (Scatter Plot)')

    plt.show()