import argparse
import json
import os
import time
import warnings
import numpy as np
import pandas as pd
from tqdm import trange

import utils
import ncp
from assmf import ASSMF


def main_assmf(data, args):

    model = ASSMF(data.shape,
                  args.period,
                  args.n_components,
                  args.n_regimes)

    model.initialize(data, args.period, args.init_cycles)

    # Online forecasting
    pred = np.zeros(data.shape)
    n_samples = data.shape[-1]

    for t in range(args.train_span,
                   n_samples - args.forecast_step,
                   args.forecast_step):

        for tt in trange(args.period, t, desc="train"):
            model.update(data[..., tt - args.period + 1:tt + 1], tt)
        
        for tt in trange(args.forecast_step, desc="test"):
            pred[..., t + tt] = model.predict(t, tt)

        # print(pred[t:t+args.forecast_step])

    return pred, model


def main_list_base(data, model, args, outpath):

    n_attributes = data.max()[:2] + 1
    model.n_attributes = n_attributes  #
    print(n_attributes, n_attributes.keys())

    timestamp = pd.date_range(
        start=data.index.min(), end=data.index.max(), freq=args.freq)
    train_span = args.train_span
    forecast_step = args.forecast_step
    duration = len(timestamp) - forecast_step
    model.dataset_length = duration  #

    print(timestamp)
    print(duration)

    # Initialization

    if args.model in ["trmf", "smf", "ssmf", "assmf"]:
        # Get a tensor for initialization
        t_train = timestamp[:args.period*args.init_cycle]
        train_tensor = utils.list2tensor_from_index(
            data[:t_train[-1]], t_train, n_attributes)

        model.initialize(train_tensor)

    # Online preocessing

    for t in range(train_span, duration, forecast_step):

        # print(sub_index.min(), sub_index.max())
        # model.date_range = timestamp[t - train_span:t]
        st = timestamp[t - train_span]
        ed = timestamp[t]
        train_data = data[st:ed]
        train_tensor = utils.list2tensor_from_index(
            train_data, timestamp[t - train_span:t], n_attributes)
        # print(train_data)
        # print(train_data.head())

        # Train
        if args.model in ["ncp", "fold", "trmf"]:
            st = timestamp[0]
            ed = timestamp[t]
            train_data = data[st:ed]
            train_tensor = utils.list2tensor_from_index(
                train_data, timestamp[t - train_span:t], n_attributes)

            elapsed_time = model.fit(train_tensor, 0)
        
        elif args.model in ["smf", "ssmf"]:
            st = timestamp[t - train_span]
            ed = timestamp[t]
            train_data = data[st:ed]
            train_tensor = utils.list2tensor_from_index(
                train_data, timestamp[t - train_span:t], n_attributes)

            elapsed_time = model.fit(train_tensor, t - train_span)

        elif args.model in ["assmf"]:
            if t == train_span:
                st = timestamp[0]
                ed = timestamp[t]
                train_data = data[st:ed]
                train_tensor = utils.list2tensor_from_index(
                    train_data, timestamp[t - train_span:t], n_attributes)

                elapsed_time = model.fit(train_tensor, 0)
            else:
                st = timestamp[t - train_span - args.period]
                ed = timestamp[t]
                train_data = data[st:ed]
                train_tensor = utils.list2tensor_from_index(
                    train_data, timestamp[t - train_span - args.period:t], n_attributes)

                elapsed_time = model.fit(train_tensor, t - train_span - args.period)

        # Test

        if args.model in ["smf", "ssmf"]:
            pred = model.predict_seq(t, forecast_step)
        else:
            pred = model.predict(forecast_step)

        # save each test results for efficiency
        tmp_dir = os.path.join(outpath, str(t))
        os.makedirs(tmp_dir)
        model.save(tmp_dir)
        np.save(tmp_dir + "/pred.npy", pred)
        np.savetxt(tmp_dir + "/time.txt", elapsed_time)
        print("Elapsed time:", elapsed_time, "sec.")


def main_tensor_base(data, model, args, outpath):

    train_span = args.train_span
    forecast_step = args.forecast_step
    duration = data.shape[-1] - forecast_step

    assert train_span == forecast_step

    if args.model in ["trmf"]:
        model.initialize(data[..., :args.period*args.init_cycle])
        model.dims = list(data.shape[:-1])

    elif args.model in ["smf", "ssmf", "assmf"]:
        model.dataset_length = duration
        model.initialize(data)

    start_step = forecast_step

    if args.start_step is not None:
        start_step = args.start_step
        # Run online algorithms till the start step
        if args.model in ["smf", "ssmf", "assmf"]:
            model.fit(data[..., :start_step], 0)

    for t in range(start_step, duration, forecast_step):

        print("Train")

        if args.model in ["ncp", "fold", "trmf"]:
            # for offline methods
            elapsed_time = model.fit(data[..., :t], 0)

        elif args.model in ["assmf"]:
            print("Train data", t - forecast_step - args.period, t)
            # add the last season
            elapsed_time = model.fit(
                data[..., t - forecast_step - args.period:t],
                t-forecast_step-args.period)

        else:
            # for online methods
            elapsed_time = model.fit(
                data[..., t-forecast_step:t], t-forecast_step)

        print("Test")

        if args.model in ["smf", "ssmf", "assmf"]:
            pred = model.predict_seq(t, forecast_step)
        else:
            pred = model.predict(forecast_step)

        # Save results
        tmp_dir = os.path.join(outpath, str(t))
        os.makedirs(tmp_dir, exist_ok=True)
        print("OUTPATH:", tmp_dir)
        model.save(tmp_dir)
        np.save(tmp_dir + "/pred.npy", pred)
        np.savetxt(tmp_dir + "/time.txt", elapsed_time)
        print("Elapsed time:", elapsed_time.sum(), "sec.")


