import argparse
import json
import os
import time

import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from pathlib import Path

from gluonts.dataset.common import ListDataset
from gluonts.dataset.field_names import FieldName
from gluonts.torch.model.forecast import DistributionForecast as PTDistributionForecast
from gluonts.mx.model.forecast import DistributionForecast as MXDistributionForecast
from gluonts.model.forecast import QuantileForecast
from gluonts.evaluation.backtest import make_evaluation_predictions
from accuracy_evaluator import evaluate_wrmsse

PREDICTION_LENGTH = 28
N_TS = 30490
VAL_START = 1886  # 1969 - 3 * 28 + 1
TEST_START = 1914  # 1969 - 2 * 28 + 1


def convert_price_file(m5_input_path):

    # load data
    calendar = pd.read_csv(f'{m5_input_path}/calendar.csv')
    sales_train_evaluation = pd.read_csv(f'{m5_input_path}/sales_train_evaluation.csv')
    sell_prices = pd.read_csv(f'{m5_input_path}/sell_prices.csv')

    # assign price for all days
    week_and_day = calendar[['wm_yr_wk', 'd']]

    price_all_days_items = pd.merge(week_and_day, sell_prices, on=['wm_yr_wk'], how='left') # join on week number
    price_all_days_items = price_all_days_items.drop(['wm_yr_wk'], axis=1)

    # convert days to column
    price_all_items = price_all_days_items.pivot_table(values='sell_price', index=['store_id', 'item_id'], columns='d')
    price_all_items.reset_index(drop=False, inplace=True)

    # reorder column
    price_all_items = price_all_items.reindex(['store_id','item_id'] + ['d_%d' % x for x in range(1,1969+1)], axis=1)

    sales_keys = ['id', 'item_id', 'dept_id', 'cat_id', 'store_id', 'state_id']
    sales_keys_pd = sales_train_evaluation[sales_keys]

    # join with sales data
    price_converted = pd.merge(sales_keys_pd, price_all_items, on=['store_id','item_id'], how='left')


    # save file
    price_converted.to_csv(f'{m5_input_path}/converted_price_evaluation.csv', index=False)


