import os
os.system("pip install --upgrade scikit-learn")
os.system("pip install sklearn-compat")
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'

import sklearn.utils
try:
    from sklearn.externals.joblib import parallel_backend
except ImportError:
    from joblib import parallel_backend
sklearn.utils.parallel_backend = parallel_backend

import sys

PACKAGE_PARENT = '..'
sys.path.append(PACKAGE_PARENT)

import torch
import numpy as np
import json
import pickle
import shutil
import time
from os.path import join as pjoin, exists as pexists

import pandas as pd
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split

#import nodegam
#from nodegam.gams import model_utils

import sklearn
import argparse
from nodegam.sklearn import NodeGAMRegressor


def main():
    
    parser = argparse.ArgumentParser(description="Train a model")
    parser.add_argument('--data_name', default='only_main_data', type=str, help='type of the data to use')
    args = parser.parse_args()

    path = '../../Dataset/Numerical-Studies/'
    dataset = torch.load(path+args.data_name+'.pt', weights_only = True)

    length, sample, feature = dataset['X_train'].size()
    mse_list = np.zeros(length)
    random_state = 0
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    print('device: ', device)
    
    length = dataset['X_train'].size()[0]
    start_time = time.time()
    for i in range(length):
        
        X_train = pd.DataFrame(dataset['X_train'][i].cpu().numpy())
        y_train = dataset['y_train'][i].cpu().numpy().reshape(-1)
    
        model = NodeGAMRegressor(in_features = feature, name = args.data_name,device='cpu', ga2m=1)
        model.fit(X_train, y_train)
        pred_test = model.predict(pd.DataFrame(dataset['X_test'][i].cpu().numpy()))
        
        
        mse_list[i] = np.mean((pred_test - dataset['y_test'][i].cpu().numpy().reshape(-1))**2)
    
    end_time = time.time()
    elapsed_time = end_time - start_time
    print(mse_list)
    print(f"Program executed in {elapsed_time:.4f} seconds")
    print(f"Mean of MSE {np.mean(mse_list)}")
    print(f"Std of MSE {np.std(mse_list)}")
    
if __name__ == "__main__":
    main()


