import numpy as np
from tqdm import tqdm  # Import tqdm for progress bar
from data_generation import generate_L, top_k_eigen, generate_X, compute_covariance_matrix, compute_covariance_matrix_scale_d
import os
# from sklearn.datasets import fetch_openml

# def generate_downscale_mnist(total_samples, D, N, k, input_is_cov, predict_vector, file_name):
#     X_data = []
#     Y_data = []
#     Y_vector_data = []
#     os.makedirs('dataset', exist_ok=True)

  
#     mnist = fetch_openml('mnist_784', version=1)
#     X = mnist.data
    
    
#     X_centered = X - np.mean(X, axis=0)
    
#     U, S, Vt = np.linalg.svd(X_centered, full_matrices=False) 
#     K = D # use top-K eigenvector project the image to K-dimension space
#     X_reduced = np.dot(X_centered, Vt.T[:, :K])


  
#     for i in tqdm(range(total_samples), desc=f"Generating {k}-dimention real world dataset"):
#         # mean = 0
#         # L = generate_L(D)
#         # X = generate_X(N, D, mean, L)
#         sample_covariance = compute_covariance_matrix(X_reduced[i, :])
        
#         # if input_is_cov:
#         #     X = sample_covariance

#         Y, Y_vector = top_k_eigen(sample_covariance, k)
        
#         X_data.append(X)
#         Y_data.append(Y)
#         if predict_vector:
#             Y_vector_data.append(Y_vector)

#     # Convert lists to NumPy arrays
#     X_data = np.array(X_data)
#     Y_data = np.array(Y_data)
#     if predict_vector:
#         Y_vector_data = np.array(Y_vector_data)

#     # Save data to .npz file
#     if predict_vector:
#         np.savez(file_name, X_data=X_data, Y_data=Y_data, Y_vector_data=Y_vector_data)
#     else:
#         np.savez(file_name, X_data=X_data, Y_data=Y_data)
    
#     print(f"Dataset saved to {file_name}")




def generate_dataset(total_samples, D, N, k, input_is_cov, predict_vector, file_name):
    X_data = []
    Y_data = []
    Y_vector_data = []
    os.makedirs('dataset', exist_ok=True)
    # Use tqdm to display progress bar
    for _ in tqdm(range(total_samples), desc=f"Generating {D}-dimention dataset"):
        mean = 0
        L = generate_L(D)
        X = generate_X(N, D, mean, L)
        sample_covariance = compute_covariance_matrix(X)
        
        if input_is_cov:
            X = sample_covariance

        Y, Y_vector = top_k_eigen(sample_covariance, k)
        
        X_data.append(X)
        Y_data.append(Y)
        if predict_vector:
            Y_vector_data.append(Y_vector)


    X_data = np.array(X_data)
    Y_data = np.array(Y_data)
    if predict_vector:
        Y_vector_data = np.array(Y_vector_data)


    if predict_vector:
        np.savez(file_name, X_data=X_data, Y_data=Y_data, Y_vector_data=Y_vector_data)
    else:
        np.savez(file_name, X_data=X_data, Y_data=Y_data)
    
    print(f"Dataset saved to {file_name}")

def generate_dataset_scale_d(total_samples, D, N, k, input_is_cov, predict_vector, file_name):
    X_data = []
    Y_data = []
    Y_vector_data = []
    os.makedirs('dataset', exist_ok=True)

    for _ in tqdm(range(total_samples), desc=f"Generating {D}-dimention dataset"):
        mean = 0
        L = generate_L(D)
        X = generate_X(N, D, mean, L)
        sample_covariance = compute_covariance_matrix_scale_d(X,D)
        
        if input_is_cov:
            X = sample_covariance

        Y, Y_vector = top_k_eigen(sample_covariance, k)
        
        X_data.append(X)
        Y_data.append(Y)
        if predict_vector:
            Y_vector_data.append(Y_vector)

    X_data = np.array(X_data)
    Y_data = np.array(Y_data)
    if predict_vector:
        Y_vector_data = np.array(Y_vector_data)


    if predict_vector:
        np.savez(file_name, X_data=X_data, Y_data=Y_data, Y_vector_data=Y_vector_data)
    else:
        np.savez(file_name, X_data=X_data, Y_data=Y_data)
    
    print(f"Dataset saved to {file_name}")



total_samples = 1280000  
N = 10   
k = 5
input_is_cov = False  
predict_vector = True  
D_list = [5]


for D in D_list:
  # file_path = os.path.join("dataset", f"{N}_column_multivariate_gaussian_dataset_D_{D}_{total_samples}.npz")
  file_path = os.path.join("dataset", f"scale_N_10_k_5_multivariate_gaussian_dataset_D_{D}_{total_samples}.npz")
  # generate_dataset(total_samples, D, N, k, input_is_cov, predict_vector, file_path)
  # generate_dataset(1280, D, N, k, input_is_cov, predict_vector, file_path)
  generate_dataset_scale_d(total_samples, D, N, k, input_is_cov, predict_vector, file_path)



