

from model import logistic_regression, logistic_regression_subsample, DFR_model, JTT_model
from weights import *
from weight_searcher import *
from data import Toy
import numpy as np
import matplotlib.pyplot as plt
from helpers import *
import sys
import matplotlib as mpl
from data import WB, CelebA, multiNLI
from metrics import *
import argparse
import time
import pandas as pd
import os, sys
import pickle 
from param_search_last_layer import efficient_split


def main(dataset, seeds, method, device,  solver, tol,
          batch_size_model, augmentation_data, early_stopping_model
          , DFR_data='val', selection_criteria='wg_val_avg',
          T=100, patience=100, lr=0.1, eta_q_GDRO=0.1, lr_schedule='constant', decay=0.1, p_min=10e-4, stable_exp=False, grad_type='ift', momentum=None, grad_clip=None, GDRO=False, start_p_value='standard', eps=0,
           verbose=True, penalty_type_predictor=None, penalty_strength_predictor=None, 
            create_identifier=False, C_GDRO=0.0, use_SGDClassifier=False, 
              fraction_original_data=1.0, parallel_fit=False, frac_DFR_val_data=0.5, p_y_JTT=0.5,
                result_folder='results', model_folder='models', use_hyper_param_standard=True, save_trajectory=False, save=True, eta_param_GDRO=1e-4, lambda_JTT=1.0,
                penalty_strength=None, penalty_type='l1', add_lr_mom_to_key=False, train_val_split='original'):

    # Take the start time
    start_time = time.time()

    # create list for keys
    keys = []

    # create a dict with lists to store the results
    results = {'seed':[], 
                    'loss_val':[], 'loss_test':[],
                'wg_val':[],'wg_test':[],
                'weighted_acc_val':[], 'weighted_acc_test':[]}

    # if we use the standard hyperparameters, we need to load the best hyperparameters
    if use_hyper_param_standard:
        # get the method by taking the str after ':'
        method_for_param = method.split(':')[1] if ':' in method else method


        # if WS:GDRO and use_SGDclassifier is False, select based on ERM-GW
        if method == 'WS:GDRO' and not use_SGDClassifier:
            method_for_param = 'ERM-GW'
        

        # select all entries with the dataset, method, seed, and the dataset settings (early stopping, augmentation data, DFR_data, fraction_original_data)
        param_search_df = pd.read_csv('{}/main_param_search.csv'.format(result_folder))
        param_search_df = param_search_df[(param_search_df['dataset']==dataset) & (param_search_df['method']==method_for_param) & (param_search_df['early_stopping_model']==early_stopping_model) & (param_search_df['augmentation_data']==augmentation_data) & (param_search_df['DFR_data']==DFR_data) & (param_search_df['fraction_original_data']==fraction_original_data)]


    # loop over all seeds
    for seed in seeds:
        
        # for Waterbirds, we need to load the embeddings
        if dataset=='WB':
            data_obj = WB()
                    
            # load the data
            X_train, y_train, X_val, y_val, X_test, y_test, g_train, g_val, g_test = data_obj.return_WB_embeddings(seed, batch_size_model, early_stopping_model, augmentation_data, JTT_folder=False)

        # for CelebA, we need to load the embeddings
        elif dataset=='CelebA':
            data_obj = CelebA()

            # load the data
            X_train, y_train, X_val, y_val, X_test, y_test, g_train, g_val,g_test = data_obj.return_CelebA_data(seed, batch_size_model, early_stopping_model, augmentation_data, device)

        elif dataset=='multiNLI':
            data_obj = multiNLI()

            # load the data
            X_train, y_train, X_val, y_val, X_test, y_test, g_train, g_val,g_test = data_obj.load_embeddings(seed, batch_size_model, early_stopping_model, device)


        # change fraction of original data
        if fraction_original_data < 1.0:
            set_seed(seed)
            X_train, y_train, g_train, indices_train = get_fraction_original_data( fraction_original_data, X_train, y_train, g_train)
            X_val, y_val, g_val, indices_val = get_fraction_original_data(fraction_original_data,X_val, y_val, g_val)
            print('Fraction of original data used: {}'.format(fraction_original_data))


        # if we need to split the validation data, we need to split the validation data
        if train_val_split != 'original':
            X_train, y_train, g_train, X_val, y_val, g_val = efficient_split(X_train, y_train, g_train, X_val, y_val, g_val, train_val_split)
            print('g in train: ', torch.unique(g_train, return_counts=True))
            print('g in val: ', torch.unique(g_val, return_counts=True))

        

        # fit the model
        if DFR_data=='val':

            # save the X_val, y_val, g_val
            X_val_orig = X_val
            y_val_orig = y_val
            g_val_orig = g_val

            # we now need to split the validation data into two parts
            # take all indeces, shuffle them
            np.random.seed(seed)
            indeces = np.arange(X_val.shape[0])
            np.random.shuffle(indeces)

            # select up to x% of the data for training
            n_frac = int(frac_DFR_val_data*X_val.shape[0])
            indeces_1 = indeces[:n_frac]
            indeces_2 = indeces[n_frac:]

            # X_train, y_train, g_train become the first half, X_val, y_val, g_val become the second half
            X_train = X_val[indeces_1]
            y_train = y_val[indeces_1]
            g_train = g_val[indeces_1]
            X_val = X_val[indeces_2]
            y_val = y_val[indeces_2]
            g_val = g_val[indeces_2]
      
        # define  the (ood) weights
        p_ood = {1: 0.25, 2: 0.25, 3: 0.25, 4: 0.25}
        p_train = data_obj.get_p_dict(g=g_train)
        p_val =  data_obj.get_p_dict(g=g_val)
        p_test = data_obj.get_p_dict(g=g_test)

        # define a weight object for the val data based on p_ood
        weights_val_obj = weights(p_weights=p_ood, p_train=p_val)
        weights_val = weights_val_obj.get_weights_sample(g_val)
        weights_test_obj = weights(p_weights=p_ood, p_train=p_test)
        weights_test = weights_test_obj.get_weights_sample(g_test)

        # get the data and respecitve best parameters
        if use_hyper_param_standard:

            # select the df
            param_search_df_seed = param_search_df[param_search_df['seed']==seed]

            # if train_val_split is not original, select part of the df. 
            if train_val_split != 'original':
                param_search_df_seed = param_search_df_seed[param_search_df_seed['train_val_split'] == train_val_split]


            # get the best hyper-parameters
            best_param, param_to_search = get_best_hyper_param_method(method_for_param, param_search_df_seed,  selection_criteria=selection_criteria)

            # from best_param, get the relevant parameters
            # get the netry in param_search_df_seed
            row_best_param = param_search_df_seed[param_search_df_seed['key']==best_param['key']]
            penalty_strength = row_best_param['penalty_strength'].values[0]
            penalty_type = row_best_param['penalty_type'].values[0]
            solver = row_best_param['solver'].values[0]
            tol = row_best_param['tol'].values[0]
            T_param = row_best_param['T'].values[0]
            # add the penalty strength to the list
            if method == 'WS:JTT':
                lambda_JTT =row_best_param['lambda_JTT'].values[0]
        else:
            T_param = T
            lambda_JTT=1.0
           

        # set the model parameters
        model_param = {'penalty_type':penalty_type,
                        'penalty_strength':penalty_strength, 
                        'solver':solver, 
                        'tol':tol,
                        'seed':seed, 
                        'T':T_param,
                            'parallel_fit':parallel_fit, 'eta0': 0.0, 'learning_rate':'optimal', 'use_SGDClassifier':use_SGDClassifier}
       
       
        print('-----')
        print('Model param to use: ', model_param)

        
        # define the model if the method is WeightSearcher:ERM-GW
        if method == 'WS:ERM-GW':

            # set parameters for weight searcher in the case of ERM-GW
            analytical_hessian=True
            stable_exp=True
            grad_type='ift'

            # create weight searcher
            if start_p_value=='standard':
                start_p = {1: 0.25, 2: 0.25, 3: 0.25, 4: 0.25}
            if start_p_value=='train':
                start_p = p_train

            # fit the model with the optimal weights
            logreg_model = logistic_regression(model_param, p_weights=p_ood, p_train=p_train, add_intercept=True)
            logreg_model.fit(X_train.cpu(), y_train.squeeze(-1).cpu(), g_train.cpu())

            
            # measure the loss on val data
            y_val_pred= logreg_model.predict(X_val.cpu())
            loss_val= logreg_model.loss(y_val_pred, y_val.cpu(), weights_val)
            wg_val, _, weighted_acc_val = wg_acc(y_val.cpu(), torch.round(torch.sigmoid(y_val_pred)), g_val.cpu(), return_all=True)

            # create weight searcher
            searcher = weight_searcher(logistic_regression, p_train, X_train, y_train, g_train, X_val, y_val, g_val, weights_val_obj, penalty_type=penalty_type, penalty_strength=penalty_strength, solver=solver, seed=seed, tol=tol,
                                        grad_type=grad_type, GDRO=GDRO, T=T, parallel_fit=parallel_fit, use_SGDClassifier=use_SGDClassifier, eta0=model_param['eta0'], learning_rate=model_param['learning_rate'])
            # conduct the weight search
            set_seed(seed)
            p_tilde = searcher.exp_gradient_descent(start_p, lr,  eps, patience, T=T, save_trajectory=save_trajectory,
                                                                                momentum=momentum,  gradient_clip=grad_clip,  lr_schedule=lr_schedule, decay=decay, p_min=p_min, stable_exp=stable_exp, 
                                                                                analytical_hessian=analytical_hessian, diff=0.01, eta=eta_q_GDRO, subsample_weights=False, normalize=True, lock_in_p_g=None)
        
            # fit the model with the optimal weights
            logreg_model = logistic_regression(model_param, p_weights=p_tilde, p_train=p_train, add_intercept=True)
            logreg_model.fit(X_train.cpu(), y_train.squeeze(-1).cpu(), g_train.cpu())
        
        elif method =='WS:ERM-SUBG':

            # set parameters for weight searcher
            analytical_hessian=True
            stable_exp=False
            grad_type='ift'

            # this is the group where p_g is smallest in p_train
            lock_in_p_g =  key_of_min(p_train)
            print('Lock in p_g: ', lock_in_p_g)

            # get the n_dict_train, determine p_min
            groups_train, counts_train = np.unique(g_train, return_counts=True)
            n_dict_train = dict(zip(groups_train, counts_train))
            p_min = subsample_clip_p_min(n_dict_train)
            print('p_min: ', p_min)
        
            # Determine the start_p
            if start_p_value=='standard':
                start_p = calc_subsample_ood_weights(p_train, X_train.shape[0])
                print('Start p: ', start_p)
            if start_p_value=='train':
                start_p =  {1: 1.0, 2: 1.0, 3: 1.0, 4: 1.0}

            # create weight searcher
            set_seed(seed)
            searcher = weight_searcher(logistic_regression_subsample, p_train, X_train, y_train, g_train, X_val, y_val, g_val, weights_val_obj, penalty_type=penalty_type, penalty_strength=penalty_strength, solver=solver, seed=seed, tol=tol,
                                        grad_type=grad_type, GDRO=GDRO, T=T, parallel_fit=parallel_fit, use_SGDClassifier=use_SGDClassifier, eta0=model_param['eta0'], learning_rate=model_param['learning_rate'])
            # conduct the weight search
            p_tilde  = searcher.exp_gradient_descent(start_p, lr,  eps, patience, T=T, save_trajectory=save_trajectory,
                                                                        momentum=momentum,  gradient_clip=grad_clip,  lr_schedule=lr_schedule, decay=decay, p_min=p_min, stable_exp=stable_exp, 
                                                                        analytical_hessian=analytical_hessian, subsample_weights=True, normalize=False, lock_in_p_g=lock_in_p_g)
            
            # fit the model with the optimal weights
            set_seed(seed)
            logreg_model = logistic_regression_subsample(model_param, p_weights=p_tilde, p_train=p_train, add_intercept=True)
            logreg_model.fit(X_train.cpu(), y_train.squeeze(-1).cpu(), g_train.cpu())

            
        # if method is WS:GDRO
        if method == 'WS:GDRO':

            # set parameters for weight searcher in the case of ERM-GW
            analytical_hessian=True
            stable_exp=True
            grad_type='ift'

            # create weight searcher
            if start_p_value=='standard':
                start_p = {1: 0.25, 2: 0.25, 3: 0.25, 4: 0.25}
            if start_p_value=='train':
                start_p = p_train


            # create weight searcher
            searcher = weight_searcher(logistic_regression, p_train, X_train, y_train, g_train, X_val, y_val, g_val, weights_val_obj, penalty_type=penalty_type, penalty_strength=penalty_strength, solver=solver, seed=seed, tol=tol,
                                        grad_type=grad_type, GDRO=True, T=T, parallel_fit=parallel_fit, use_SGDClassifier=use_SGDClassifier, eta0=model_param['eta0'], learning_rate=model_param['learning_rate'])
            # conduct the weight search
            set_seed(seed)

            p_tilde = searcher.exp_gradient_descent(start_p, lr,  eps, patience, T=T, save_trajectory=save_trajectory,
                                                                                momentum=momentum,  gradient_clip=grad_clip,  lr_schedule=lr_schedule, decay=decay, p_min=p_min, stable_exp=stable_exp, 
                                                                                analytical_hessian=analytical_hessian, diff=0.01, eta=eta_q_GDRO, subsample_weights=False, normalize=True, lock_in_p_g=None)
        
            # fit the model with the optimal weights
            logreg_model = logistic_regression(model_param, p_weights=p_tilde, p_train=p_train, add_intercept=True)
            logreg_model.fit(X_train.cpu(), y_train.squeeze(-1).cpu(), g_train.cpu())

        elif method =='WS:DFR':
            # set parameters for weight searcher
            analytical_hessian=True
            stable_exp=False
            grad_type='ift'

            # this is the group where p_g is smallest in p_train
            lock_in_p_g =  key_of_min(p_train)
            print('Lock in p_g: ', lock_in_p_g)

            # change the weights val obj. 
            p_val_orig = data_obj.get_p_dict(g=g_val_orig)

            # get the n_dict_train
            groups_train, counts_train = np.unique(g_train, return_counts=True)
            n_dict_train = dict(zip(groups_train, counts_train))
            p_min = p_min = subsample_clip_p_min(n_dict_train)
            print('p_min: ', p_min)
            
            # Determine the start_p
            set_seed(seed)
            if start_p_value=='standard':
                start_p = calc_subsample_ood_weights(p_train, X_train.shape[0])
            if start_p_value=='train':
                start_p =  {1: 1.0, 2: 1.0, 3: 1.0, 4: 1.0}
            
            # create weight searcher
            searcher = weight_searcher(DFR_model, p_train, X_train, y_train, g_train, X_val, y_val, g_val, weights_val_obj, penalty_type=penalty_type, penalty_strength=penalty_strength, solver=solver, seed=seed, tol=tol,
                                        grad_type=grad_type, GDRO=GDRO, T=T, parallel_fit=parallel_fit, use_SGDClassifier=use_SGDClassifier, eta0=model_param['eta0'], learning_rate=model_param['learning_rate'])

            # conduct the weight search
            set_seed(seed)
            p_tilde  = searcher.exp_gradient_descent(start_p, lr,  eps, patience, T=T, save_trajectory=save_trajectory,
                                                                        momentum=momentum,  gradient_clip=grad_clip,  lr_schedule=lr_schedule, decay=decay, p_min=p_min, stable_exp=stable_exp, 
                                                                        analytical_hessian=analytical_hessian,  subsample_weights=True, normalize=False, lock_in_p_g=lock_in_p_g)
    
            # create the DFR model
            set_seed(seed)

            # if we need to split the validation data, we need to set the DFR weights based on the validation data
            if DFR_data=='val':
                logreg_model = DFR_model(model_param, p_weights=p_tilde, p_train=p_val_orig, add_intercept=True)
                logreg_model.fit(X_val_orig.cpu(), y_val_orig.squeeze(-1).cpu(), g_val_orig.cpu())
                
            
            elif DFR_data=='train':
                # create the DFR model
                logreg_model = DFR_model(model_param, p_weights=p_tilde, p_train=p_train, add_intercept=True)
                logreg_model.fit(X_train.cpu(), y_train.squeeze(-1).cpu(), g_train.cpu())
            
            
        elif method == 'WS:JTT':

        
            # define class weights in the training set
            p_train_identifier = {0:p_train[1]+ p_train[2], 1:p_train[3]+ p_train[4]}
            model_param_identifier = model_param
            if penalty_strength_predictor is None or penalty_strength_predictor is None:
                model_param_predictor = model_param
                print('Copying model param')
            else:
                model_param_predictor = {'penalty_type':penalty_type_predictor, 'penalty_strength':penalty_strength_predictor, 'solver':solver, 'tol':tol, 'seed':1, 'T':100}
            

            # if we are not creating an identifier, load the predictions
            if create_identifier:
                pred_train, pred_val, pred_test = None, None, None
            else:
                # get predictions
                if dataset=='WB':
                    pred_train, pred_val, pred_test =  data_obj.return_WB_pred(seed, 32, True, False, JTT_folder=True)
                elif dataset == 'CelebA':
                    pred_train, pred_val, pred_test = data_obj.return_CelebA_pred(seed)
                elif dataset == 'multiNLI':
                        pred_train, pred_val, pred_test = data_obj.return_multiNLI_pred(seed)

                # put the predictions on the right device
                pred_train, pred_val, pred_test = pred_train.to(device), pred_val.to(device), pred_test.to(device)

                if fraction_original_data < 1.0:
                        pred_train = pred_train[indices_train]
                        pred_val = pred_val[indices_val]
            
            # fit the model
            JTT_model_obj = JTT_model(model_param_identifier, model_param_predictor, p_weights=p_ood, p_train=p_train, add_intercept=True, class_balanced_identifier=True, p_train_identifier=p_train_identifier, create_identifier=create_identifier)

            # fit the JTT model
            y_hat_class = torch.round(pred_train).squeeze(-1)
            if dataset == 'CelebA':
                batched = True
            else:
                batched = False
            g_train_JTT = JTT_model_obj.get_g_train_JTT(y_train.squeeze(-1), y_hat_class, batched=batched).squeeze(-1)

            # get the p_dict_train_JTT
            p_train_JTT = JTT_model_obj.get_p_dict(g=g_train_JTT)

            # set parameters for weight searcher
            save_trajectory=False
            analytical_hessian=True
            stable_exp=True
            grad_type='ift'

            # determine the start_p
            set_seed(seed)
            if start_p_value=='standard':
                start_p = {1: 0.25, 2: 0.25, 3: 0.25, 4: 0.25}
         

            if start_p_value=='train':
                start_p = p_train
           
            # create weight searcher
            searcher = weight_searcher(logistic_regression, p_train_JTT, X_train, y_train, g_train_JTT, X_val, y_val, g_val, weights_val_obj, penalty_type=model_param_predictor['penalty_type'],
                                        penalty_strength=model_param_predictor['penalty_strength'], solver=solver, seed=seed, tol=tol, grad_type=grad_type, 
                                    GDRO=GDRO, T=T, parallel_fit=parallel_fit, use_SGDClassifier=use_SGDClassifier,  eta0=model_param['eta0'], learning_rate=model_param['learning_rate'])
            
            # conduct the weight search
            set_seed(seed)
            p_tilde = searcher.exp_gradient_descent(start_p, lr,  eps, patience, T=T, save_trajectory=save_trajectory,
                                                                                momentum=momentum,  gradient_clip=grad_clip,  lr_schedule=lr_schedule, decay=decay, p_min=p_min, stable_exp=stable_exp, 
                                                                                analytical_hessian=analytical_hessian,  subsample_weights=False, normalize=True, lock_in_p_g=None)
        

            # fit the model with the optimal weights
            logreg_model = logistic_regression(model_param_predictor, p_weights=p_tilde, p_train=p_train_JTT, add_intercept=True)
            logreg_model.fit(X_train.cpu(), y_train.squeeze(-1).cpu(), g_train_JTT.cpu())

        
        # measure the loss on val data
        y_val_pred= logreg_model.predict(X_val.cpu())
        loss_val= logreg_model.loss(y_val_pred, y_val.cpu(), weights_val)
        wg_val, _, weighted_acc_val = wg_acc(y_val.cpu(), torch.round(torch.sigmoid(y_val_pred)), g_val.cpu(), return_all=True)

      
        # measure the loss on test data
        y_test_pred= logreg_model.predict(X_test.cpu())
        loss_test= logreg_model.loss(y_test_pred, y_test.cpu(), weights_test)
        wg_test, _, weighted_acc_test = wg_acc(y_test.cpu(), torch.round(torch.sigmoid(y_test_pred)), g_test.cpu(), return_all=True)

        # store the results
        results['seed'].append(seed)
        results['loss_val'].append(loss_val.item())
        results['wg_val'].append(wg_val.item())
        results['weighted_acc_val'].append(weighted_acc_val)
        results['loss_test'].append(loss_test.item())
        results['wg_test'].append(wg_test.item())
        results['weighted_acc_test'].append(weighted_acc_test)

        # if verbose, post intermediate results
        if verbose:
            print('At seed: {}, the selected penalty is {}, loss val/test is: {:.3f}/{:.3f}, wg val/test is: {:.3f}/{:.3f}, weighted acc val/test is: {:.3f}/{:.3f}'.format(seed, penalty_strength, loss_val, loss_test, wg_val, wg_test, weighted_acc_val, weighted_acc_test))

        # add model to list of models, and create a key
        run_key = get_key(method, dataset, penalty_type, penalty_strength, solver, tol, batch_size_model, augmentation_data, early_stopping_model, lambda_JTT, DFR_data, penalty_type_predictor, penalty_strength_predictor, create_identifier, p_y_JTT, T, C_GDRO, eta_param_GDRO, eta_q_GDRO, fraction_original_data, seed)

        # if method == JTT and GDRO=True, add to the run key that it is a GDRO result
        if method == 'WS:JTT' and GDRO:
            run_key += '_GDRO'
        
        # add to the run key that it is a weight searcher result
        run_key += '_WS'
        
        # if add_lr_mom_to_key, add the lr and momentum to the key
        if add_lr_mom_to_key:
            run_key += '_lr_{}_mom_{}'.format(lr, momentum)
        
        # if train_val_split is not original, add to key
        if train_val_split != 'original':
            train_val_split_str = str(train_val_split).replace('.', '')
            run_key = run_key + '_train_val_split_' + train_val_split_str
        

        keys.append(run_key)

        # save the models from list of models according to their key as a pickle file
        # first, in the model folder, check if there is a folder for the method
        if not os.path.exists(model_folder+ '/' + method):
            os.makedirs(model_folder + '/' + method)
        # second, save the model
        with open(model_folder + '/' + method + '/' + run_key + '.pkl', 'wb') as f:
            pickle.dump(logreg_model, f)
    

   
    # create a result entry
    results_run = create_result_entry(results, method, dataset, penalty_type, model_param['penalty_strength'], solver, tol, batch_size_model, augmentation_data, early_stopping_model, lambda_JTT, DFR_data, penalty_type_predictor, penalty_strength_predictor, create_identifier, p_y_JTT, T, C_GDRO, eta_param_GDRO, eta_q_GDRO, fraction_original_data)
    results_run['key'] = keys

    # add to the result entry the key hyper-parameters of ws
    results_run['lr'] = lr
    results_run['lr_schedule'] = lr_schedule
    results_run['decay'] = decay
    results_run['momentum'] = momentum
    results_run['grad_clip'] = grad_clip
    results_run['GDRO'] = GDRO
    results_run['start_p_value'] = start_p_value
    results_run['T'] = T
    results_run['patience'] = patience
    results_run['use_SGDClassifier'] = use_SGDClassifier

     # if the train_val_split is not original, add to the df
    if train_val_split != 'original':
        results_run['train_val_split'] = train_val_split



    # turn the dict into a pd.DataFrame table
    print('Results run: ', results_run)
    results_run_table = pd.DataFrame(results_run)

    # print the results
    if verbose:
        # calculate the average and standard error
        _, _, wg_test_avg, wg_test_se, weighted_acc_test_avg, weighted_acc_test_se = return_results(results, seeds, type_result='test')

        print('Param combination: {}'.format(model_param))
        print('WG test - Avg: {:.3f}, SE: {:.3f}'.format(wg_test_avg, wg_test_se))
        print('Weighted acc test - Avg: {:.3f}, SE: {:.3f}'.format(weighted_acc_test_avg, weighted_acc_test_se))

    time_after = time.time()
    print('Time taken: {}'.format(time_after-start_time))
    
    # print the result dict
    print('--- Results ---')

    if save:

        # add to the table the following information - the time of the run
        # get the time at the minute
        results_run_table['time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())

        # check if in the result folder there is already a main file
        main_file = 'main_ws_{}.csv'.format(dataset)
        if os.path.exists(args.result_folder + '/' + main_file):

            # load the main file
            main_table = pd.read_csv(args.result_folder + '/' + main_file)

            # concat in such a way that if there are new columns, they are added
            # where the empty values are filled with NaN
            main_table = pd.concat([main_table, results_run_table], axis=0, sort=False)

            # check: entries with the same key? - if so, remove the old one
            main_table = main_table.drop_duplicates(subset='key', keep='last')


            # write back to the main file
            main_table.to_csv(args.result_folder + '/' + main_file, index=False)


        else:
            results_run_table.to_csv(args.result_folder + '/' + main_file, index=False)







if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Dataset preparation')
    parser.add_argument('--dataset', type=str, default='WB', help='Dataset to use')
    parser.add_argument('--seeds', type=str, default='1', help='Number of seeds to use')
    parser.add_argument('--method', type=str, default='last_layer', help='Method to use')
    parser.add_argument('--device', type=str, default='cpu', help='Device to use')
    parser.add_argument('--solver', type=str, default='liblinear', help='Solver to use')
    parser.add_argument('--tol', type=float, default=1e-4, help='Tolerance to use')
    parser.add_argument('--batch_size_model', type=int, default=64, help='Batch size to use for the model')
    parser.add_argument('--augmentation_data', type=str, default='False', help='Whether to use data augmentation')
    parser.add_argument('--early_stopping_model', type=str, default='False', help='Whether to use early stopping')
    parser.add_argument('--selection_criteria', type=str, default='wg_val', help='Selection criteria to use')
    parser.add_argument('--DFR_data', type=str, default='train', help='Data to use for DFR')
    parser.add_argument('--penalty_type_predictor', type=str, default='l1', help='Penalty type for the predictor')
    parser.add_argument('--penalty_strength_predictor', type=float, default=None, help='Penalty strength for the predictor')
    parser.add_argument('--create_identifier', type=str, default='False', help='Whether to create an identifier')
    parser.add_argument('--T', type=int, default=100, help='Number of iterations ')
    parser.add_argument('--eta_q_GDRO', type=float, default=0.1, help='Eta q for GDRO')
    parser.add_argument('--C_GDRO', type=float, default=0.0, help='C for GDRO')
    parser.add_argument('--fraction_original_data', type=float, default=1.0, help='Fraction of original data to use')
    parser.add_argument('--train_val_split', default='original', help='Split to use for training and validation data')
    parser.add_argument('--result_folder', type=str, default='results/20092024_set', help='Whether to use parallel fit')
    parser.add_argument('--model_folder', type=str, default='models/last_layer_models', help='Whether to use parallel fit')
    parser.add_argument('--use_SGDClassifier', type=str, default='False', help='Whether to use SGDClassifier')
    parser.add_argument('--lr', type=float, default=0.1, help='Learning rate')
    parser.add_argument('--lr_schedule', type=str, default='constant', help='Learning rate schedule')
    parser.add_argument('--decay', type=float, default=0.1, help='Decay')
    parser.add_argument('--momentum', type=float, default=None, help='Momentum')
    parser.add_argument('--patience', type=int, default=100, help='Patience')
    parser.add_argument('--grad_clip', type=str, default='False', help='Grad clip to use')
    parser.add_argument('--GDRO', type=str, default='False', help='Whether to use GDRO')
    parser.add_argument('--start_p_value', type=str, default='standard', help='Start p value')
    parser.add_argument('--penalty_strength', type=str, default=None, help='Penalty strength')
    parser.add_argument('--penalty_type', type=str, default='l1', help='Penalty type')
    parser.add_argument('--save', type=str, default='True', help='Whether to save the results')




    args = parser.parse_args()

    # convert seed to list of ints
    seeds = [int(seed) for seed in args.seeds.split('-')]

    def str_to_bool(text):
        if text.lower() == 'true':
            return True
        elif text.lower() == 'false':
            return False
    
    args.early_stopping_model = str_to_bool(args.early_stopping_model)
    args.augmentation_data = str_to_bool(args.augmentation_data)
    args.create_identifier = str_to_bool(args.create_identifier)
    args.use_SGDClassifier = str_to_bool(args.use_SGDClassifier)
    args.grad_clip = str_to_bool(args.grad_clip)
    if not args.grad_clip:
        args.grad_clip = None
    args.GDRO = str_to_bool(args.GDRO)
    args.save = str_to_bool(args.save)

    # if train_val_split is not original, turn to float
    if args.train_val_split != 'original':
        args.train_val_split = float(args.train_val_split)
 
   
    # Run the main function
    main(args.dataset, seeds, args.method, args.device,   args.solver, args.tol,
          args.batch_size_model, args.augmentation_data, args.early_stopping_model,
          DFR_data=args.DFR_data, selection_criteria=args.selection_criteria, 
          T=args.T, patience=args.patience, lr=args.lr, eta_q_GDRO=args.eta_q_GDRO, lr_schedule=args.lr_schedule, 
          decay=args.decay, p_min=10e-4, momentum=args.momentum, grad_clip=args.grad_clip, GDRO=args.GDRO, start_p_value=args.start_p_value, eps=0,
          verbose=True, penalty_type_predictor=args.penalty_type_predictor, penalty_strength_predictor=args.penalty_strength_predictor, 
          create_identifier=args.create_identifier, C_GDRO=args.C_GDRO, use_SGDClassifier=args.use_SGDClassifier,
          fraction_original_data=args.fraction_original_data, result_folder=args.result_folder, model_folder=args.model_folder,
            penalty_strength=args.penalty_strength, penalty_type=args.penalty_type, save=args.save, train_val_split=args.train_val_split)

