
import os
os.system("pip install --upgrade scikit-learn")

import sys

PACKAGE_PARENT = '..'
sys.path.append(PACKAGE_PARENT)

import torch
import numpy as np
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split

from gaminet import GAMINetRegressor, GAMINetClassifier
from gaminet.utils import local_visualize
from gaminet.utils import global_visualize_density
from gaminet.utils import feature_importance_visualize
from gaminet.utils import plot_trajectory
from gaminet.utils import plot_regularization
import time
import sklearn
import argparse


def main():
    
    parser = argparse.ArgumentParser(description="Train a model")
    parser.add_argument('--data_name', default='only_main_data', type=str, help='type of the data to use')
    args = parser.parse_args()

	path = '/home/users/yhung7/SDAM/Dataset/Numerical-Studies/'
    dataset = torch.load(path+args.data_name+'.pt', weights_only = True)
    
    length = dataset['X_train'].size()[0]
    mse_list = np.zeros(length)
    random_state = 0
    start_time = time.time()
    
    for i in range(length):
        X_train = np.asarray(dataset['X_train'])[i]
        y_train = np.asarray(dataset['y_train'][i]).reshape(-1)
    
        model = GAMINetRegressor(interact_num=10,
                         subnet_size_main_effect=(20, ) ,
                         subnet_size_interaction=(20, 20),
                         max_epochs=(1000, 1000, 1000),
                         learning_rates=(0.001, 0.001, 0.0001),
                         early_stop_thres=("auto", "auto", "auto"),
                         batch_size=1000,
                         reg_clarity=1,
                         loss_threshold=0.01,
                         warm_start=True,
                         verbose=False,
                         random_state=random_state)
        
        model.fit(X_train, y_train)
        pred_test = model.predict(np.asarray(dataset['X_test'])[i])
        mse_list[i] = np.mean((pred_test - np.asarray(dataset['y_test'])[i].reshape(-1))**2)
    
    end_time = time.time()
    elapsed_time = end_time - start_time
    print(mse_list)
    print(f"Program executed in {elapsed_time:.4f} seconds")
    print(f"Mean of MSE {np.mean(mse_list)}")
    print(f"Std of MSE {np.std(mse_list)}")
    
if __name__ == "__main__":
    main()


