import sys
import pickle
import torch
import argparse
import numpy as np
import json
from sklearn import metrics
from datetime import datetime

sys.path.append("./")

from acfssa import (
    MSSA,
    MultivariateARIMA,
    Prophet,
    RealDataGenerator,
    ExperimentManager,
    ModelClass,
    DeepAR,
    LSTM_
)

NAME_TRMF = "TRMF"
NAME_ARIMA = "ARIMA"
NAME_MSSA = "mSSA"
NAME_SAMOSSA = "SAMoSSA"
NAME_PROPHET = "Prophet"
NAME_DEEPAR = "DeepAR"
NAME_LSTM = "LSTM"

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("dataset", type=str)
    parser.add_argument("models", type=str, nargs="+")
    parser.add_argument("--startIncl", type=int, default=None)
    parser.add_argument("--endExcl", type=int, default=None)
    parser.add_argument("--window", type=int, default=1)
    parser.add_argument("--trial", action="store_true")
    parser.add_argument("--standardize", action="store_true")
    parser.add_argument("--returns", action="store_true")
    args = parser.parse_args()

    for model in args.models:
        assert model in (
            NAME_ARIMA,
            NAME_MSSA,
            NAME_SAMOSSA,
            NAME_PROPHET,
            NAME_DEEPAR,
            NAME_TRMF,
            NAME_LSTM
        ), "Unknown model!"

    np.random.seed(420)
    torch.manual_seed(420)

    if args.dataset == "synthetic":
        generator = RealDataGenerator(
            "datasets/samossa/synthetic_20230503.npy",
            startIndexIncl=args.startIncl,
            endIndexExcl=args.endExcl,
            standardize=args.standardize,
        )
        numTrain = 10000
        windowSize = args.window
        numWindowsEval = int(25 // args.window)
        numWindowsTest = int(25 // args.window)
        numSeries = generator.data.shape[1]
        optimalMatrixSize = int(np.sqrt(numSeries * numTrain))
    elif args.dataset == "electricity":
        generator = RealDataGenerator(
            "datasets/electricity/electricity.npy",
            startIndexIncl=args.startIncl,
            endIndexExcl=args.endExcl,
            standardize=args.standardize,
        )
        numTrain = 25824
        windowSize = args.window
        numWindowsEval = int(48 // args.window)
        numWindowsTest = int(48 // args.window)
        numSeries = generator.data.shape[1]
        optimalMatrixSize = int(np.sqrt(numSeries * numTrain))
    elif args.dataset == "exchange":
        generator = RealDataGenerator(
            "datasets/exchange/exchange_returns.npy" if args.returns else "datasets/exchange/exchange.npy",
            startIndexIncl=args.startIncl,
            endIndexExcl=args.endExcl,
            standardize=args.standardize,
        )
        numTrain = 7528
        if args.returns:
            numTrain -= 1
        windowSize = args.window
        numWindowsEval = int(30 // args.window)
        numWindowsTest = int(30 // args.window)
        numSeries = generator.data.shape[1]
        optimalMatrixSize = int(np.sqrt(numSeries * numTrain))
    else:
        assert args.dataset == "traffic", "Unknown dataset!"
        generator = RealDataGenerator(
            "datasets/traffic/traffic.npy",
            startIndexIncl=args.startIncl,
            endIndexExcl=args.endExcl,
            standardize=args.standardize,
        )
        numTrain = 10248
        windowSize = args.window
        numWindowsEval = int(48 // args.window)
        numWindowsTest = int(48 // args.window)
        numSeries = generator.data.shape[1]
        optimalMatrixSize = int(np.sqrt(numSeries * numTrain))

    startIndexIncl = args.startIncl or 0
    endIndexExcl = args.endExcl or generator.numSeries

    manager = ExperimentManager(generator)

    if NAME_ARIMA in args.models:
        manager.register(
            NAME_ARIMA,
            ModelClass(
                MultivariateARIMA,
                dict(
                    numSeries=numSeries,
                ),
                dict(
                    arOrder=[1, 2, 3],
                    diffOrder=[0, 1],
                    maOrder=[1, 2, 3],
                ),
                True,
            ),
        )
    if NAME_DEEPAR in args.models:
        manager.register(
            NAME_DEEPAR,
            ModelClass(
                DeepAR,
                dict(predictionLength=windowSize, numEpochs=50, numSamples=100),
            ),
        )
    
    if NAME_LSTM in args.models:
        manager.register(
            NAME_LSTM,
            ModelClass(
                LSTM_,
                dict(predictionLength=windowSize, epochs=50),
                dict(
                        num_layers=[2, 3, 4],
                    )
            )
        )
    # if NAME_TRMF in args.models:
    #     manager.register(
    #         NAME_TRMF,
    #         ModelClass(TRMF, dict(k =2)),
    #     )


    if NAME_MSSA in args.models:
        manager.register(
            NAME_MSSA,
            ModelClass(
                MSSA,
                dict(
                    numSeries=numSeries,
                    numCoefs=optimalMatrixSize,
                    arOrder=None,
                ),
                dict(
                    rankEst=["donoho", "energy", "fixed"],
                    numCoefs=[
                        optimalMatrixSize,
                        optimalMatrixSize // 3,
                        optimalMatrixSize // 5,
                    ],
                ),
            ),
        )

    if NAME_SAMOSSA in args.models:
        manager.register(
            NAME_SAMOSSA,
            ModelClass(
                MSSA,
                dict(
                    numSeries=numSeries,
                ),
                dict(
                    rankEst=["donoho", "energy", "fixed"],
                    numCoefs=[
                        optimalMatrixSize,
                        optimalMatrixSize // 3,
                        optimalMatrixSize // 5,
                    ],
                    arOrder=[0, 1, 2, 3],
                ),
            ),
        )

    if NAME_PROPHET in args.models:
        manager.register(
            NAME_PROPHET,
            ModelClass(
                Prophet,
                dict(numSeries=numSeries, freq=("D" if args.dataset == "exchange" else "H")),
                dict(
                    changepointPriorScale=[0.001, 0.05, 0.2],
                    seasonalityPriorScale=[0.01, 10],
                    seasonalityMode=["additive", "multiplicative"],
                ),
                True,
            ),
        )

    results = manager.run(numTrain, windowSize, numWindowsEval, numWindowsTest)
    
    # save to file
    now = datetime.now()
    dt_string = now.strftime("%Y%m%d%H%M%S")
    modelNames = "_".join(sorted(args.models))
    suffix = "std" if args.standardize else ""
    filename = f"results/{args.dataset}{suffix}_{startIndexIncl}_{endIndexExcl}_h{windowSize}_{modelNames}_{dt_string}.pkl"
    if not args.trial:
        with open(filename, "wb") as f:
            metadata = {
                "windowSize": windowSize,
                "numTrain": numTrain,
                "models": args.models,
                "dataset": args.dataset,
            }
            pickle.dump({"results": results, "metadata": metadata}, f)
    else:
        print(f"Trial run, not saving to {filename}")

    # calculate r2
    r2 = {}
    for name, preds in results.items():
        if name in ("groundTruth", "noisyObservations"):
            continue
        r2[name] = np.mean(
            metrics.r2_score(results["groundTruth"], preds, multioutput="raw_values")
        )
    print(r2)
    print(f"Finished experiment on {args.dataset} for {args.models}, starting index {startIndexIncl} and ending {endIndexExcl}. Mean R^2 scores are {r2}.")
    return results

if __name__ == "__main__":
    results = main()
    