

from weights import *
from weight_searcher import *
from helpers import *
from data import WB, CelebA, multiNLI
from metrics import *
from helpers import get_best_hyper_param_method
import argparse
import pandas as pd
import os
import pickle
from param_search_last_layer import efficient_split

    


def main(dataset, seeds, method, device, batch_size_model, augmentation_data, early_stopping_model, create_identifier=False,
            DFR_data='train', result_folder=None, model_folder=None,fraction_original_data=1.0, selection_criteria='wg_val', save=True, verbose=True, train_val_split='original'):


    # create dict for results
    result_table = dict.fromkeys(seeds)

    # select all entries with the dataset, method, seed, and the dataset settings (early stopping, augmentation data, DFR_data, fraction_original_data)
    param_search_df = pd.read_csv('{}/main_param_search.csv'.format(result_folder))
    param_search_df = param_search_df[(param_search_df['dataset']==dataset) & (param_search_df['method']==method) & (param_search_df['early_stopping_model']==early_stopping_model) & (param_search_df['augmentation_data']==augmentation_data) & (param_search_df['DFR_data']==DFR_data) & (param_search_df['fraction_original_data']==fraction_original_data)]

    if train_val_split != 'original':
        train_val_split_float = float(train_val_split)
        param_search_df = param_search_df[param_search_df['train_val_split']==train_val_split_float]

    print('Selecting from penalties: ', param_search_df['penalty_strength'].unique())


    # sense check; if there are multiple values of T and method ERM-GW, ERM-SUBG, JTT, take the one with the most recent time
    if method in ['ERM-GW', 'ERM-SUBG', 'DFR']:
        # count the number of unique values of T
        unique_T = param_search_df['T'].unique()
        if len(unique_T) > 1:
            param_search_df = param_search_df[param_search_df['T']==unique_T[-1]]

            print('param_search_df: ', param_search_df.head())
            print('value of T: ', param_search_df['T'].unique())
        
    print('These results are from the time: ', param_search_df['time'].unique())

    
    # loop over the seeds
    for seed in seeds:


        # get the data if the dataset is WB
        if dataset=='WB':
            data_obj = WB()
                    
            # load the data
            X_train, y_train, X_val, y_val, X_test, y_test, g_train, g_val, g_test = data_obj.return_WB_embeddings(seed, batch_size_model, early_stopping_model, augmentation_data, JTT_folder=False)
        # get the data if the dataset is CelebA
        elif dataset=='CelebA':
            data_obj = CelebA()

            # load the data
            X_train, y_train, X_val, y_val, X_test, y_test, g_train, g_val, g_test = data_obj.return_CelebA_data(seed, batch_size_model, early_stopping_model, augmentation_data, device,  train_val_split='original')

        elif dataset=='multiNLI':
            data_obj = multiNLI()

            # load the data
            X_train, y_train, X_val, y_val, X_test, y_test, g_train, g_val,g_test = data_obj.load_embeddings(seed, batch_size_model, early_stopping_model, device)

        # change fraction of original data
        if fraction_original_data < 1.0:
            set_seed(seed)
            X_train, y_train, g_train, indices_train = get_fraction_original_data( fraction_original_data, X_train, y_train, g_train)
            X_val, y_val, g_val, indices_val = get_fraction_original_data(fraction_original_data,X_val, y_val, g_val)
            print('Fraction of original data used: {}'.format(fraction_original_data))


        # if we need to split the validation data, we need to split the validation data
        if train_val_split != 'original':
            X_train, y_train, g_train, X_val, y_val, g_val = efficient_split(X_train, y_train, g_train, X_val, y_val, g_val, train_val_split)
            print('After split train size: {}, val size: {}, ratio is {}'.format(X_train.shape[0], X_val.shape[0], train_val_split_float))
            print('Division of groups in train: {}'.format(np.unique(g_train, return_counts=True)))
            print('Division of groups in val: {}'.format(np.unique(g_val, return_counts=True)))
            print('Division of groups in test: {}'.format(np.unique(g_test, return_counts=True)))

       
     
        # get the predictions 
        if not create_identifier:
            if dataset=='WB':
                # get the predictions
                pred_train, pred_val, pred_test =  data_obj.return_WB_pred(seed, 32, True, False, JTT_folder=True)

                # put the predictions on the right device
                pred_train, pred_val, pred_test = pred_train.to(device), pred_val.to(device), pred_test.to(device)

            elif dataset=='CelebA':
                # get the predictions
                pred_train, pred_val, pred_test = data_obj.return_CelebA_pred(seed)
                
                # put the predictions on the right device
                pred_train, pred_val, pred_test = pred_train.to(device), pred_val.to(device), pred_test.to(device)
            
            elif dataset=='multiNLI':
               
                pred_train, pred_val, pred_test = data_obj.return_multiNLI_pred(seed)

                # put the predictions on the right device
                pred_train, pred_val, pred_test = pred_train.to(device), pred_val.to(device), pred_test.to(device)

        else:
            pred_train, pred_val, pred_test = None, None, None

        # define  the (ood) weights
        p_ood = {1: 0.25, 2: 0.25, 3: 0.25, 4: 0.25}
        p_test = data_obj.get_p_dict(g=g_test)

        print('------------------------------------')
        
        # define a weight object for the val data based on p_ood
        weights_test_obj = weights(p_weights=p_ood, p_train=p_test)
        weights_test = weights_test_obj.get_weights_sample(g_test)

        # from the param_search_df, get all results for seed
        param_search_df_seed = param_search_df[param_search_df['seed']==seed]

        # get the best hyper-parameters
        best_param, param_to_search = get_best_hyper_param_method(method, param_search_df_seed,  selection_criteria=selection_criteria)
        
        # get the key
        model_key = best_param['key']
        print('model key: {}'.format(model_key))
        print('For seed {} of dataset {}, with method {}, the best hyper parameters are: {}'.format(seed, dataset, method, best_param[param_to_search]))
        
        
        # from the model folder, get the model in the folder of the method
        # it is saved as a pickle file under the key
        with open('{}/{}/{}.pkl'.format(model_folder, method, model_key), 'rb') as f:
            logreg_model = pickle.load(f)
            print('Loaded model from {}'.format('{}/{}/{}.pkl'.format(model_folder, method, model_key)))
            print('Coefficients are: {}'.format(logreg_model.Beta))
            
        
        # measure the loss on test data 
        print('X_test: {}'.format(X_test))
        print('X_val {}'.format(X_val))
        y_test_pred= logreg_model.predict(X_test.cpu())
        loss_test = logreg_model.loss(y_test_pred, y_test.cpu(), weights_test)
        wg_test, _, weighted_acc_test= wg_acc(y_test.cpu(), torch.round(torch.sigmoid(y_test_pred)), g_test.cpu(), return_all=True)

        # store the results
        print("At seed {},  loss_test: {:.3f}, wg_test: {:.3f}, weighted_acc_test: {:.3f}".format(seed,  loss_test,  wg_test, weighted_acc_test))
        
        # get the row
        y_val_pred = logreg_model.predict(X_val.cpu())
        loss_val = logreg_model.loss(y_val_pred, y_val.cpu())
        wg_val, _, weighted_acc_val = wg_acc(y_val.cpu(), torch.round(torch.sigmoid(y_val_pred)), g_val.cpu(), return_all=True)
        print("At seed {},  loss_val: {:.3f}, wg_val: {:.3f}, weighted_acc_val: {:.3f}".format(seed,  loss_val,  wg_val, weighted_acc_val))

 
        row = param_search_df_seed[param_search_df_seed['key']==model_key]
        print('time: ', row['time'])
     

        # add the loss_test, wg_test, weighted_acc_test, and selection_criteria to the row without causing a warning
        row = row.copy()
        
        # at row, put results
        row.loc[row.index[0],'loss_test'] = loss_test.item()
        row.loc[row.index[0],'wg_test'] = wg_test.item()
        row.loc[row.index[0],'weighted_acc_test'] = weighted_acc_test
        
        # add information - the selection criteria, and selected hyper-parameters
        row.loc[row.index[0],'selection_criteria'] = selection_criteria

        # save in table dict
        result_table[seed] = row


    # save the results in a csv file
    result_table = pd.concat(result_table.values())
    metrics = ['loss_val', 'loss_test', 'wg_val', 'wg_test', 'weighted_acc_val', 'weighted_acc_test']

    
     # print the results
    if verbose:
        # calculate the average and standard error
        result_table_avg = result_table[metrics].mean()
        result_table_se = result_table[metrics].sem()
        print('Results for dataset: {}, method: {}'.format(dataset, method))
        print('Average results: ', result_table_avg)

    if save:
        print('Saving results in {}'.format(args.result_folder))

        # check if it exists
        main_file = 'main_results_{}.csv'.format(dataset)
        if os.path.exists(args.result_folder + '/' + main_file):

            # load the main file
            main_table = pd.read_csv(args.result_folder + '/' + main_file)

            # concat in such a way that if there are new columns, they are added
            # where the empty values are filled with NaN
            main_table = pd.concat([main_table, result_table], axis=0, sort=False)

            # check: entries with the same key? - if so, remove the old one
            main_table = main_table.drop_duplicates(subset='key', keep='last')

            # write back to the main file
            main_table.to_csv(args.result_folder + '/' + main_file, index=False)

        else:
            result_table.to_csv(args.result_folder + '/' + main_file, index=False)


    


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Dataset preparation')
    parser.add_argument('--dataset', type=str, default='WB', help='Dataset to use')
    parser.add_argument('--seeds', type=str, default='1', help='Number of seeds to use')
    parser.add_argument('--method', type=str, default='last_layer', help='Method to use')
    parser.add_argument('--device', type=str, default='cuda', help='Device to use')
    parser.add_argument('--batch_size_model', type=int, default=32, help='Batch size to use')
    parser.add_argument('--early_stopping_model', type=str, default='False', help='Whether to use early stopping')
    parser.add_argument('--augmentation_data', type=str, default='False', help='Whether to use data augmentation')
    parser.add_argument('--DFR_data', type=str, default='train', help='Data to use for DFR')
    parser.add_argument('--fraction_original_data', type=float, default=1.0, help='Fraction of original data to')
    parser.add_argument('--result_folder', default='results/20092024_set/', help='Folder to save results')
    parser.add_argument('--model_folder', default='models/last_layer_models', help='Folder to get models')
    parser.add_argument('--selection_criteria', type=str, default='wg_val', help='Selection criteria')
    parser.add_argument('--save', type=str, default='True', help='Whether to save the results')
    parser.add_argument('--train_val_split', type=str, default='original', help='The train-val split')



    args = parser.parse_args()

    # convert seed to list of ints
    seeds = [int(seed) for seed in args.seeds.split('-')]

    def str_to_bool(text):
        if text.lower() == 'true':
            return True
        elif text.lower() == 'false':
            return False
    
    args.early_stopping_model = str_to_bool(args.early_stopping_model)
    args.augmentation_data = str_to_bool(args.augmentation_data)
    args.save = str_to_bool(args.save)
 
   
    # Run the main function
    main(args.dataset, seeds, args.method, args.device, args.batch_size_model, args.augmentation_data, args.early_stopping_model,
            DFR_data=args.DFR_data, result_folder=args.result_folder,model_folder= args.model_folder, fraction_original_data=args.fraction_original_data, 
            selection_criteria=args.selection_criteria, save=args.save, train_val_split=args.train_val_split)
