import argparse
import os
import time
import warnings
from typing import Union

warnings.simplefilter(action="ignore", category=FutureWarning)

import numpy as np
import tensorflow as tf

import traceback
import computational_experiments.real_world.src.data_utils as data_utils
from models import SWIM, ESN, BaseModel

MAX_SEED = np.iinfo(np.int64).max
DIVERGENCE_THRESHOLD = 1e4
TRAIN, VAL, TEST = "train", "val", "test"


def load_data(
        data_folder: str, size_limit: Union[int, None] = None
) -> dict[str, data_utils.DataSplit]:
    splits = dict()
    for label in [TRAIN, VAL, TEST]:
        filename = os.path.join(data_folder, f"{label}.csv")
        splits[label] = data_utils.DataSplit(filename, size_limit)
    return splits


def get_model_class(model_name) -> type[BaseModel]:
    if model_name == "kirnn":
        return SWIM
    if model_name == "esn":
        return ESN

    raise ValueError(f"Unknown model name: {model_name}.")


def run_model(
        model: BaseModel, splits: dict[str, data_utils.DataSplit], metric: str = "abs"
) -> data_utils.DataOutput:
    train_start = time.time()
    model.train(splits[TRAIN].data, splits[VAL].data)
    train_end = time.time()
    training_time = train_end - train_start
    print("Training: done.")

    predictions = dict()
    metrics = dict()
    prediction_times = dict()
    for label in [TRAIN, VAL, TEST]:
        split = splits[label]
        try:
            predict_start = time.time()
            split_predictions = model.predict(split.data)
            predict_end = time.time()

            predictions[label] = split_predictions
            metrics[label] = model.compute_metric(
                split_predictions, split.data, metric=metric
            )
            prediction_times[label] = predict_end - predict_start

            if np.any(np.abs(split_predictions) > DIVERGENCE_THRESHOLD):
                raise RuntimeError(f"Training diverged for split={label}.")

            print(f"Evaluating {label}: done.")

        except Exception as e:
            msg = f"Prediction failed for split={label}: {e}, model={model}"
            msg += f"\t{config}\n\n\t"
            msg += "\n\t".join(traceback.format_exception(e, tb=None))

    return data_utils.DataOutput(
        predictions=predictions,
        metrics=metrics,
        training_time=training_time,
        prediction_times=prediction_times,
    )


def run_experiment(config: data_utils.DataConfig) -> data_utils.DataOutput:
    # Limit number of cores used.
    tf.config.threading.set_intra_op_parallelism_threads(4)
    tf.config.threading.set_inter_op_parallelism_threads(4)

    # Set global seeds (just in case some functions do not use local seeds).
    np.random.seed(config.seed)
    tf.random.set_seed(config.seed)
    rng = np.random.default_rng(config.seed)

    # Get the data.
    splits = load_data(config.data_folder)

    # Get the model.
    model_cls = get_model_class(config.model_name)
    model = model_cls(
        target=config.target,
        time_delay=config.time_delay,
        horizon=config.horizon,
        rng=rng,
        layer_width=config.width,
        activation=config.activation,
        regularization_scale=config.regularization,
        n_features=splits[TRAIN].data.shape[1],
    )

    # Train and evaluate the model.
    output = run_model(model, splits)

    # Plot and save the results.
    fig = data_utils.plot_results(splits, output, target=config.target)
    data_utils.save_results(
        fig,
        output,
        output_folder=config.output_folder,
        base_filename=f"run_{config.run_id}",
        save_predictions=config.save_predictions,
    )
    return output


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-i", "--run_id", type=int, help="ID of the config.")
    parser.add_argument("-m", "--model_name", type=str, help="Model name.")
    parser.add_argument(
        "-f",
        "--data_folder",
        type=str,
        help="Folder with train.csv, val.csv, and test.csv.",
    )
    parser.add_argument("-t", "--target", type=str, help="Name of the target feature.")
    parser.add_argument("-o", "--output_folder", type=str, help="Output folder.")
    parser.add_argument(
        "-z", "--horizon", type=int, default=None, help="Prediction horizon."
    )
    parser.add_argument("-w", "--width", type=int, help="Layer width.")
    parser.add_argument("-a", "--activation", type=str, help="Activation function.")
    parser.add_argument(
        "-r", "--regularization", type=float, help="Regularization scale."
    )
    parser.add_argument("-s", "--seed", type=int, help="Random seed.")
    parser.add_argument("-d", "--time_delay", type=int, help="Time delay size.")
    parser.add_argument(
        "-p",
        "--save_predictions",
        type=bool,
        help="If True, save predictions of the model.",
    )
    args = parser.parse_args()
    config = data_utils.DataConfig(**vars(args))

    run_experiment(config)