def load_datasets(data_dir, feature_dict=False):
    calendar = pd.read_csv(f'{data_dir}/calendar.csv')
    sales_train_evaluation = pd.read_csv(f'{data_dir}/sales_train_evaluation.csv')
    sell_prices = pd.read_csv(f'{data_dir}/sell_prices.csv')

    cal_features = calendar.drop(
        ['date', 'wm_yr_wk', 'weekday', 'wday', 'month', 'year', 'event_name_1', 'event_name_2', 'd'],
        axis=1
    )
    cal_features['event_type_1'] = cal_features['event_type_1'].apply(lambda x: 0 if str(x)=='nan' else 1)
    cal_features['event_type_2'] = cal_features['event_type_2'].apply(lambda x: 0 if str(x)=='nan' else 1)

    event_features = cal_features.values.T
    event_features_expand = np.tile(event_features, (len(sales_train_evaluation), 1, 1))

    state_ids = sales_train_evaluation['state_id'].astype('category').cat.codes.values
    state_ids_un , state_ids_counts = np.unique(state_ids, return_counts=True)

    store_ids = sales_train_evaluation['store_id'].astype('category').cat.codes.values
    store_ids_un , store_ids_counts = np.unique(store_ids, return_counts=True)

    cat_ids = sales_train_evaluation['cat_id'].astype('category').cat.codes.values
    cat_ids_un , cat_ids_counts = np.unique(cat_ids, return_counts=True)

    dept_ids = sales_train_evaluation['dept_id'].astype('category').cat.codes.values
    dept_ids_un , dept_ids_counts = np.unique(dept_ids, return_counts=True)

    item_ids = sales_train_evaluation['item_id'].astype('category').cat.codes.values
    item_ids_un , item_ids_counts = np.unique(item_ids, return_counts=True)

    if feature_dict:
        stat_cat = {'item_ids': item_ids.reshape(-1, 1),
                    'dept_ids': dept_ids.reshape(-1, 1),
                    'cat_ids': cat_ids.reshape(-1, 1),
                    'store_ids': store_ids.reshape(-1, 1),
                    'state_ids': state_ids.reshape(-1, 1)}
        stat_cat_cardinalities = {
            'item_ids': len(item_ids_un),
            'dept_ids': len(dept_ids_un),
            'cat_ids': len(cat_ids_un),
            'store_ids': len(store_ids_un),
            'state_ids': len(state_ids_un)
        }
    else:
        stat_cat_list = [item_ids, dept_ids, cat_ids, store_ids, state_ids]
        stat_cat = np.concatenate(stat_cat_list)
        stat_cat = stat_cat.reshape(len(stat_cat_list), len(item_ids)).T
        stat_cat_cardinalities = [len(item_ids_un), len(dept_ids_un), len(cat_ids_un), len(store_ids_un), len(state_ids_un)]

    train_df = sales_train_evaluation.drop(['id','item_id','dept_id','cat_id','store_id','state_id'], axis=1)
    train_target_values = train_df.values

    test_target_values = train_target_values.copy()
    train_target_values = [ts[:-(2 * PREDICTION_LENGTH)] for ts in train_df.values]
    val_target_values = [ts[:-PREDICTION_LENGTH] for ts in train_df.values]

    # snap features
    # snap_features = calendar[['snap_CA', 'snap_TX', 'snap_WI']]
    # snap_features = snap_features.values.T
    # snap_features_expand = np.array([snap_features] * len(sales_train_evaluation))    # 30490 * 3 * T

    # sell_prices
    converted_price_file = Path(f'{data_dir}/converted_price_evaluation.csv')
    if not converted_price_file.exists():
        convert_price_file(data_dir)
    converted_price = pd.read_csv(converted_price_file)

    price_feature = converted_price.drop(["id","item_id","dept_id","cat_id","store_id","state_id"], axis=1).values

    # normalized sell prices
    normalized_price_file = Path(f'{data_dir}/normalized_price_evaluation.npz')
    if not normalized_price_file.exists():
        # normalized sell prices per each item
        price_mean_per_item = np.nanmean(price_feature, axis=1, keepdims=True)
        price_std_per_item = np.nanstd(price_feature, axis=1, keepdims=True)
        normalized_price_per_item = (price_feature - price_mean_per_item) / (price_std_per_item + 1e-6)

        # normalized sell prices per day within the same dept
        dept_groups = converted_price.groupby('dept_id')
        price_mean_per_dept = dept_groups.transform(np.nanmean)
        price_std_per_dept = dept_groups.transform(np.nanstd)
        normalized_price_per_group_pd = (converted_price[price_mean_per_dept.columns] - price_mean_per_dept) / (price_std_per_dept + 1e-6)

        normalized_price_per_group = normalized_price_per_group_pd.values
        np.savez(normalized_price_file, per_item = normalized_price_per_item, per_group = normalized_price_per_group)
    else:
        normalized_price = np.load(normalized_price_file)
        normalized_price_per_item = normalized_price['per_item']
        normalized_price_per_group = normalized_price['per_group']

    price_feature = np.nan_to_num(price_feature)
    normalized_price_per_item = np.nan_to_num(normalized_price_per_item)
    normalized_price_per_group = np.nan_to_num(normalized_price_per_group)

    all_price_features = np.stack([normalized_price_per_item, normalized_price_per_group], axis=1)   # 30490 * 2 * T
    # dynamic_real = np.concatenate([snap_features_expand, all_price_features, event_features_expand], axis=1)    # 30490 * 6 * T
    dynamic_real = np.concatenate([all_price_features, event_features_expand], axis=1)    # 30490 * 6 * T

    train_dynamic_real = dynamic_real[...,:VAL_START - 1]
    val_dynamic_real = dynamic_real[...,:TEST_START - 1]
    test_dynamic_real = dynamic_real[...,:-PREDICTION_LENGTH]

    m5_dates = [pd.Timestamp('2011-01-29') for _ in range(len(sales_train_evaluation))]

    if feature_dict:
        train_ds = ListDataset([
            {
                FieldName.TARGET: target,
                FieldName.START: start,
                FieldName.FEAT_DYNAMIC_REAL: fdr,
                **dict((k, v[i]) for k, v in stat_cat.items())
            }
            for i, (target, start, fdr) in enumerate(zip(train_target_values,
                                                         m5_dates,
                                                         train_dynamic_real))
        ], freq='D')

        val_ds = ListDataset([
            {
                FieldName.TARGET: target,
                FieldName.START: start,
                FieldName.FEAT_DYNAMIC_REAL: fdr,
                **dict((k, v[i]) for k, v in stat_cat.items())
            }
            for i, (target, start, fdr) in enumerate(zip(val_target_values,
                                                         m5_dates,
                                                         val_dynamic_real))
        ], freq='D')

        test_ds = ListDataset([
            {
                FieldName.TARGET: target,
                FieldName.START: start,
                FieldName.FEAT_DYNAMIC_REAL: fdr,
                **dict((k, v[i]) for k, v in stat_cat.items())
            }
            for i, (target, start, fdr) in enumerate(zip(test_target_values,
                                                         m5_dates,
                                                         test_dynamic_real))
        ], freq='D')

    else:
        train_ds = ListDataset([
            {
                FieldName.TARGET: target,
                FieldName.START: start,
                FieldName.FEAT_DYNAMIC_REAL: fdr,
                FieldName.FEAT_STATIC_CAT: fsc
            }
            for (target, start, fdr, fsc) in zip(train_target_values,
                                                m5_dates,
                                                train_dynamic_real,
                                                stat_cat)
        ], freq='D')

        val_ds = ListDataset([
            {
                FieldName.TARGET: target,
                FieldName.START: start,
                FieldName.FEAT_DYNAMIC_REAL: fdr,
                FieldName.FEAT_STATIC_CAT: fsc
            }
            for (target, start, fdr, fsc) in zip(val_target_values,
                                                m5_dates,
                                                val_dynamic_real,
                                                stat_cat)
        ], freq='D')

        test_ds = ListDataset([
            {
                FieldName.TARGET: target,
                FieldName.START: start,
                FieldName.FEAT_DYNAMIC_REAL: fdr,
                FieldName.FEAT_STATIC_CAT: fsc
            }
            for (target, start, fdr, fsc) in zip(test_target_values,
                                                m5_dates,
                                                test_dynamic_real,
                                                stat_cat)
        ], freq='D')

    return train_ds, val_ds, test_ds, stat_cat_cardinalities


