
import random
from typing import List, Union


class BaseRetrainAlgo():
    def __init__(self, T: int, t_offline: int, relative_pe: bool = False):
        super().__init__()
        self.T = T
        self.t_offline = t_offline
        self.most_recent_available_model = t_offline-1
        self.t = t_offline
        self.relative_pe = relative_pe
        self.info = []

    # some set up before running the eval loop. Help for compatibility with CARA code, but should not be needed.

    def initialize_eval(self, all_data):
        pass

    def update_at_t(self, new_info):
        self.t += 1
        self.info.append(new_info)

    # by default, no training
    def train_offline(self, training_data, testing_data=None):
        pass

    # the retrain cost accessible to the algo is not necessarily the true one
    def set_retrain_cost(self, train_retrain_cost):
        self.retrain_cost = train_retrain_cost

   
    def decide(self, t: int) -> Union[bool, int]:
        if t in self.fixed_retrain_indices:
            retrain = True
            self.most_recent_available_model = self.t
        else:
            retrain = False
        return retrain, self.most_recent_available_model


"""
Randomly retrains without following any schedule
"""


class Random(BaseRetrainAlgo):
    def __init__(self, T: int, t_offline: int):
        super().__init__(T, t_offline)
        self.name = 'random'

    # We ignore any info, we randomly retrain
    def decide(self, t: int) -> Union[bool, int]:
        retrain = bool(random.randint(0, 1))
        if retrain:
            self.most_recent_available_model = self.t

        return retrain, self.most_recent_available_model


"""
Retrains following a fixed schedule
"""


class FixedRetrain(BaseRetrainAlgo):
    def __init__(self, T: int, t_offline: int, fixed_retrain_indices: List[int]):
        super().__init__(T, t_offline)
        self.fixed_retrain_indices = fixed_retrain_indices
        self.name = 'fixed_'+str(fixed_retrain_indices)

    

