import matplotlib.pyplot as plt
plt.rcParams.update({'font.size': 20})
import pandas as pd
import numpy as np
import os
from json import load as jsonload
from sklearn.metrics import mean_squared_error as MSE

def predict_and_plot_GEFCom(lag, cutoff = '4e-3', max_hours = 0, hour_transform = None):
    # Manipulating the names of the datasets and saved files based on what hour transform is being used
    if hour_transform is None:
        hour_transform_string = ''
    elif hour_transform.casefold() == 'ohe':
        hour_transform_string = '-OHEhour'
    elif hour_transform.casefold() == 'trig':
        hour_transform_string = '-trighour'
        trig_vals = np.vstack(( np.sin(np.pi*np.arange(1, 25)/12), np.cos(np.pi*np.arange(1, 25)/12) )).T # Used below, declared once here for convenience. arange(1, 25) because the first point of the test set occurs at 1:00, not 0:00
        trig_vals[np.abs(trig_vals) < 1e-8] = 0 # Rounding very small values that should be 0 but are O(10^-16)
    # Obtaining the models' coefficients from the .json files
    coeffs = []
    intercepts = []
    hour_coeffs = []
    file_count = 1
    for this_file in np.sort(os.listdir()):
        if this_file == f'GEFCom_loadonly_{lag}lag_{cutoff}cutoff_{file_count:03}hour{hour_transform_string}.json':
            with open(this_file, 'r') as f:
                temp_json = jsonload(f)
            if temp_json['LCEN']['model_params'][0][0] == 'intercept':
                intercepts.append(temp_json['LCEN']['model_params'][0][1])
            else:
                intercepts.append(0)
            temp_coeffs = np.zeros(lag)
            if hour_transform_string == '-OHEhour':
                temp_hour_coeffs = np.zeros(24)
            elif hour_transform_string == '-trighour':
                temp_hour_coeffs = np.zeros(2)
            else:
                temp_hour_coeffs = []
            for elem in temp_json['LCEN']['model_params']:
                if 't-' in elem[0]: # y (power loads) at previous times
                    idx = int(elem[0].split('-')[-1][:-1]) # Getting the specific time point X from the 'y(t-X)' string. [:-1] because X may be >= 10
                    temp_coeffs[idx - 1] = elem[1] # idx - 1 because Python is 0-indexed but the strings are 1-indexed
                elif 'x' in elem[0]: # hour transform
                    idx = int(elem[0][1:])
                    temp_hour_coeffs[idx] = elem[1]
            coeffs.append(temp_coeffs)
            hour_coeffs.append(temp_hour_coeffs)
            if max_hours and len(coeffs) >= max_hours:
                break
            file_count += 1
    coeffs = np.array(coeffs)
    hour_coeffs = np.array(hour_coeffs)
    # Loading the data
    data_train = pd.read_csv(f'GEFCom2014_load{hour_transform_string}-data_train.csv', header = None).values
    data_test = pd.read_csv(f'GEFCom2014_load{hour_transform_string}-data_test.csv', header = None).values
    data_all = np.concatenate((data_train, data_test))
    # Predicting
    preds = np.zeros((data_test.shape[0], coeffs.shape[0])) # preds[X, Y] means "at timestep t = X, predict Y hours ahead"
    for hour_idx in range(coeffs.shape[0]):
        for idx, coeff in enumerate(coeffs[hour_idx]):
            preds[:, hour_idx] += coeff*data_all[-data_test.shape[0]-(idx+hour_idx+1) : -(idx+hour_idx+1), -1] # data_all[complex_index][X] goes into preds[X, hour_idx]. For a fixed X, the value returned by data_all[complex_index][X] changes as complex_index changes with idx, and these changes in idx represent using different hours in the past to make predictions for the point hour_idx (Y) hours ahead
        preds[:, hour_idx] += intercepts[hour_idx]
        # Hour transforms add some values to the prediction
        if hour_transform_string == '-OHEhour':
            preds[:, hour_idx] += (hour_coeffs[hour_idx]*data_test[:, :-1]).sum(axis=1)
        elif hour_transform_string == '-trighour':
            preds[:, hour_idx] += np.tile((hour_coeffs[hour_idx]*trig_vals).sum(axis=1), data_test.shape[0]//24)
    # Evaluation
    mse_values = np.empty((preds.shape[0] - preds.shape[1]))
    relative_error = np.empty_like(mse_values)
    SSErr = 0
    SST = 0
    for idx in range(mse_values.shape[0]):
        these_preds = preds[idx : idx+preds.shape[1]].diagonal()
        mse_values[idx] = MSE(data_test[idx : idx+preds.shape[1], -1], these_preds)
        relative_error[idx] = np.mean( np.abs(data_test[idx : idx+preds.shape[1], -1] - these_preds)/data_test[idx : idx+preds.shape[1], -1] )
        if relative_error[idx] >= 0.6:
            relative_error[idx] = np.nan
            mse_values[idx] = np.nan
        else:
            SSErr += ((these_preds - data_test[idx : idx+preds.shape[1], -1])**2).sum()
            SST += ((data_test[idx : idx+preds.shape[1], -1] - data_test[idx : idx+preds.shape[1], -1].mean())**2).sum()
        if False:#(idx <= coeffs.shape[0]*8 and (not(idx%coeffs.shape[0]) or idx in {6, 12, 18})) or relative_error[idx] >= 0.6:
            size_horizontal = 16 + 10*len(these_preds)//48
            fig, ax = plt.subplots(figsize = (size_horizontal, 9))
            ax.plot(range(len(these_preds)), data_test[idx : idx+preds.shape[1], -1], label = 'Test Data')
            ax.plot(range(len(these_preds)), these_preds, label = 'LCEN')
            ax.set_xticks((range(0, len(these_preds), 1 + (len(these_preds)-1)//48 )), labels = range(1, len(these_preds)+1, 1 + (len(these_preds)-1)//48 ))
            ax.set_xlim([-0.2, len(these_preds)+0.2])
            ax.set_xlabel('Time (h)')
            ax.set_ylabel('Load (MW)')
            plt.legend(title = f'RMSE = {mse_values[idx]**0.5:.2f} MW\nRel. Error = {relative_error[idx]*100:.1f}%')
            plt.tight_layout()
            plt.savefig(f'GEFCom_loadonly_{lag}lag_{cutoff}cutoff_{len(these_preds)}hours{hour_transform_string}_start-at-{idx}hours.svg', bbox_inches = 'tight')
            plt.close()
    print(f'MSE stats: mean = {mse_values.mean():.3f}; min = {mse_values.min():.3f}; max = {mse_values.max():.3f}; median = {np.nanmedian(mse_values):.3f}')
    print(f'RMSE stats: mean = {np.nanmean(mse_values**0.5):.3f}; min = {np.nanmin(mse_values**0.5):.3f}; max = {np.nanmax(mse_values**0.5):.3f}; median = {np.nanmedian(mse_values**0.5):.3f}')
    print(f'R^2 = {1 - SSErr/SST:.4f}')
    print(f'Relative error stats: mean = {np.nanmean(relative_error):.3f}; min = {np.nanmin(relative_error):.3f}; max = {np.nanmax(relative_error):.3f}; median = {np.nanmedian(relative_error):.3f}')
    # Relative error histogram
    if not(np.any(np.isnan(relative_error))):
        fig, ax = plt.subplots(figsize = (16, 9))
        hist_bins = ax.hist( relative_error, bins = np.concatenate((np.arange(0, 1.1, 0.1), [5])) )
        ax.set_xlabel('Relative Error')
        ax.set_ylabel('Number')
        ax.set_ylim([0, np.max( np.concatenate((hist_bins[0][3:], [45])) ) + 5]) # y-axis limit equals (the highest value among the big errors + 5) or 50, whichever is higher
        ax.set_xticks([0, 0.2, 0.4, 0.6, 0.8, 1])
        plt.tight_layout()
        plt.savefig(f'GEFCom_loadonly_{lag}lag_{cutoff}cutoff_{len(these_preds)}hours{hour_transform_string}_RelErr-histogram.svg', bbox_inches = 'tight')

if __name__ == '__main__':
    # Input setup
    import argparse
    parser = argparse.ArgumentParser(description = 'Reads LCEN coefficients from JSON files and makes a multiple-output prediction of power loads, comparing with the GEFCom test data.')
    parser.add_argument('lag', type = int, nargs = 1, help = 'The lag that was used when making predictions.')
    parser.add_argument('-c', '--cutoff', metavar = '4e-3', nargs = 1, type = str, default = ['4e-3'], help = 'The cutoff that was used when making predictions. Optional, default = "4e-3".')
    parser.add_argument('-mh', '--max_hours', metavar = 0, nargs = 1, type = int, default = [0], help = 'The maximum number of future hours to predict. Leave at 0 for no limit. Optional, default = 0.')
    parser.add_argument('-ht', '--hour_transform', metavar = 0, nargs = 1, type = str, default = [None], help = 'What transform (if any) to use for the hours. Options are None for no transform, "ohe" for one-hot encoding, and "trig" for a sine-and-cosine-based transform. Optional, default = None.')
    args = parser.parse_args()
    lag = args.lag[0] # [0] to convert from list to int
    cutoff = args.cutoff[0]
    max_hours = args.max_hours[0]
    hour_transform = args.hour_transform[0]
    predict_and_plot_GEFCom(lag, cutoff, max_hours, hour_transform)
