import itertools
import numpy as np
from src.models.base_models import BaseRetrainAlgo
from src.models.optimal_schedule import find_optimal_schedule

"""
The true oracle based from our formulation.
"""


class Oracle(BaseRetrainAlgo):
    def __init__(self, T: int, t_offline: int, relative_pe: bool):
        super().__init__(T, t_offline, relative_pe)
        self.name = 'Oracle'

    # search the optimal schedule on the test set and return it
    def train_offline(self, trainin_data, testing_data):
        # straight up use the test loss
        self.ground_truth_loss = testing_data['loss_dict']
        thetas = list(itertools.product([0, 1], repeat=self.T-self.t))

        optimal_theta = find_optimal_schedule(
            self.ground_truth_loss, thetas, self.retrain_cost, now=self.t, index_last_trained_model=self.most_recent_available_model, loss_is_01=None, unc=False)
        self.fixed_retrain_indices = list(
            np.argwhere(optimal_theta) + self.t_offline)
