import numpy as np
from tqdm import tqdm  # Import tqdm for progress bar
from data_generation import generate_L, top_k_eigen, generate_X, compute_covariance_matrix, compute_covariance_matrix_scale_d
import os
from sklearn.datasets import fetch_openml
import torchvision
import torchvision.transforms as transforms



def generate_downscale_mnist(num_train, D, N, k, input_is_cov, predict_vector, file_name):
    X_data = []
    Y_data = []
    Y_vector_data = []
    os.makedirs('dataset', exist_ok=True)

  
    mnist = fetch_openml('mnist_784', version=1, )
    X = mnist.data / 100
    print(X.shape)
    
    X_centered = X - np.mean(X, axis=0)
    
    U, S, Vt = np.linalg.svd(X_centered, full_matrices=False) 
    K = D # use top-K eigenvector project the image to K-dimension space
    X_reduced = np.dot(X_centered, Vt.T[:, :K])
    print("X_reduced shape", X_reduced.shape)
    # X_reduced_train = X_reduced[:num_train,:] 
    # X_reduced_test = X_reduced[num_train:,:]
    num_sample = N

    for i in tqdm(range(X_reduced.shape[0] // num_sample), desc=f"Generating {k}-dimention real world dataset"):
        # mean = 0
        # L = generate_L(D)
        # X = generate_X(N, D, mean, L)
        one_input_train = X_reduced[i:i+num_sample, :]
        # one_input_test = X_reduced_test[i:i+num_sample.:]
        sample_covariance = compute_covariance_matrix(one_input_train)
        
        # if input_is_cov:
        #     X = sample_covariance

        Y, Y_vector = top_k_eigen(sample_covariance, k)
        
        X_data.append(one_input_train)
        Y_data.append(Y)
        if predict_vector:
            Y_vector_data.append(Y_vector)
    
    print("begin to transform x to numpy")
    # Convert lists to NumPy arrays
    X_data = np.array(X_data)
    print("X_data shape", X_data.shape)
    Y_data = np.array(Y_data)
    print("Y_data shape", Y_data.shape)
    if predict_vector:
        Y_vector_data = np.array(Y_vector_data)
        print("Y_vector_data shape", Y_vector_data.shape)
    X_train = X_data[:num_train,:]
    print("X_train shape: ", X_train.shape)
    Y_train = Y_data[:num_train,:]
    print("Y_train shape: ", Y_train.shape)
    if predict_vector:
        Y_vector_train = Y_vector_data[:num_train,:]
        print("Y_vector_train shape: ", Y_vector_train.shape)
    X_test = X_data[num_train:,:]
    print("X_test shape: ", X_test.shape)
    Y_test = Y_data[num_train:,:]
    print("Y_test shape: ", Y_test.shape)
    if predict_vector:
            Y_vector_test = Y_vector_data[:num_train,:]
            print("Y_vector_test shape: ", Y_vector_test.shape)
    # Save data to .npz file
    if predict_vector:
        np.savez(file_name, X_train=X_train, Y_train=Y_train, Y_vector_train=Y_vector_train, X_test=X_test, Y_test=Y_test, Y_vector_test=Y_vector_test)
    else:
        np.savez(file_name, X_train=X_train, Y_train=Y_train, X_test=X_test, Y_test=Y_test)
    
    print(f"Dataset saved to {file_name}")


def generate_downscale_fashion_mnist(num_train, D, N, k, input_is_cov, predict_vector, file_name):
    X_data = []
    Y_data = []
    Y_vector_data = []
    os.makedirs('dataset', exist_ok=True)

    transform = transforms.Compose(
    [transforms.ToTensor()])
    # mnist = fetch_openml('mnist_784', version=1, )
    training_set = torchvision.datasets.FashionMNIST('./data', train=True, transform=transform, download=True)

    X = training_set.data.view(training_set.data.size(0),-1).numpy() / 100

    
    # X = mnist.data / 100
    print(X.shape)
    
    X_centered = X - np.mean(X, axis=0)
    
    U, S, Vt = np.linalg.svd(X_centered, full_matrices=False) 
    K = D # use top-K eigenvector project the image to K-dimension space
    X_reduced = np.dot(X_centered, Vt.T[:, :K])
    print("X_reduced shape", X_reduced.shape)
    # X_reduced_train = X_reduced[:num_train,:] 
    # X_reduced_test = X_reduced[num_train:,:]
    num_sample = N

    for i in tqdm(range(X_reduced.shape[0] // num_sample), desc=f"Generating {k}-dimention real world dataset"):
        # mean = 0
        # L = generate_L(D)
        # X = generate_X(N, D, mean, L)
        one_input_train = X_reduced[i:i+num_sample, :]
        # one_input_test = X_reduced_test[i:i+num_sample.:]
        sample_covariance = compute_covariance_matrix(one_input_train)
        
        # if input_is_cov:
        #     X = sample_covariance

        Y, Y_vector = top_k_eigen(sample_covariance, k)
        
        X_data.append(one_input_train)
        Y_data.append(Y)
        if predict_vector:
            Y_vector_data.append(Y_vector)
    
    print("begin to transform x to numpy")

    X_data = np.array(X_data)
    print("X_data shape", X_data.shape)
    Y_data = np.array(Y_data)
    print("Y_data shape", Y_data.shape)
    if predict_vector:
        Y_vector_data = np.array(Y_vector_data)
        print("Y_vector_data shape", Y_vector_data.shape)
    X_train = X_data[:num_train,:]
    print("X_train shape: ", X_train.shape)
    Y_train = Y_data[:num_train,:]
    print("Y_train shape: ", Y_train.shape)
    if predict_vector:
        Y_vector_train = Y_vector_data[:num_train,:]
        print("Y_vector_train shape: ", Y_vector_train.shape)
    X_test = X_data[num_train:,:]
    print("X_test shape: ", X_test.shape)
    Y_test = Y_data[num_train:,:]
    print("Y_test shape: ", Y_test.shape)
    if predict_vector:
            Y_vector_test = Y_vector_data[:num_train,:]
            print("Y_vector_test shape: ", Y_vector_test.shape)
  
    if predict_vector:
        np.savez(file_name, X_train=X_train, Y_train=Y_train, Y_vector_train=Y_vector_train, X_test=X_test, Y_test=Y_test, Y_vector_test=Y_vector_test)
    else:
        np.savez(file_name, X_train=X_train, Y_train=Y_train, X_test=X_test, Y_test=Y_test)
    
    print(f"Dataset saved to {file_name}")


def generate_downscale_cifar10(num_train, D, N, k, input_is_cov, predict_vector, file_name):
    X_data = []
    Y_data = []
    Y_vector_data = []
    os.makedirs('dataset', exist_ok=True)

    transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Grayscale()])
    # mnist = fetch_openml('mnist_784', version=1, )
    training_set = torchvision.datasets.CIFAR10(root='./data', train=True,
                                            download=True, transform=transform)
    X = training_set.data.reshape(training_set.data.shape[0],-1) / 100

    
    # X = mnist.data / 100
    print(X.shape)
    
    X_centered = X - np.mean(X, axis=0)
    
    U, S, Vt = np.linalg.svd(X_centered, full_matrices=False) 
    K = D # use top-K eigenvector project the image to K-dimension space
    X_reduced = np.dot(X_centered, Vt.T[:, :K])
    print("X_reduced shape", X_reduced.shape)
    # X_reduced_train = X_reduced[:num_train,:] 
    # X_reduced_test = X_reduced[num_train:,:]
    num_sample = N

    for i in tqdm(range(X_reduced.shape[0] // num_sample), desc=f"Generating {k}-dimention real world dataset"):
        # mean = 0
        # L = generate_L(D)
        # X = generate_X(N, D, mean, L)
        one_input_train = X_reduced[i:i+num_sample, :]
        # one_input_test = X_reduced_test[i:i+num_sample.:]
        sample_covariance = compute_covariance_matrix(one_input_train)
        
        # if input_is_cov:
        #     X = sample_covariance

        Y, Y_vector = top_k_eigen(sample_covariance, k)
        
        X_data.append(one_input_train)
        Y_data.append(Y)
        if predict_vector:
            Y_vector_data.append(Y_vector)
    
    print("begin to transform x to numpy")

    X_data = np.array(X_data)
    print("X_data shape", X_data.shape)
    Y_data = np.array(Y_data)
    print("Y_data shape", Y_data.shape)
    if predict_vector:
        Y_vector_data = np.array(Y_vector_data)
        print("Y_vector_data shape", Y_vector_data.shape)
    X_train = X_data[:num_train,:]
    print("X_train shape: ", X_train.shape)
    Y_train = Y_data[:num_train,:]
    print("Y_train shape: ", Y_train.shape)
    if predict_vector:
        Y_vector_train = Y_vector_data[:num_train,:]
        print("Y_vector_train shape: ", Y_vector_train.shape)
    X_test = X_data[num_train:,:]
    print("X_test shape: ", X_test.shape)
    Y_test = Y_data[num_train:,:]
    print("Y_test shape: ", Y_test.shape)
    if predict_vector:
            Y_vector_test = Y_vector_data[:num_train,:]
            print("Y_vector_test shape: ", Y_vector_test.shape)

    if predict_vector:
        np.savez(file_name, X_train=X_train, Y_train=Y_train, Y_vector_train=Y_vector_train, X_test=X_test, Y_test=Y_test, Y_vector_test=Y_vector_test)
    else:
        np.savez(file_name, X_train=X_train, Y_train=Y_train, X_test=X_test, Y_test=Y_test)
    
    print(f"Dataset saved to {file_name}")




def generate_dataset(total_samples, D, N, k, input_is_cov, predict_vector, file_name):
    X_data = []
    Y_data = []
    Y_vector_data = []
    os.makedirs('dataset', exist_ok=True)
    # Use tqdm to display progress bar
    for _ in tqdm(range(total_samples), desc=f"Generating {D}-dimention dataset"):
        mean = 0
        L = generate_L(D)
        X = generate_X(N, D, mean, L)
        sample_covariance = compute_covariance_matrix(X)
        
        if input_is_cov:
            X = sample_covariance

        Y, Y_vector = top_k_eigen(sample_covariance, k)
        
        X_data.append(X)
        Y_data.append(Y)
        if predict_vector:
            Y_vector_data.append(Y_vector)


    X_data = np.array(X_data)
    Y_data = np.array(Y_data)
    if predict_vector:
        Y_vector_data = np.array(Y_vector_data)


    if predict_vector:
        np.savez(file_name, X_data=X_data, Y_data=Y_data, Y_vector_data=Y_vector_data)
    else:
        np.savez(file_name, X_data=X_data, Y_data=Y_data)
    
    print(f"Dataset saved to {file_name}")

def generate_dataset_scale_d(total_samples, D, N, k, input_is_cov, predict_vector, file_name):
    X_data = []
    Y_data = []
    Y_vector_data = []
    os.makedirs('dataset', exist_ok=True)

    for _ in tqdm(range(total_samples), desc=f"Generating {D}-dimention dataset"):
        mean = 0
        L = generate_L(D)
        X = generate_X(N, D, mean, L)
        sample_covariance = compute_covariance_matrix_scale_d(X,D)
        
        if input_is_cov:
            X = sample_covariance

        Y, Y_vector = top_k_eigen(sample_covariance, k)
        
        X_data.append(X)
        Y_data.append(Y)
        if predict_vector:
            Y_vector_data.append(Y_vector)

    X_data = np.array(X_data)
    Y_data = np.array(Y_data)
    if predict_vector:
        Y_vector_data = np.array(Y_vector_data)


    if predict_vector:
        np.savez(file_name, X_data=X_data, Y_data=Y_data, Y_vector_data=Y_vector_data)
    else:
        np.savez(file_name, X_data=X_data, Y_data=Y_data)
    
    print(f"Dataset saved to {file_name}")



num_train = 60000  
N = 50   
k = 10
input_is_cov = False  
predict_vector = True  
D_list = [20]


for D in D_list:
  # file_path = os.path.join("dataset", f"{N}_column_multivariate_gaussian_dataset_D_{D}_{total_samples}.npz")
  # file_path_mnist = os.path.join("dataset", f"100_divide_N_10_mnist_dataset_D_{D}_{num_train}_k_10.npz")
  file_path_fashion_mnist = os.path.join("dataset", f"100_divide_N_50_fashion_mnist_dataset_D_{D}_{num_train}_k_10.npz")
  # file_path_cifar10 = os.path.join("dataset", f"100_divide_N_10_cifar10_dataset_D_{D}_{num_train}_k_10.npz")
  # generate_downscale_cifar10(num_train, D, N, k, input_is_cov, predict_vector, file_path_cifar10)
  # generate_downscale_mnist(num_train, D, N, k, input_is_cov, predict_vector, file_path_mnist)
  generate_downscale_fashion_mnist(num_train, D, N, k, input_is_cov, predict_vector, file_path_fashion_mnist)
  # generate_dataset(1280, D, N, k, input_is_cov, predict_vector, file_path)
  # generate_dataset_scale_d(total_samples, D, N, k, input_is_cov, predict_vector, file_path)
    


