from typing import List
from src.retraining_task_processing.f_generator import PredictorFGenerator
from src.seq_dataset import SeqDataset


def train_f_on_datasets(timesteps_to_train: List[int], seqDataset: SeqDataset, f_generator: PredictorFGenerator):
    # for the wild dataset, those f are already trained so we just need to load them
    if f_generator.already_stored_f:
        dict_trained_f = f_generator.get_empty_models(timesteps_to_train, seqDataset)
    else: 
        dict_trained_f = {}
        for t in timesteps_to_train:
            X_t, Y_t = seqDataset.get_X_Y(t, split='train')
            dict_trained_f[t] = f_generator.get_trained_f(
                X_t, Y_t)  # model trained with data t
    return dict_trained_f


def generate_loss_matrix(timesteps_to_generate: List[int], seqDataset: SeqDataset, f_generator: PredictorFGenerator, dict_trained_f: dict, relative_pe: bool = False):
    # building the full loss matrix on the timesteps specified PE[timesteps, timesteps]
    L_matrix = {t: {} for t in timesteps_to_generate}
    for t in timesteps_to_generate:
        # get the data at timestep t
        X_t_to_eval, y_t_to_eval = seqDataset.get_X_Y(t, split='test')

        # now we can evaluate all models that were evaluated before t
        previous_models_indices = range(timesteps_to_generate[0], t+1)
        for f_index in previous_models_indices:
            f = dict_trained_f[f_index]  # get the pretrained model f_index

            # evaluate model f_index at time step t
            loss = f_generator.get_loss_metric(f, X_t_to_eval, y_t_to_eval)

            L_matrix[f_index][t] = loss

    if relative_pe:  # if we consider a relative setting, we substract loss_tt' from loss_tt
        rel_L_matrix = {t: {} for t in timesteps_to_generate}
        for t in timesteps_to_generate:
            for f_index in range(t+1):

                loss = L_matrix[f_index][t]
                loss_tt = L_matrix[f_index][f_index]

                rel_L_matrix[f_index][t] = loss - loss_tt

        loss_dict_matrix = rel_L_matrix
    else:
        loss_dict_matrix = L_matrix

    return loss_dict_matrix
