import argparse
import os
import time

import numpy as np
import tensorflow as tf

import climate.src.climate_utils as climate_utils
from models import BaseModel, SWIM, LSTM

MAX_SEED = np.iinfo(np.int64).max
DIVERGENCE_THRESHOLD = 1e4
TRAIN, VAL, TEST = "train", "val", "test"


def load_data(
    data_folder: str,
    size_limit: int | None = None
) -> dict[str, climate_utils.ClimateSplit]:
    splits = dict()
    for label in [TRAIN, VAL, TEST]:
        filename = os.path.join(data_folder, f"{label}.csv")
        splits[label] = climate_utils.ClimateSplit(filename, size_limit)
    return splits


def get_model_class(model_name) -> type[BaseModel]:
    match model_name:
        case "swim":
            return SWIM 
        case "lstm":
            return LSTM
        case _:
            raise ValueError(f"Unknown model name: {model_name}.")


def run_model(
    model: BaseModel, splits: dict[str, climate_utils.ClimateSplit], metric: str = "abs"
) -> climate_utils.ClimateOutput:
    train_start = time.time()
    model.train(splits[TRAIN].data, splits[VAL].data)
    train_end = time.time()
    training_time = train_end - train_start
    print("Training: done.")

    predictions = dict()
    metrics = dict()
    prediction_times = dict()
    for label in [TRAIN, VAL, TEST]:
        split = splits[label]
        predict_start = time.time()
        split_predictions = model.predict(split.data)
        predict_end = time.time()

        predictions[label] = split_predictions
        metrics[label] = model.compute_metric(split_predictions, split.data, metric=metric)
        prediction_times[label] = predict_end - predict_start

        if np.any(np.abs(split_predictions) > DIVERGENCE_THRESHOLD):
            raise RuntimeError(f"Training diverged for split={label}.")
        
        print(f"Evaluating {label}: done.")

    return climate_utils.ClimateOutput(predictions=predictions,
                               metrics=metrics,
                               training_time=training_time,
                               prediction_times=prediction_times)


def run_experiment(config: climate_utils.ClimateConfig) -> climate_utils.ClimateOutput:
    # Limit number of cores used.
    tf.config.threading.set_intra_op_parallelism_threads(4)
    tf.config.threading.set_inter_op_parallelism_threads(4)

    # Set global seeds (just in case some functions do not use local seeds).
    np.random.seed(config.seed)
    tf.random.set_seed(config.seed)
    rng = np.random.default_rng(config.seed)

    # Get the data.
    splits = load_data(config.data_folder)

    # Get the model.
    model_cls = get_model_class(config.model_name)
    model = model_cls(
        target=config.target,
        time_delay=config.time_delay,
        horizon=config.horizon,
        rng=rng,
        layer_width=config.width,
        activation=config.activation,
        regularization_scale=config.regularization,
        num_features=splits[TRAIN].data.shape[1]
    )

    # Train and evaluate the model.
    output = run_model(model, splits)

    # Plot and save the results.
    fig = climate_utils.plot_results(splits, output, target=config.target)
    climate_utils.save_results(fig, output,
                       output_folder=config.output_folder,
                       base_filename=f"run_{config.run_id}",
                       save_predictions=config.save_predictions)
    return output


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-i", "--run_id", type=int, help="ID of the config.")
    parser.add_argument("-m", "--model_name", type=str, help="Model name.")
    parser.add_argument(
        "-f",
        "--data_folder",
        type=str,
        help="Folder with train.csv, val.csv, and test.csv.",
    )
    parser.add_argument("-t", "--target", type=str, help="Name of the target feature.")
    parser.add_argument("-o", "--output_folder", type=str, help="Output folder.")
    parser.add_argument("-z", "--horizon", type=int, default=None, help="Prediction horizon.")
    parser.add_argument("-w", "--width", type=int, help="Layer width.")
    parser.add_argument("-a", "--activation", type=str, help="Activation function.")
    parser.add_argument(
        "-r", "--regularization", type=float, help="Regularization scale."
    )
    parser.add_argument("-s", "--seed", type=int, help="Random seed.")
    parser.add_argument("-d", "--time_delay", type=int, help="Time delay size.")
    parser.add_argument("-p", "--save_predictions", type=bool, help="If True, save predictions of the model.")
    args = parser.parse_args()
    config = climate_utils.ClimateConfig(**vars(args))

    run_experiment(config)
