import torch
import numpy as np
from math import isnan

from config import metrics

def early_stop_check_all(model, 
						 epoch:int, 
						 stoppers:dict, 
						 res_dict_plots:dict, 
						 res_dict_models: dict, 
						 path_save_dir:str):
	""" Early stop when a metric stops getting better (val_loss, acc, auc-pr, vus)
	Args:
	- stoppers (list): list of EarlyStopper objects
	Returns:
	"""
	for strategy in stoppers.keys():
		stopper = stoppers[strategy] 

		if len(res_dict_plots[strategy])!=0 and stopper.active :
			if ("box" in strategy):
				metric_value = torch.median(torch.stack(res_dict_plots[strategy][-1]))
				stopping = stopper.early_stop_wdown(metric_value)

			else : 
				metric_value = res_dict_plots[strategy][-1]
				stopping = stopper.early_stop_wup(metric_value)

			# when the counter resets, we have the current best model
			if stopper.counter == 0:
				model.save(path_save_dir + "/best_models/params_best_model_{}.zip".format(strategy))  # create a file "nom_du_modele.zip"

			# when patience is reached, we stop checking for early stopping
			if stopping :
				res_dict_models["best_model_{}".format(strategy)] = epoch - stopper.patience
				print(f"Early stopping on {strategy} at epoch {epoch}\n")	
				print(f"Last value of {strategy} : {metric_value}\n")
				stopper.active = False 


class EarlyStopper:
	def __init__(self, patience=5, delta=0.0001):
		self.patience = patience
		self.delta = delta
		self.counter = 0
		self.min_val_loss = np.inf
		self.max_val_acc = 0
		self.active = True

	def early_stop_wup(self, val_loss):
		""" Early stop when metric stops decreasing (val_loss)"""
		if val_loss < self.min_val_loss:
			self.min_val_loss = val_loss
			self.counter = 0
		elif val_loss > (self.min_val_loss - self.delta): 
			self.counter += 1
			if self.counter >= self.patience:
				return True
		return False


	def early_stop_wdown(self, val_acc):
		""" Early stop when metric stops increasing (acc, auc-pr, vus)"""
		if val_acc > self.max_val_acc:
			self.max_val_acc = val_acc
			self.counter = 0
		elif val_acc < (self.max_val_acc + self.delta): 
			self.counter += 1
			if self.counter >= self.patience:
				return True
		return False









# import torch
# from math import isnan

# class EarlyStopping:
#     def __init__(self, 
#                  patience:int = 10, 
#                  delta:float = 0.0):
#         """
#         Args:
#             patience (int): How long to wait after last time validation loss improved.
#                             Default: 7
#             delta (float): Minimum change in the monitored quantity to qualify as an improvement.
#                             Default: 0
#         Note : 
#             - Can be applied to any metric put in arg of EarlyStopping(metric). (Ex: val_loss, auc_pr, ...)
#         """
#         self.patience = patience
#         self.counter = 0
#         self.best_score = None
#         self.early_stop = False
#         self.delta = delta

#     def check(self, val_loss):
#         score = val_loss
#         if isnan(score):
#             self.early_stop = True
#             print("The test_loss score is NaN")
#             return self.early_stop
        
#         if self.best_score is None:
#             self.best_score = score

#         elif score > self.best_score - self.delta: # if the score don't beat the best
#             self.counter += 1
#             print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
#             if self.counter >= self.patience:
#                 print(f'The process ran out of patience : early stopped')
#                 self.early_stop = True

#         else: # the score beat the best
#             self.best_score = score
#             self.counter = 0

#         return self.early_stop

#     def reset(self):
#         self.counter = 0
#         self.best_score = None
#         self.early_stop = False
    