import torch
import numpy as np
import os  
from math import isnan

from config import metrics

def early_stop_check_all(model, 		
						epoch:int, 
						stoppers:dict, 
						res_dict_plots:dict, 
						res_dict_models: dict, 
						path_save_dir:str):

	for strategy in stoppers.keys():
		stopper = stoppers[strategy] 

		if len(res_dict_plots[strategy])!=0 and stopper.active :
			if ("box" in strategy):
				metric_value = torch.median(torch.stack(res_dict_plots[strategy][-1]))
				stopping = stopper.early_stop_wdown(metric_value)

			else : 
				metric_value = res_dict_plots[strategy][-1]
				stopping = stopper.early_stop_wup(metric_value)

			# when the counter resets, we have the current best model
			if stopper.counter == 0:
				# model.save(path_save_dir + "/best_models/params_best_model")  # create a file "nom_du_modele.zip"
				os.makedirs(path_save_dir + "/best_models", exist_ok=True)
				torch.save(model.state_dict(), path_save_dir + "/best_models/params_best_models_{}.pth".format(strategy))

			# when patience is reached, we stop checking for early stopping
			if stopping :
				res_dict_models["best_model_{}".format(strategy)] = epoch - stopper.patience
				print(f"Early stopping on {strategy} at epoch {epoch}\n")	
				print(f"Last value of {strategy} : {metric_value}\n")
				stopper.active = False 


class EarlyStopper:
	def __init__(self, patience=5, delta=0.0001):
		self.patience = patience
		self.delta = delta
		self.counter = 0
		self.min_val_loss = np.inf
		self.max_val_acc = 0
		self.active = True

	def early_stop_wup(self, val_loss):
		""" Early stop when metric stops decreasing (val_loss)"""
		if val_loss < self.min_val_loss:
			self.min_val_loss = val_loss
			self.counter = 0
		elif val_loss > (self.min_val_loss - self.delta): 
			self.counter += 1
			if self.counter >= self.patience:
				return True
		return False


	def early_stop_wdown(self, val_acc):
		""" Early stop when metric stops increasing (acc, auc-pr, vus)"""
		if val_acc > self.max_val_acc:
			self.max_val_acc = val_acc
			self.counter = 0
		elif val_acc < (self.max_val_acc + self.delta): 
			self.counter += 1
			if self.counter >= self.patience:
				return True
		return False







