from typing import List
import copy
import operator
from enum import Enum, auto
import numpy as np

from torch.nn import Module


class StopVariable(Enum):
    LOSS = auto()
    ACCURACY = auto()
    NONE = auto()


class Best(Enum):
    RANKED = auto()
    ALL = auto()


stopping_args = dict(
        stop_varnames=[StopVariable.ACCURACY, StopVariable.LOSS],
        patience=1400, max_epochs=2000, remember=Best.RANKED)

class EarlyStopping:
    def __init__(
            self, model: Module, stop_varnames: List[StopVariable],
            patience: int = 10, max_epochs: int = 200, remember: Best = Best.ALL):
        self.model = model
        self.comp_ops = []
        self.stop_vars = []
        self.best_vals = []
        for stop_varname in stop_varnames:
            if stop_varname is StopVariable.LOSS:
                self.stop_vars.append('loss')
                self.comp_ops.append(operator.le)
                self.best_vals.append(np.inf)
            elif stop_varname is StopVariable.ACCURACY:
                self.stop_vars.append('acc')
                self.comp_ops.append(operator.ge)
                self.best_vals.append(-np.inf)
        self.remember = remember
        self.remembered_vals = copy.copy(self.best_vals)
        self.max_patience = patience
        self.patience = self.max_patience
        self.max_epochs = max_epochs
        self.best_epoch = None
        self.best_state = None

    def check(self, values: List[np.floating], epoch: int, test_acc = None) -> bool:
        checks = [self.comp_ops[i](val, self.best_vals[i])
                  for i, val in enumerate(values)]
        # if any(checks):
        #     self.best_vals = np.choose(checks, [self.best_vals, values])
        #     self.patience = self.max_patience
        #
        #     comp_remembered = [
        #             self.comp_ops[i](val, self.remembered_vals[i])
        #             for i, val in enumerate(values)]
        #     if self.remember is Best.ALL:
        #         if all(comp_remembered):
        #             self.best_epoch = epoch
        #             self.remembered_vals = copy.copy(values)
        #             self.best_state = {
        #                     key: value.cpu() for key, value
        #                     in self.model.state_dict().items()}
        #     elif self.remember is Best.RANKED:
        #         for i, comp in enumerate(comp_remembered):
        #             if comp:
        #                 # if not(self.remembered_vals[i] == values[i]):
        #                 #     self.best_epoch = epoch
        #                 #     self.remembered_vals = copy.copy(values)
        #                 #     self.best_state = {
        #                 #             key: value.cpu() for key, value
        #                 #             in self.model.state_dict().items()}
        #                 #     break
        #                 if not(self.remembered_vals[i] == values[i]):
        #                     self.best_epoch = epoch
        #                     self.remembered_vals = copy.copy(values)
        #                     self.best_state = {
        #                             key: value.cpu() for key, value
        #                             in self.model.state_dict().items()}
        #                     break
        #             else:
        #                 break

        if any(checks):
            self.best_vals = np.choose(checks, [self.best_vals, values])
            self.patience = self.max_patience

            comp_remembered = [
                    self.comp_ops[i](val, self.remembered_vals[i])
                    for i, val in enumerate(values)]
            ###################### 修改为验证集准确率最高，或者相等的时候取loss最低 ##########################
            if (not(values[0] == self.remembered_vals[0]) and comp_remembered[0]) or (values[0] == self.remembered_vals[0] and comp_remembered[1]):
                # self.best_vals = values
                self.best_epoch = epoch
                # if values[0] > 0.3:
                #     print('CHOOSE: best_epoch=' + str(epoch) + ', loss=' + str(round(values[1], 2)) + ', acc=' +
                #           str(round(values[0]*100, 2)) + ', test_acc=' + str(round(test_acc*100, 2)))
                # print('CHOOSE: best_epoch=' + str(epoch) + ', loss=' + str(round(values[1], 2))
                #       + ', acc=' + str(round(values[0], 2)*100) + ', test_acc=' + str(round(test_acc*100, 2)))
                self.remembered_vals = copy.copy(values)
                self.best_state = {
                    key: value.cpu() for key, value
                    in self.model.state_dict().items()}
        else:
            self.patience -= 1
        return self.patience == 0
