import warnings
import argparse

import numpy as np
import pandas as pd
from tqdm import tqdm
from statsmodels.api import tsa as sm
import os
from time import time

# warnings.filterwarnings("ignore")
from data.DataPrepare import load_and_prep_df

METHODS = {
    "ARIMA":   sm.ARIMA,
    "VARMAX":  sm.VARMAX,
    "SARIMAX": sm.SARIMAX
}


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--name", "-n", type=str, default="Eintritt Geburtshaus")
    parser.add_argument("--params", "-p", default=(1, 0, 0))
    parser.add_argument("--method", "-m", default="ARIMA", choices=METHODS.keys())
    parser.add_argument("--refit", action="store_true")
    parser.add_argument("--num_steps", type=int,
                        help="Number of steps to predict. Entire test set is predicted if not given.")
    parser.add_argument("--train_size", type=int,
                        help="Number of steps used for training")
    # parser.add_argument("--maxiter", type=int, help="Number of iterations when estimating model parameters.")
    return parser.parse_args()


def calculate_arima(args):
    arg_poi_name = args.name
    arg_params = args.params if isinstance(args.params, tuple) else eval(args.params)
    arg_method = args.method
    model_cs = METHODS[arg_method]

    df_train, df_test = load_and_prep_df()
    df_train.reset_index(drop=True, inplace=True)

    if args.train_size is not None:
        df_train = df_train[-args.train_size:]
    if args.num_steps is not None:
        df_test = df_test[:args.num_steps]
    test_dates = df_test.index
    df_test.index = np.arange(df_train.index[-1] + 1, df_train.index[-1] + 1 + len(df_test))

    print(f'Fitting {arg_method}{arg_params}')
    if not arg_method.startswith("V"):
        print('poi_name:', arg_poi_name)
        single_train = df_train.venues[[arg_poi_name]]
        single_test = df_test.venues[[arg_poi_name]]
        del df_train["venues"]
        del df_test["venues"]
        df_train["venues"] = single_train
        df_test["venues"] = single_test

    if arg_method.endswith("X"):
        model = model_cs(df_train.venues, exog=df_train.features, order=arg_params)
    else:
        model = model_cs(df_train.venues, order=arg_params)

    start_time = time()
    if arg_method == "VARMAX":
        model_fit = model.fit(disp=True)
    else:
        model_fit = model.fit()
    elapsed = time() - start_time
    print(f"Fitting done after {elapsed:.2f} s")
    print()
    if args.refit:
        col_predictions = []
        print("Predicting and refitting")
        for new_idx in tqdm(df_test.index):
            exog = df_test.features.loc[[new_idx]] if arg_method.endswith("X") else None
            prediction = model_fit.forecast(exog=exog)
            col_predictions.append(prediction)
            model_fit = model_fit.append(df_test.venues.loc[[new_idx]],
                                         exog=exog,
                                         refit=True)
        overall_predictions = pd.concat(col_predictions)
        elapsed = time() - start_time
        print(f"Predicting done after {elapsed:.2f} s (including initial fit)")
    else:
        print("Predicting all at once")
        start_time = time()
        overall_predictions = model_fit.forecast(len(df_test), exog=df_test.features)
        elapsed = time() - start_time
        print(f"Predicting done after {elapsed:.2f} s")
    overall_predictions.index = test_dates

    if not arg_method.startswith("V"):
        output_filename = arg_poi_name.replace(' ', '_') + \
                          '_' + str(arg_params).replace(' ', '') + \
                          '.csv'
    else:
        output_filename = str(arg_params).replace(' ', '') + \
                          ("_refit" if args.refit else "") + \
                          '.csv'
    output_filedir = f'./results/{arg_method}'
    os.makedirs(output_filedir, exist_ok=True)
    output_fullname = os.path.join(output_filedir, output_filename)

    overall_predictions.to_csv(output_fullname)


if __name__ == "__main__":
    calculate_arima(parse_args())
