import matplotlib.pyplot as plt
import numpy as np
import math as m
import warnings

from keras.callbacks import History, Callback

colors = 'bgrcmykw'

def plot_history(history, metric_names = None, title = None, ignore_epochs = False, new_figure = True, val = True):
    """ 
    metrics_names is a list containing the metrics/costs to be displayed
    If no metrics specified, model.metrics_names is used
    It uses one color per metric, dashed lines for validation set
    
    history is either a keras History object or a dictionary
    with keys 'epoch' and 'history'.
    """    
    if isinstance(history,dict):
        epoch = history['epoch']
        history = history['history']
    elif isinstance(history, History):
        epoch = history.epoch
        history = history.history
    else:
        raise ValueError()
        
    
    if not metric_names:
        metric_names = [metric for metric in history.keys() if metric[:4] != 'val_']
    
    if ignore_epochs:
        epoch = range(len(history[metric_names[0]]))
    
    if new_figure:
        plt.figure()
    
    handles = [None]*len(metric_names)
    for i,metric in enumerate(metric_names):
        handles[i], = plt.plot(epoch,history[metric], '-'+colors[i],label = metric)
        if val and 'val_'+metric in history.keys():
            plt.plot(epoch,history['val_'+metric],'--'+colors[i])
        
    plt.legend(handles = handles)
    plt.xlabel('epoch')
    
    if title:
        plt.title(title)

def history_todict(history):
    """
    returns python dictionary containing keras history object content.
    Easier to handle for pickle
    """
    return {'epoch' : history.epoch,'history' : history.history}

def lr_schedule(initial_lr,factor,epochs):
    '''
    epoch indices starts at 0
    '''
    def schedule(epoch):
        lr = initial_lr

        for i in range(len(epochs)):
            if epoch>= epochs[i]:
                lr *= factor
        return lr
    return schedule


class StoppingCriteria(Callback):
    '''
    Callback that stops training before the announced number of epochs when some criteria are met.
    '''
    def __init__(self, not_working=(0.,-1), finished = 0., converged = np.inf):
        '''
        not_working is a tuple (acc,nbepochs) with the accuracy that should be reached after nbepochs to consider the training as working
        finished is a training loss value for which the training can be considered as finished
        converged is the number of epochs with unchanged training loss which indicates that the network doesn't change anymore
        '''
        super().__init__()
        self.acc, self.nbepochs = not_working
        self.finished = finished
        self.converged = converged
        
        self.previous_loss = -1
        self.counter = 0
        
        
    def on_epoch_end(self, epoch, logs=None):
        if epoch ==self.nbepochs and logs.get('accuracy')<= self.acc:
            self.model.stop_training = True
        
        if logs.get('loss')<=self.finished:
            self.model.stop_training = True
        
        if logs.get('loss') == self.previous_loss:
            self.counter += 1
            if self.counter >= self.converged:
                self.model.stop_training = True
        else:
            self.counter = 0
            self.previous_loss = logs.get('loss')
            