def get_model(args):
    """ 
        "args" must include all the parameters to initialize
        the specified model object
    """
    if args.model == "ncp":
        # offline method
        print("model: NCP")
        return ncp.NCP(rank=args.n_components)

    if args.model == "ssmf":
        print("model: SSMF")
        return ssmf.SSMF(n_seasons=args.period,
                         n_components=args.n_components,
                         n_regimes=args.n_regimes,
                         learning_rate=args.learning_rate,
                         init_cycles=args.init_cycle)

    if args.model == "assmf":
        print("model: ASSMF")
        return ASSMF(args.period,
                     args.n_components,
                     args.n_regimes,
                     args.learning_rate,
                     init_cycles=args.init_cycle,
                     update_freq=args.update_freq,
                     compression=args.compression)


if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument("dataset", type=str,
        choices=[
            "tycho", "citibike", "nytaxi", "gtrends", "olist",
            "us_daily_sweets"])

    parser.add_argument("model", type=str)
    parser.add_argument("period", type=int)
    parser.add_argument("--freq", type=str, default="H")
    parser.add_argument("--n_components", type=int)
    parser.add_argument("--n_regimes", type=int, default=3)
    parser.add_argument("--learning_rate", type=float, default=0.1)
    parser.add_argument("--init_cycle", type=int, default=3)
    parser.add_argument("--train_span", type=int, default=1000)
    parser.add_argument("--update_freq", type=int, default=500)
    parser.add_argument("--compression", action="store_true")
    parser.add_argument("--forecast_step", type=int, default=500)
    parser.add_argument("--outpath", type=str, default="../out/")
    parser.add_argument("--n_trial", type=int, default=1)
    parser.add_argument("--start_step", type=int)
    parser.add_argument("--start_date", type=str, default="2017-03")
    parser.add_argument("--end_date", type=str, default="2021-03")
    parser.add_argument("--citibike_key", type=str, default="stationid")
    parser.add_argument("--beta", type=float, default=0.1, help="bias rate of MDL")
    parser.add_argument("--as_tensor", action="store_true")

    args = parser.parse_args()

    # Dataset selection
    if args.dataset == "tycho":
        # data: numpy.array
        data = utils.load_tycho("../dat/project_tycho.csv", as_tensor=True)

    elif args.dataset == "citibike":
        # data: pandas.DataFrame
        data = utils.load_citibike("../dat/citibike/",
                                   key=args.citibike_key,
                                   freq=args.freq,
                                   as_tensor=args.as_tensor)
        
    elif args.dataset == "nytaxi":
        # data: pandas.DataFrame
        data = utils.load_nytaxi("../dat/nytaxi/")

    elif args.dataset == "olist":
        data = utils.load_olist("../dat/olist/")
        print(data.shape)

    # GoogleTrends
    elif args.dataset in ["us_daily_sweets"]:
        # data: numpy.array
        data = utils.load_gtrends(args.dataset)

    # Main process
    for trial_id in range(args.n_trial):
        
        # Initialize model object
        model_object = get_model(args)
        params = model_object.set_params(args)

        # Train/Test
        if args.dataset in ["citibike"]:
            # Root directory for outputs
            outpath = os.path.join(
                args.outpath,
                args.dataset,
                args.citibike_key,
                args.model,
                f"forecast_step={args.forecast_step}",
                *["=".join([k, str(v)]) for k, v in params.items()],
                str(trial_id))

            os.makedirs(outpath, exist_ok=True)

            if args.as_tensor:
                main_tensor_base(data, model_object, args, outpath)
            else:
                main_list_base(data, model_object, args, outpath)

        elif args.dataset in [
            "tycho",  "nytaxi", "gtrends", "olist",
            "us_daily_sweets"]:

            # Root directory for outputs
            outpath = os.path.join(
                args.outpath,
                args.dataset,
                args.model,
                f"forecast_step={args.forecast_step}",
                *["=".join([k, str(v)]) for k, v in params.items()],
                str(trial_id))

            os.makedirs(outpath, exist_ok=True)

            main_tensor_base(data, model_object, args, outpath)

        else:
            raise ValueError

        # Save experimental setting
        with open(outpath + "/setting.json", "w") as f:
            json.dump(vars(args), f, indent=4, sort_keys=True)
