import argparse
from time import time

import itertools

import pandas as pd

from Models import rnn_models
from Models.pl_rnn import LitRNNModel
from config.experiments import all_models, all_sizes


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('mode', nargs='?', help='train or eval', default='train')
    parser.add_argument('--steps', default=5)
    return parser.parse_args()


def get_training_times():
    df = pd.read_csv('../logs/tb_data.csv', index_col=0)
    df['time'] = pd.to_datetime(df['end_date']) - pd.to_datetime(df['date'])
    minutes = df.groupby(['model_cls'])['time'].aggregate('median').dt.total_seconds() / 60
    return minutes


def get_prediction_times(num_steps=5):
    df = pd.DataFrame()
    for model, size in itertools.product(all_models, all_sizes):
        if not hasattr(rnn_models, model):
            print("ERROR: Unknown model type '{}'".format(model))
            continue
        start = time()
        for i in range(num_steps):
            model_cls = getattr(rnn_models, model)
            lit_model = LitRNNModel(model_cls, size)
            lit_model.eval()
            lit_model(lit_model.eval_dataset.x[0][:1])
        elapsed = time() - start
        df.loc[model, size] = elapsed / num_steps
    return df


if __name__ == "__main__":
    args = parse_args()
    if args.mode == 'train':
        print(get_training_times().to_latex())
    elif args.mode == 'eval':
        print(get_prediction_times(num_steps=args.steps).to_latex())
