from typing import Dict, List, Union

import numpy as np
from autogluon.timeseries.dataset import TimeSeriesDataFrame
from autogluon.timeseries.utils.datetime import get_seasonality

DATA_RUN_NAME = "BaseModelEvaluations-5windows-refitevery1"


def timeseriesdataframe_to_tensor(tsdf: Union[List, Dict, TimeSeriesDataFrame]):
    """
    Outputs arrays of size:
    #folds x #items x #timesteps [x #features [x #models]]
    where the last two dimensions are optional, and only appear if
    the TimeSeriesDataFrame has more than one feature (e.g. quantiles), and if
    the input is a dict, which is assumed to contain the different model predicitons.
    """
    if isinstance(tsdf, list):
        # list elements are individual folds, so concatenate along the fold axis
        return np.concatenate([timeseriesdataframe_to_tensor(y) for y in tsdf], axis=0)

    if isinstance(tsdf, dict):
        # dict elements are different models, so introduce a new models axis
        return np.stack([timeseriesdataframe_to_tensor(y) for y in tsdf.values()], axis=-1)

    tsdf = tsdf.sort_index()
    tensor = tsdf.values
    shape = [1]
    shape += [len(tsdf.index.get_level_values(0).unique())]
    shape += [int(tsdf.shape[0] / shape[1])]
    if tsdf.shape[1] != 1 or tsdf.columns == ["0.5"]:
        shape += [tsdf.shape[1]]
    return tensor.reshape(tuple(shape))


def tensor_to_timeseriesdataframe(output_template: TimeSeriesDataFrame, tensor: np.ndarray):
    # print(f"tensor_to_timeseriesdataframe: output_template={output_template}, tensor.shape={tensor.shape}")
    assert len(tensor.shape) in (3, 4)

    n_folds, n_items, n_timesteps = tensor.shape[:3]
    assert n_folds == 1

    if len(tensor.shape) == 4:
        n_features = tensor.shape[-1]
    else:
        n_features = 1

    out = output_template.copy()
    assert out.shape[1] == n_features
    out[out.columns] = tensor.reshape((n_items * n_timesteps, n_features))
    return out


def process_simulation_artifact(data):
    # Filter test data such that we only keep items that appear in the train/val data
    df_val_0 = data["y_val"][0]
    known_items = df_val_0.index.get_level_values(0).unique()
    data["y_test"] = data["y_test"][data["y_test"].index.get_level_values(0).isin(known_items)]
    data["pred_proba_dict_test"] = {
        k: tsdf[tsdf.index.get_level_values(0).isin(known_items)] for (k, tsdf) in data["pred_proba_dict_test"].items()
    }

    # Get the list of models
    models = [m for m in data["pred_proba_dict_val"].keys() if "_FULL" not in m]
    data["models"] = models

    # drop the "mean" from the columns
    data["pred_proba_dict_val"] = {
        k: [tsdf[[col for col in tsdf.columns if col != "mean"]] for tsdf in tsdfs]
        for (k, tsdfs) in data["pred_proba_dict_val"].items()
    }
    data["pred_proba_dict_test"] = {
        k: tsdf[[col for col in tsdf.columns if col != "mean"]] for (k, tsdf) in data["pred_proba_dict_test"].items()
    }

    # Overwrite the test predictions with their "_FULL" model if available
    data["y_val_preds"] = {}
    data["y_test_preds"] = {}
    for model_name in models:
        data["y_val_preds"][model_name] = data["pred_proba_dict_val"][model_name]
        if (
            model_name
            in (
                "DeepAR",
                "DLinear",
                "PatchTST",
                "SimpleFeedForward",
                "TemporalFusionTransformer",
                "TiDE",
                "WaveNet",
            )
            and f"{model_name}_FULL" in data["pred_proba_dict_test"]
        ):
            data["y_test_preds"][model_name] = data["pred_proba_dict_test"][f"{model_name}_FULL"]
        else:
            data["y_test_preds"][model_name] = data["pred_proba_dict_test"][model_name]

    # for convenience infer the seasonal period if it has not been set
    data["eval_metric_seasonal_period"] = data["eval_metric_seasonal_period"] or get_seasonality(data["y_test"].freq)

    # save the "input" (i.e. the quantity that gets passed to the model for prediction) more explicitly
    data["y_test_in"] = data["y_test"].slice_by_timestep(None, -data["prediction_length"])
    data["y_val_in"] = [y_val.slice_by_timestep(None, -data["prediction_length"]) for y_val in data["y_val"]]

    return data


def limit_data_to_n_validation_windows(data, n_windows):
    data = data.copy()
    n_windows_available = len(data["y_val"])
    if n_windows > n_windows_available:
        raise ValueError(f"Asked for {n_windows} validation windows, but only {n_windows_available} are available")
    if n_windows < 1:
        raise ValueError(f"Asked for {n_windows} validation windows, but must be >= 1")
    for k in ("pred_proba_dict_val", "y_val_preds"):
        data[k] = {model: windows[n_windows_available - n_windows :] for model, windows in data[k].items()}
    k = "y_val"
    data[k] = data[k][n_windows_available - n_windows :]
    return data
