

from model import logistic_regression, logistic_regression_subsample, DFR_model, JTT_model, BERT_model
from weights import *
from weight_searcher import *
from data import Toy
import numpy as np
import matplotlib.pyplot as plt
from helpers import *
import sys
import matplotlib as mpl
from data import WB, CelebA, multiNLI
from metrics import *
import argparse
import time
import pandas as pd
import os
import pickle 
import gc

def efficient_split(X_train, y_train, g_train, X_val, y_val, g_val, train_val_split):
    X_temp = torch.concatenate((X_train, X_val), axis=0)
    y_temp = torch.concatenate((y_train, y_val), axis=0)
    g_temp = torch.concatenate((g_train, g_val), axis=0)
    del X_train, y_train, g_train, X_val, y_val, g_val

    # shuffle the data
    i = np.arange(X_temp.shape[0])
    set_seed(0) # in order to ensure the shuffle is the same for all seeds
    np.random.shuffle(i)
    print('i is: {}'.format(i))
    X, y, g = X_temp[i], y_temp[i], g_temp[i]
    del X_temp, y_temp, g_temp

    # split the data
    train_val_split_float = float(train_val_split)
    n_train = int(train_val_split_float*X.shape[0])
    X_train_new, y_train_new, g_train_new = X[:n_train], y[:n_train], g[:n_train]
    X_val_new, y_val_new, g_val_new = X[n_train:], y[n_train:], g[n_train:]
    del X, y, g

    # remove the temporary variables
    gc.collect()

    return X_train_new, y_train_new, g_train_new, X_val_new, y_val_new, g_val_new