def get_deepar_estimator(args, cardinality, ckpt_dir, callbacks):
    from gluonts.torch.model.deepar import DeepAREstimator
    from gluonts.torch.distributions import NegativeBinomialOutput
    estimator = DeepAREstimator(
        prediction_length=PREDICTION_LENGTH,
        num_layers=args.n_block,
        hidden_size=args.hidden_size,
        lr=args.lr,
        weight_decay=args.weight_decay,
        dropout_rate=args.dropout,
        freq='D',
        num_feat_dynamic_real=10,
        num_feat_static_cat=5,
        cardinality=cardinality,
        distr_output = NegativeBinomialOutput(),
        lags_seq=[],
        time_features=[],
        batch_size=args.batch_size,
        num_batches_per_epoch=1 if args.debug else (N_TS // args.batch_size + 1),
        trainer_kwargs={
            'accelerator': 'gpu',
            'devices': 1,
            'max_epochs': 1 if args.debug else 300,
            'callbacks': callbacks,
            'ckpt_kwargs': {'dirpath': ckpt_dir},
        },
    )
    return estimator


def get_nbeats_estimator(args, cardinality, ckpt_dir, callbacks):
    from gluonts.mx import NBEATSEstimator
    from gluonts.mx.trainer import Trainer

    estimator = NBEATSEstimator(
        freq="D",
        prediction_length=PREDICTION_LENGTH,
        context_length=args.seq_len,
        num_stacks=args.n_stack,
        num_blocks=[args.n_block],
        num_block_layers=[4],
        widths=[args.hidden_size],
        sharing=[False],
        stack_types=['G'],
        loss_function='MAPE',
        scale=True,
        trainer=Trainer(
            learning_rate=args.lr,
            epochs=1 if args.debug else 300,
            num_batches_per_epoch=1 if args.debug else (N_TS // args.batch_size + 1),
            callbacks=callbacks
        )
    )
    return estimator


def get_tft_estimator(args, cardinality, ckpt_dir, callbacks):
    from gluonts.mx import TemporalFusionTransformerEstimator
    from gluonts.mx.trainer import Trainer

    estimator = TemporalFusionTransformerEstimator(
        freq="D",
        prediction_length=PREDICTION_LENGTH,
        context_length=args.seq_len,
        hidden_dim=args.hidden_size,
        num_heads=args.n_head,
        num_outputs=1,
        dropout_rate=args.dropout,
        static_cardinalities=cardinality,
        static_feature_dims={},
        dynamic_feature_dims={FieldName.FEAT_DYNAMIC_REAL: 10},
        past_dynamic_features=[],
        trainer=Trainer(
            learning_rate=args.lr,
            epochs=1 if args.debug else 300,
            num_batches_per_epoch=1 if args.debug else (N_TS // args.batch_size + 1),
            callbacks=callbacks
        )
    )
    return estimator


def get_prob_tft_estimator(args, cardinality, ckpt_dir, callbacks):
    from gluonts.mx import ProbabilisticTemporalFusionTransformerEstimator
    from gluonts.mx.trainer import Trainer
    from gluonts.mx.distribution.neg_binomial import NegativeBinomialOutput

    estimator = ProbabilisticTemporalFusionTransformerEstimator(
        freq="D",
        prediction_length=PREDICTION_LENGTH,
        context_length=args.seq_len,
        hidden_dim=args.hidden_size,
        num_heads=args.n_head,
        distr_output = NegativeBinomialOutput(),
        dropout_rate=args.dropout,
        static_cardinalities=cardinality,
        static_feature_dims={},
        dynamic_feature_dims={FieldName.FEAT_DYNAMIC_REAL: 10},
        past_dynamic_features=[],
        trainer=Trainer(
            learning_rate=args.lr,
            epochs=1 if args.debug else 300,
            num_batches_per_epoch=1 if args.debug else (N_TS // args.batch_size + 1),
            callbacks=callbacks
        )
    )
    return estimator


def get_prob_patchtst_estimator(args, cardinality, ckpt_dir, callbacks):
    from gluonts.torch.model.prob_patchtst import ProbabilisticPatchTSTEstimator
    from gluonts.torch.distributions import NegativeBinomialOutput

    estimator = ProbabilisticPatchTSTEstimator(
        prediction_length=PREDICTION_LENGTH,
        context_length=args.seq_len,
        n_block=args.n_block,
        hidden_size=args.hidden_size,
        n_head=args.n_head,
        patch_len=8,
        stride=4,
        d_ff=args.hidden_size,
        lr=args.lr,
        weight_decay=args.weight_decay,
        dropout_rate=args.dropout,
        batch_size=args.batch_size,
        freq='D',
        distr_output = NegativeBinomialOutput(),
        num_batches_per_epoch=1 if args.debug else (N_TS // args.batch_size + 1),
        trainer_kwargs={
            'accelerator': 'gpu',
            'devices': 1,
            'max_epochs': 1 if args.debug else 300,
            'callbacks': callbacks,
            'ckpt_kwargs': {'dirpath': ckpt_dir},
        },
    )
    return estimator


def get_fedformer_estimator(args, cardinality, ckpt_dir, callbacks):
    from gluonts.torch.model.fedformer import FEDformerEstimator
    from gluonts.torch.distributions import NegativeBinomialOutput

    estimator = FEDformerEstimator(
        prediction_length=PREDICTION_LENGTH,
        context_length=args.seq_len,
        n_block=args.n_block,
        hidden_size=args.hidden_size,
        lr=args.lr,
        weight_decay=args.weight_decay,
        dropout_rate=args.dropout,
        n_head=8,
        num_feat_dynamic_real=10,
        disable_future_feature=True,
        num_feat_static_cat=0 if args.disable_static else 5,
        cardinality=cardinality,
        batch_size=args.batch_size,
        freq='D',
        distr_output = NegativeBinomialOutput(),
        num_batches_per_epoch=1 if args.debug else (N_TS // args.batch_size + 1),
        trainer_kwargs={
            'accelerator': 'gpu',
            'devices': 1,
            'max_epochs': 1 if args.debug else 300,
            'callbacks': callbacks,
            'ckpt_kwargs': {'dirpath': ckpt_dir},
        },
    )
    return estimator

def get_tsmixer_estimator(args, cardinality, ckpt_dir, callbacks):
    from gluonts.torch.model.tsmixer import TSMixerEstimator
    from gluonts.torch.distributions import NegativeBinomialOutput

    estimator = TSMixerEstimator(
        prediction_length=PREDICTION_LENGTH,
        context_length=args.seq_len,
        n_block=args.n_block,
        hidden_size=args.hidden_size,
        lr=args.lr,
        weight_decay=args.weight_decay,
        dropout_rate=args.dropout,
        num_feat_dynamic_real=7,
        disable_future_feature=args.disable_future,
        num_feat_static_cat=0 if args.disable_static else 5,
        cardinality=cardinality,
        batch_size=args.batch_size,
        freq='D',
        distr_output = NegativeBinomialOutput(),
        num_batches_per_epoch=1 if args.debug else (N_TS // args.batch_size + 1),
        trainer_kwargs={
            'accelerator': 'gpu',
            'devices': 1,
            'max_epochs': 1 if args.debug else 300,
            'callbacks': callbacks,
            'ckpt_kwargs': {'dirpath': ckpt_dir},
        },
    )
    return estimator


def evaluate(data_dir, dataset, predictor, prediction_start, debug=False):
    forecast_it, _ = make_evaluation_predictions(
        dataset=dataset,
        predictor=predictor,
        num_samples=100
    )

    if debug:
        forecasts = [next(forecast_it)] * len(dataset)
    else:
        forecasts = list(tqdm(forecast_it, total=len(dataset)))

    forecasts_acc = np.zeros((len(forecasts), PREDICTION_LENGTH))
    if isinstance(forecasts[0], (PTDistributionForecast, MXDistributionForecast, QuantileForecast)):
        for i in range(len(forecasts)):
            forecasts_acc[i] = forecasts[i].mean
    else:
        for i in range(len(forecasts)):
            forecasts_acc[i] = np.mean(forecasts[i].samples, axis=0)
    wrmsse = evaluate_wrmsse(data_dir, forecasts_acc, prediction_start, score_only=True)
    return wrmsse


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_dir', default='./data')
    parser.add_argument('--ckpt_dir', default='./ckpt')
    parser.add_argument('--model', default='tsmixer')
    parser.add_argument('--seq_len', type=int, default=35)
    parser.add_argument('--n_block', type=int, default=1)
    parser.add_argument('--n_stack', type=int, default=30)
    parser.add_argument('--n_head', type=int, default=4)
    parser.add_argument('--temporal_hidden_size', type=int, default=16)
    parser.add_argument('--hidden_size', type=int, default=128)
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--weight_decay', type=float, default=0)
    parser.add_argument('--dropout', type=float, default=0.05)
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--patience', type=int, default=30)
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--disable_static', action='store_true')
    parser.add_argument('--disable_future', action='store_true')
    parser.add_argument('--debug', action='store_true')
    parser.add_argument('--result_path', default='result.csv')

    return parser.parse_args()


def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    # Uncomment the following lines if you use MXNet
    # import mxnet as mx
    # mx.random.seed(seed)


def main():
    args = parse_args()
    set_seed(args.seed)

    if args.model == 'nbeats':
        exp_id = f'{args.model}_nb{args.n_block}_ns{args.n_stack}_hs{args.hidden_size}_s{args.seed}'
    elif 'tft' in args.model:
        exp_id = f'{args.model}_nh{args.n_head}_dp{args.dropout}_hs{args.hidden_size}_s{args.seed}'
    elif args.model == 'prob_patchtst':
        exp_id = f'{args.model}_nb_{args.n_block}_nh{args.n_head}_dp{args.dropout}_hs{args.hidden_size}_s{args.seed}'
    else:
        exp_id = f'{args.model}_nb{args.n_block}_dp{args.dropout}_hs{args.hidden_size}_ds{int(args.disable_static)}_df{int(args.disable_future)}_s{args.seed}'

    train_ds, val_ds, test_ds, stat_cat_cardinalities = load_datasets(
        args.data_dir, feature_dict=('tft' in args.model))

    if args.model in ['tft', 'nbeats', 'prob_tft']:
        from custom_callback import EarlyStopping
        early_stop_callback = EarlyStopping(patience=args.patience)
    else:
        from pytorch_lightning.callbacks.early_stopping import EarlyStopping
        early_stop_callback = EarlyStopping(monitor='val_loss', patience=args.patience)
    ckpt_dir = f'{args.ckpt_dir}/{exp_id}'
    estimator = globals()[f'get_{args.model}_estimator'](
        args, cardinality=stat_cat_cardinalities, ckpt_dir=ckpt_dir, callbacks=[early_stop_callback])

    start_training_time = time.time()
    predictor = estimator.train(train_ds, validation_data=val_ds, num_workers=8)
    end_training_time = time.time()
    elasped_training_time = end_training_time - start_training_time
    print(f'Training finished in {elasped_training_time} secconds')

    val_wrmsse = evaluate(args.data_dir, val_ds, predictor, VAL_START, debug=args.debug)
    print(f'val wrmsse: {val_wrmsse}')
    test_wrmsse = evaluate(args.data_dir, test_ds, predictor, TEST_START, debug=args.debug)
    print(f'test wrmsse: {test_wrmsse}')

    if 'tsmixer' in args.model:
        args.model = f'tsmixer_ds{(args.disable_static)}_df{(args.disable_future)}'

    data = [{
        'model': args.model,
        'seq_len': args.seq_len,
        'val_wrmsse': val_wrmsse,
        'test_wrmsse': test_wrmsse,
        'training_time': elasped_training_time,
        'n_block': args.n_block,
        'temporal_hidden_size': args.temporal_hidden_size,
        'hidden_size': args.hidden_size,
        'lr': args.lr,
        'dropout': args.dropout,
        'n_stack': args.n_stack,
        'n_head': args.n_head,
        'seed': args.seed,
    }]

    df = pd.DataFrame.from_records(data)
    if os.path.exists(args.result_path):
        df.to_csv(args.result_path, mode='a', index=False, header=False)
    else:
        df.to_csv(args.result_path, mode='w', index=False, header=True)


if __name__ == '__main__':
    main()

