import numpy as np


"""
Generate a dataset from a loss dict.
"""


def create_X_y_data(data, features):

    loss_dict = data['loss_dict']
    accessible_datasets = data['datasets']
    X = []
    y = []
    for model_key, dict_model in loss_dict.items():
        for t, loss_value in dict_model.items():
            feature_x = generate_x_feature(model_key, t, accessible_datasets, features)
            X.append(feature_x)
            y.append(loss_value)
    y = np.array(y)
    X = np.array(X)

    return X, y


def aggregate_pred(X_test, y_test):
    predicted_PE = {}
    for i in range(X_test.shape[0]):
        model_t = X_test[i, 0]
        t = X_test[i, 1]

        pe = y_test[i]
        if model_t not in predicted_PE:
            predicted_PE[model_t] = {t: pe}
        else:
            predicted_PE[model_t][t] = pe
    return predicted_PE


def compute_distribution_shift(t, datasets):
    prev_X = np.mean(datasets[t-1]['X_val'], axis=0)
    X = np.mean(datasets[t]['X_val'], axis=0)
    average_X_shift = np.sum(np.abs(prev_X-X))
    return average_X_shift


def generate_x_feature(model_index:int, t:int, datasets:dict, features:str):
    """_summary_

    Args:
        model_index (int): index of the model
        t (int): time index
        datasets (dict): dict containing all available datasets
        features (str): specify which features we are using

    Returns:
        _type_: _description_
    """

    x_feat = [model_index, t]
    if 'time_since_train' in features:
        time_since_training = t-model_index
        x_feat.append(time_since_training)
    if 'second_order' in features:
        x_feat.append(model_index**2)
        x_feat.append(t**2)
    if 'shift_dist' in features:
        average_X_shift = 0
        avergae_y_shift = 0
        if t > 0:
            most_recent_dataset = max(datasets.keys())
            if t > most_recent_dataset:  # we take the most recent shift we have access to
                average_X_shift = compute_distribution_shift(
                    most_recent_dataset, datasets)
            else:  # if we have access to the data of interest
                average_X_shift = compute_distribution_shift(
                    t, datasets)

        x_feat.append(average_X_shift)
    
    return x_feat
