from operator import ge,le,gt,lt
from hyperopt.early_stop import no_progress_loss

class EarlystopWithWarmup:
    
    def __init__(self,patience=1,warmup=None,higher_better=False,allow_eq=False):
        self.patience = patience
        self.warmup = warmup
        self.higher_better = higher_better
        self.allow_eq = allow_eq
        
        
        if higher_better and allow_eq:
            self.compare = ge
        elif higher_better and not allow_eq:
            self.compare = gt
        elif not higher_better and allow_eq:
            self.compare = le
        elif not higher_better and not allow_eq:
            self.compare = lt
        else:
            raise ValueError("Wrong combination")
        print(higher_better,self.compare)
        
        self.n_trial = 0
        self.es_count = 0
        self.prev_best = None
        
    def __call__(self,trial,*args):
        # Called after trial is done
        result = trial.results[-1]
        if self.n_trial==0 and self.prev_best is None:
            self.prev_best = result["loss"]
        
        self.n_trial +=1
        if self.n_trial <= self.warmup:
#             print(f"warmup: {self.n_trial}")
            if self.compare(result["loss"],self.prev_best):
                self.prev_best = result["loss"]
            return False, {}
        else:
            self.es_count += 1
#             print(f"Out warmup: {self.n_trial},{self.es_count}",end="")
            if self.compare(result["loss"],self.prev_best):
                print(f"{self.n_trial},{self.es_count}->Reset patience")
                self.prev_best = result["loss"]
                self.es_count = 0
                
            if self.es_count>=self.patience:
                print(f"{self.n_trial},{self.es_count}->End search")
                return True, {}
            else: 
                return False, {}