def main(dataset, seeds, method, device, penalty_strengths,  penalty_type, solver, tol,
          batch_size_model, augmentation_data, early_stopping_model, 
          lambda_JTT, DFR_data='val', selection_criteria='wg_val_avg',
            warm_start=False, verbose=True, penalty_type_predictor=None, penalty_strength_predictor=None, 
            create_identifier=False,  p_y_JTT = 0.5, T=50, C_GDRO=0.0, eta_param_GDRO=0.00001, eta_q_GDRO=0.1,
              fraction_original_data=1.0, parallel_fit=False, frac_DFR_val_data=0.5, result_folder='results', model_folder='models', train_val_split='original', save=False):

    # create parameter combinations of - penalty strength, lambda_JTT, C_GDRO, eta_param_GDRO
    param_combinations = [{'penalty_strength':penalty_strength, 'lambda_JTT':lambda_JTT_i, 'C_GDRO':C_GDRO_i, 'eta_param_GDRO':eta_param_GDRO_i} for penalty_strength in penalty_strengths for lambda_JTT_i in lambda_JTT for C_GDRO_i in C_GDRO for eta_param_GDRO_i in eta_param_GDRO]
           
    # define the dictionary to store the results - count n of combinations
    n_combinations = len(param_combinations)
    result_table = dict.fromkeys(list(range(n_combinations)))

    # Take the start time
    start_time = time.time()

    # for each seed, load the data
    list_data_dict = []


    # loop over all seeds
    for seed in seeds:
        
        # for Waterbirds, we need to load the embeddings
        if dataset=='WB':
            data_obj = WB()
                    
            # load the data
            X_train, y_train, X_val, y_val, X_test, y_test, g_train, g_val, g_test = data_obj.return_WB_embeddings(seed, batch_size_model, early_stopping_model, augmentation_data, JTT_folder=False)

        # for CelebA, we need to load the embeddings
        elif dataset=='CelebA':
            data_obj = CelebA()

            # load the data
            X_train, y_train, X_val, y_val, X_test, y_test, g_train, g_val,g_test = data_obj.return_CelebA_data(seed, batch_size_model, early_stopping_model, augmentation_data, device, train_val_split='original')

     

        elif dataset=='multiNLI':
            data_obj = multiNLI()

            # load the data
            X_train, y_train, X_val, y_val, X_test, y_test, g_train, g_val,g_test = data_obj.load_embeddings(seed, batch_size_model, early_stopping_model, device)

            
        # change fraction of original data
        if fraction_original_data < 1.0:
            set_seed(seed)
            X_train, y_train, g_train, indices_train = get_fraction_original_data( fraction_original_data, X_train, y_train, g_train)
            X_val, y_val, g_val, indices_val = get_fraction_original_data(fraction_original_data,X_val, y_val, g_val)
            print('Fraction of original data used: {}'.format(fraction_original_data))
        else:
            indices_train, indices_val = np.arange(X_train.shape[0]), np.arange(X_val.shape[0])


        # if train_val_split is not original, 1) combine the data, 2) shuffle the data, 3) split the data according to the ratio
        if train_val_split != 'original':
            train_val_split_float = float(train_val_split)
            X_train, y_train, g_train, X_val, y_val, g_val = efficient_split(X_train, y_train, g_train, X_val, y_val, g_val, train_val_split_float)



            print('After split train size: {}, val size: {}'.format(X_train.shape[0], X_val.shape[0]))
            print('Division of groups in train: {}'.format(np.unique(g_train, return_counts=True)))
            print('Division of groups in val: {}'.format(np.unique(g_val, return_counts=True)))
            print('Division of groups in test: {}'.format(np.unique(g_test, return_counts=True)))


        
        # fit the model
        if DFR_data=='val':

            X_val_orig, y_val_orig, g_val_orig = X_val, y_val, g_val

            # we now need to split the validation data into two parts
            # take all indeces, shuffle them
            np.random.seed(seed)
            indeces = np.arange(X_val.shape[0])
            np.random.shuffle(indeces)

            # select up to x% of the data for training
            n_frac = int(frac_DFR_val_data*X_val.shape[0])
            indeces_1 = indeces[:n_frac]
            indeces_2 = indeces[n_frac:]

            # X_train, y_train, g_train become the first half, X_val, y_val, g_val become the second half
            X_train = X_val[indeces_1]
            y_train = y_val[indeces_1]
            g_train = g_val[indeces_1]
            X_val = X_val[indeces_2]
            y_val = y_val[indeces_2]
            g_val = g_val[indeces_2]

        # define  the (ood) weights
        p_ood = {1: 0.25, 2: 0.25, 3: 0.25, 4: 0.25}
    
        p_train = data_obj.get_p_dict(g=g_train)
        p_val =  data_obj.get_p_dict(g=g_val)
        p_test = data_obj.get_p_dict(g=g_test)

        print('n in train: {}, n in val: {}, n in test: {}'.format(X_train.shape[0], X_val.shape[0], X_test.shape[0]))

        # define a weight object for the val data based on p_ood
        weights_val_obj = weights(p_weights=p_ood, p_train=p_val)
        weights_val = weights_val_obj.get_weights_sample(g_val)
        weights_test_obj = weights(p_weights=p_ood, p_train=p_test)
        weights_test = weights_test_obj.get_weights_sample(g_test)

        # store the data in a dictionary
        data_dict = {'X_train':X_train, 'y_train':y_train, 'X_val':X_val, 
                     'y_val':y_val, 'g_train':g_train, 'g_val':g_val,
                        'X_test':X_test, 'y_test':y_test, 'g_test':g_test,
                        'weights_val':weights_val, 'weights_test':weights_test,
                        'p_train':p_train, 'p_ood':p_ood, 'p_val':p_val, 'p_test':p_test,
                        'indices_train':indices_train, 'indices_val':indices_val}

        if DFR_data=='val':
            data_dict['X_val_orig'] = X_val_orig
            data_dict['y_val_orig'] = y_val_orig
            data_dict['g_val_orig'] = g_val_orig

        list_data_dict.append(data_dict)
   

    # count the ith combination
    i = 0

    # loop over all penalty strengths
    for param_combination in param_combinations:

        # create list for keys
        keys = []

        # create a dict with lists to store the results
        results = {'seed':[], 
                   'loss_val':[], 'loss_test':[],
                'wg_val':[],'wg_test':[],
                'weighted_acc_val':[], 'weighted_acc_test':[]}
        
        # define the model param of the logistic regression
        model_param = {'penalty_type':penalty_type, 
                       'penalty_strength':param_combination['penalty_strength'], 
                       'solver':solver, 
                       'tol':tol,
                         'seed':seed, 
                         'T':T, 'eta0':0.0, 'learning_rate':'optimal', 'parallel_fit':parallel_fit}

        # set whether to use the SGDClassifier
        if method == 'GDRO':
            model_param['use_SGDClassifier'] = True
        else:
            model_param['use_SGDClassifier'] = False
        
        # loop over all seeds
        for seed in seeds:

            # get the data
            print('seed: {}'.format(seed))
            data_dict = list_data_dict[seed-1]
            X_train, y_train, X_val, y_val, g_train, g_val = data_dict['X_train'], data_dict['y_train'], data_dict['X_val'], data_dict['y_val'], data_dict['g_train'], data_dict['g_val']


            # get the weights
            p_train, p_val = data_dict['p_train'], data_dict['p_val']
            weights_val, weights_test = data_dict['weights_val'], data_dict['weights_test']
            
 
             # define the model if the method is ERM
            if method == 'ERM':
                logreg_model = logistic_regression(model_param, p_weights=p_train, p_train=p_train, add_intercept=True, warm_start=warm_start)
            # define the model if the method is ERM-GW
            elif method == 'ERM-GW':
                logreg_model = logistic_regression(model_param, p_weights=p_ood, p_train=p_train, add_intercept=True, warm_start=warm_start)

            # define the model if the method is ERM-SUBG
            elif method == 'ERM-SUBG':

                # calculate the subsample weights
                p_SUBG =calc_subsample_ood_weights(p_train, X_train.shape[0])

                # fit the model
                set_seed(seed)
                logreg_model = logistic_regression_subsample(model_param, p_weights=p_SUBG, p_train=p_train, add_intercept=True)

            # define the model if the method is DFR
            elif method == 'DFR':

                # calculate the subsample weights
                p_DFR =calc_subsample_ood_weights(p_train, X_train.shape[0])

                # fit the model
                set_seed(seed)
                logreg_model = DFR_model(model_param, p_weights=p_DFR, p_train=p_train, add_intercept=True)
            
            # define the model if the method is JTT
            elif method =='JTT':

                # define class weights in the training set
                p_train_identifier = {0:p_train[1]+ p_train[2], 1:p_train[3]+ p_train[4]}
                model_param_identifier = model_param

                # define the model param for the predictor
                if penalty_strength_predictor is None or penalty_strength_predictor is None:
                    model_param_predictor = model_param
                else:
                    model_param_predictor = {'penalty_type':penalty_type_predictor, 'penalty_strength':penalty_strength_predictor, 'solver':solver,  'tol':tol, 'seed':1, 'T':T, 'eta0':None, 'learning_rate':'optimal'}

                # fit the model
                logreg_model = JTT_model(model_param_identifier, model_param_predictor, p_weights=p_ood, p_train=p_train, add_intercept=True, class_balanced_identifier=True, p_train_identifier=p_train_identifier, create_identifier=create_identifier)


            
            # fit the model if the method is ERM, ERM-GW, ERM-SUBG
            if method == 'ERM' or method == 'ERM-GW' or method == 'ERM-SUBG':
                set_seed(seed)
                logreg_model.fit(X_train.cpu(), y_train.squeeze(-1).cpu(), g_train.cpu())
                #logreg_model.Beta = Beta

            # fit the model if the method is DFR
            elif method == 'DFR':
                    
                # fit the model
                set_seed(seed)
                logreg_model.fit(X_train.cpu(), y_train.squeeze(-1).cpu(), g_train.cpu())

            # fit the model if the method is JTT
            elif method=='JTT':
                 # if we are not creating an identifier, load the predictions
                if create_identifier:
                    pred_train, pred_val, pred_test = None, None, None
                else:
                    # get predictions
                    if dataset=='WB':
                        pred_train, pred_val, pred_test =  data_obj.return_WB_pred(seed, 32, True, False, JTT_folder=True)
                    elif dataset == 'CelebA':
                        pred_train, pred_val, pred_test = data_obj.return_CelebA_pred(seed)
                    elif dataset == 'multiNLI':
                        pred_train, pred_val, pred_test = data_obj.return_multiNLI_pred(seed)

                    pred_train, pred_val, pred_test = pred_train.to(device), pred_val.to(device), pred_test.to(device)

                    if fraction_original_data < 1.0:
                        indices_train, indices_val = data_dict['indices_train'], data_dict['indices_val']
                        pred_train = pred_train[indices_train]
                        pred_val = pred_val[indices_val]
                
                # fit the model, set the p_y_JTT to 0.5
                logreg_model.fit(X_train, y_train, param_combination['lambda_JTT'], p_y_JTT=p_y_JTT, pred_train=pred_train)
            
            # fit the model if the method is GDRO
            elif method == 'GDRO':
                # define the model
                logreg_model = logistic_regression(model_param, p_weights=p_train, p_train=p_train, add_intercept=True)
            
                # fit the model via GDRO
                batch_size_GDRO = X_train.shape[0]
                set_seed(seed)
                opt_Beta, best_wg, best_t = logreg_model.optimize_GDRO_via_SGD(X_train, y_train, g_train, T, batch_size_GDRO, param_combination['eta_param_GDRO'], eta_q_GDRO, C=param_combination['C_GDRO'], use_val=True, X_val=X_val, y_val=y_val, g_val=g_val, early_stopping=False, learning_rate_schedule='constant')
                logreg_model.Beta = opt_Beta
                
                # print the results
                if verbose:
                    print('Selected opt. param at t ={} with a WG of {}'.format(best_t, best_wg))

            # measure the loss on train data
            y_train_pred= logreg_model.predict(X_train.cpu())
            wg_train, _, weighted_acc_train = wg_acc(y_train.cpu(), torch.round(torch.sigmoid(y_train_pred)), g_train.cpu(), return_all=True)



            print('After split train size: {}, val size: {}'.format(X_train.shape[0], X_val.shape[0]))
            print('Division of groups in train: {}'.format(np.unique(g_train, return_counts=True)))
            print('Division of groups in val: {}'.format(np.unique(g_val, return_counts=True)))
            print('Division of groups in test: {}'.format(np.unique(g_test, return_counts=True)))



            # measure the loss on val data
            y_val_pred= logreg_model.predict(X_val.cpu())
            loss_val= logreg_model.loss(y_val_pred, y_val.cpu(), weights_val)
            wg_val, _, weighted_acc_val = wg_acc(y_val.cpu(), torch.round(torch.sigmoid(y_val_pred)), g_val.cpu(), return_all=True)


            # if the method is DFR and DFR_data is val, re-fit the model on the entire validation data
            if method == 'DFR' and DFR_data=='val':
                
                # get the original data
                X_val_orig, y_val_orig, g_val_orig = data_dict['X_val_orig'], data_dict['y_val_orig'], data_dict['g_val_orig']
                # set the DFR weights
                p_val_orig = data_obj.get_p_dict(g=g_val_orig)
                p_DFR =calc_subsample_ood_weights(p_val_orig, X_train.shape[0])

                # fit the model
                set_seed(seed)
                logreg_model = DFR_model(model_param, p_weights=p_DFR, p_train=p_train, add_intercept=True)
                logreg_model.fit(X_val_orig.cpu(), y_val_orig.squeeze(-1).cpu(), g_val_orig.cpu())
                print('Successfully re-fitted the model on the entire validation data.')


            # measure the loss on test data
            y_test_pred= logreg_model.predict(X_test.cpu())
            loss_test= logreg_model.loss(y_test_pred, y_test.cpu(), weights_test)
            wg_test, _, weighted_acc_test = wg_acc(y_test.cpu(), torch.round(torch.sigmoid(y_test_pred)), g_test.cpu(), return_all=True)


            # store the results
            results['seed'].append(seed)
            results['loss_val'].append(loss_val.item())
            results['wg_val'].append(wg_val.item())
            results['weighted_acc_val'].append(weighted_acc_val)
            results['loss_test'].append(loss_test.item())
            results['wg_test'].append(wg_test.item())
            results['weighted_acc_test'].append(weighted_acc_test)

            # print the results
            if verbose:
                print('Seed: {}, Loss val: {:.3f}, WG val: {:.3f}, Weighted acc val: {:.3f}, Loss test: {:.3f}, WG test: {:.3f}, Weighted acc test: {:.3f}'.format(seed, loss_val.item(), wg_val.item(), weighted_acc_val, loss_test.item(), wg_test.item(), weighted_acc_test))
                print('For the train data: WG train: {:.3f}, Weighted acc train: {:.3f}'.format(wg_train, weighted_acc_train))

            # add model to list of models, and create a key
            run_key = get_key(method, dataset, penalty_type, param_combination['penalty_strength'], solver, tol, batch_size_model, augmentation_data, early_stopping_model, param_combination['lambda_JTT'], DFR_data, penalty_type_predictor, penalty_strength_predictor, create_identifier, p_y_JTT, T, param_combination['C_GDRO'], param_combination['eta_param_GDRO'], eta_q_GDRO, fraction_original_data, seed)
            

            # if train_val_split is not original, add to key
            if train_val_split != 'original':
                train_val_split_str = str(train_val_split).replace('.', '')
                run_key = run_key + '_train_val_split_' + train_val_split_str
            
            keys.append(run_key)


            # save the models from list of models according to their key as a pickle file
            # first, in the model folder, check if there is a folder for the method
            if not os.path.exists(model_folder+ '/' + method):
                os.makedirs(model_folder + '/' + method)
            # second, save the model
            with open(model_folder + '/' + method + '/' + run_key + '.pkl', 'wb') as f:
                pickle.dump(logreg_model, f)

        # create a result entry
        results_run = create_result_entry(results, method, dataset, penalty_type, param_combination['penalty_strength'], solver, tol, batch_size_model, augmentation_data, early_stopping_model, param_combination['lambda_JTT'], DFR_data, penalty_type_predictor, penalty_strength_predictor, create_identifier, p_y_JTT, T,  param_combination['C_GDRO'],  param_combination['eta_param_GDRO'], eta_q_GDRO, fraction_original_data)
        results_run['key'] = keys


        # if the train_val_split is not original, add to the df
        if train_val_split != 'original':
            results_run['train_val_split'] = train_val_split

       


        # turn the dict into a pd.DataFrame table
        results_run_table = pd.DataFrame(results_run)
        result_table[i] = results_run_table
        print('results_run_table: {}'.format(results_run_table))

        # print the results
        if verbose:
            # calculate the average and standard error
            _, _, wg_val_avg, wg_val_se, weighted_acc_val_avg, weighted_acc_val_se = return_results(results, seeds, type_result='val')
            print('Param combination: {}'.format(param_combination))
            print('WG val - Avg: {:.3f}, SE: {:.3f}'.format(wg_val_avg, wg_val_se))
            print('Weighted acc val - Avg: {:.3f}, SE: {:.3f}'.format(weighted_acc_val_avg, weighted_acc_val_se))

        # increase the counter
        i += 1
                 
    time_after = time.time()
    print('Time taken: {}'.format(time_after-start_time))
    
    # print the result dict
    print('--- Results ---')
    if save:
        print('Saving the results to the result folder.')
        # now, we have a dict with dataframes
        result_table = pd.concat(result_table, axis=0)

        # add to the table the following information - the time of the run
        # get the time at the minute
        result_table['time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())

        # check if in the result folder there is already a main file
        main_file = 'main_param_search.csv'
        if os.path.exists(result_folder + '/' + main_file):

            # load the main file
            main_table = pd.read_csv(result_folder + '/' + main_file)

            # concat in such a way that if there are new columns, they are added
            # where the empty values are filled with NaN
            main_table = pd.concat([main_table, result_table], axis=0, sort=False)

            # check: entries with the same key? - if so, remove the old one
            main_table = main_table.drop_duplicates(subset='key', keep='last')


            # write back to the main file
            main_table.to_csv(result_folder + '/' + main_file, index=False)


        else:
            result_table.to_csv(result_folder + '/' + main_file, index=False)







if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Dataset preparation')
    parser.add_argument('--dataset', type=str, default='WB', help='Dataset to use')
    parser.add_argument('--seeds', type=str, default='1', help='Number of seeds to use')
    parser.add_argument('--method', type=str, default='last_layer', help='Method to use')
    parser.add_argument('--device', type=str, default='cpu', help='Device to use')
    parser.add_argument('--penalty_strengths', type=str, help='Penalty strength to use')
    parser.add_argument('--penalty_type', type=str, default='l1', help='Penalty type to use')
    parser.add_argument('--solver', type=str, default='liblinear', help='Solver to use')
    parser.add_argument('--tol', type=float, default=1e-4, help='Tolerance to use')
    parser.add_argument('--batch_size_model', type=int, default=64, help='Batch size to use for the model')
    parser.add_argument('--augmentation_data', type=str, default='False', help='Whether to use data augmentation')
    parser.add_argument('--early_stopping_model', type=str, default='False', help='Whether to use early stopping')
    parser.add_argument('--lambda_JTT', type=str, default='1.0', help='Lambda for JTT')
    parser.add_argument('--selection_criteria', type=str, default='wg_val_avg', help='Selection criteria to use')
    parser.add_argument('--DFR_data', type=str, default='train', help='Data to use for DFR')
    parser.add_argument('--penalty_type_predictor', type=str, default='l1', help='Penalty type for the predictor')
    parser.add_argument('--penalty_strength_predictor', type=float, default=None, help='Penalty strength for the predictor')
    parser.add_argument('--create_identifier', type=str, default='False', help='Whether to create an identifier')
    parser.add_argument('--T', type=int, default=100, help='Number of iterations ')
    parser.add_argument('--C_GDRO', type=str, default='0.0', help='C for GDRO')
    parser.add_argument('--eta_param_GDRO', type=str, default='0.00001', help='Eta param for GDRO')
    parser.add_argument('--eta_q_GDRO', type=float, default=0.1, help='Eta q for GDRO')
    parser.add_argument('--fraction_original_data', type=float, default=1.0, help='Fraction of original data to use')
    parser.add_argument('--result_folder', type=str, default='results/20092024_set', help='Whether to use parallel fit')
    parser.add_argument('--model_folder', type=str, default='models/last_layer_models', help='Whether to use parallel fit')
    parser.add_argument('--train_val_split',  default='original', help='Whether to use train/val split')
    parser.add_argument('--save', type=str, default='True', help='Whether to save the results')

    args = parser.parse_args()

    # convert seed to list of ints
    seeds = [int(seed) for seed in args.seeds.split('-')]

    # turn the penalty strength into a list of floats
    penalty_strengths = [float(penalty_strength) for penalty_strength in args.penalty_strengths.split('-')]

    # turn the lambda_JTT into a list of floats
    lambda_JTT = [float(lambda_JTT) for lambda_JTT in args.lambda_JTT.split('-')]

    # turn the C_GDRO into a list of floats
    C_GDRO = [float(C_GDRO) for C_GDRO in args.C_GDRO.split('-')]

    # turn the eta_param_GDRO into a list of floats
    eta_param_GDRO = [float(eta_param_GDRO) for eta_param_GDRO in args.eta_param_GDRO.split('-')]

    def str_to_bool(text):
        if text.lower() == 'true':
            return True
        elif text.lower() == 'false':
            return False
    
    args.early_stopping_model = str_to_bool(args.early_stopping_model)
    args.augmentation_data = str_to_bool(args.augmentation_data)
    args.create_identifier = str_to_bool(args.create_identifier)
    args.save = str_to_bool(args.save)

 
 
   
    # Run the main function
    main(args.dataset, seeds, args.method, args.device, penalty_strengths, args.penalty_type, args.solver, args.tol, args.batch_size_model, args.augmentation_data, args.early_stopping_model,
         selection_criteria=args.selection_criteria,  lambda_JTT=lambda_JTT, DFR_data=args.DFR_data, 
         penalty_type_predictor=args.penalty_type_predictor, penalty_strength_predictor=args.penalty_strength_predictor, create_identifier=args.create_identifier,
            T=args.T, C_GDRO=C_GDRO, eta_param_GDRO=eta_param_GDRO, eta_q_GDRO=args.eta_q_GDRO, fraction_original_data=args.fraction_original_data, result_folder=args.result_folder, model_folder=args.model_folder, train_val_split=args.train_val_split, save=args.save)

