import SPA
import os
from time import sleep
import sys

def train_GEFCom(lag, cutoff = '4e-3', n_hours = 24, hour_transform = None):
    # Manipulating the names of the datasets and saved files based on what hour transform is being used
    if hour_transform is None:
        hour_transform_string = ''
    elif hour_transform.casefold() == 'ohe':
        hour_transform_string = '-OHEhour'
    elif hour_transform.casefold() == 'trig':
        hour_transform_string = '-trighour'
    # Running the actual training
    for hour_idx in range(n_hours):
        if os.path.isfile(f'GEFCom_loadonly_{lag}lag_{cutoff}cutoff_{hour_idx+1:02}hour{hour_transform_string}.json') or os.path.isfile(f'GEFCom_loadonly_{lag}lag_{cutoff}cutoff_{hour_idx+1:03}hour{hour_transform_string}.json'): # Do not retrain is file with results already exists
            continue
        _ = SPA.main_SPA(f'GEFCom2014_load{hour_transform_string}-data_train.csv', test_data = f'GEFCom2014_load{hour_transform_string}-data_test.csv', scale_X = False, cv_method = 'timeseries', degree = [1], trans_type = 'poly', model_name = ['LCEN'], LCEN_cutoff = float(cutoff), lag = [lag + hour_idx], min_lag = hour_idx, l1_ratio = [0, 0.333, 0.667, 0.99] )
        # Renaming the results files for convenience
        for this_file in os.listdir():
            if this_file.startswith('SPA_results') and this_file.endswith('.json'):
                os.rename(this_file, f'GEFCom_loadonly_{lag}lag_{cutoff}cutoff_{hour_idx+1:03}hour{hour_transform_string}.json')
            elif this_file.startswith('SPA_results') and this_file.endswith('.p'):
                os.remove(this_file)

if __name__ == '__main__':
    # Input setup
    import argparse
    parser = argparse.ArgumentParser(description = 'Trains multiple LCEN models (one for each hour ahead) for the prediction of power loads with the GEFCom data.')
    parser.add_argument('lag', type = int, nargs = 1, help = 'The lag to be used when making predictions.')
    parser.add_argument('-c', '--cutoff', metavar = '4e-3', nargs = 1, type = str, default = ['4e-3'], help = 'The LCEN cutoff that to be used when making predictions in string format. Optional, default = "4e-3".')
    parser.add_argument('-n', '--n_hours', metavar = 0, nargs = 1, type = int, default = [24], help = 'The maximum number of future hours to predict. Optional, default = 24.')
    parser.add_argument('-ht', '--hour_transform', metavar = 0, nargs = 1, type = str, default = [None], help = 'What transform (if any) to use for the hours. Options are None for no transform, "ohe" for one-hot encoding, and "trig" for a sine-and-cosine-based transform. Optional, default = None.')
    args = parser.parse_args()
    lag = args.lag[0] # [0] to convert from list to int
    cutoff = args.cutoff[0]
    n_hours = args.n_hours[0]
    hour_transform = args.hour_transform[0]
    train_GEFCom(lag, cutoff, n_hours, hour_transform)
