﻿from matplotlib.ticker import MaxNLocator
import matplotlib.pyplot
import numpy
matplotlib.use('Agg')


class EvaluationPlotter():

    def __init__(self, evaluation, outputDir, config):
        self._evaluation = evaluation
        self._outputDir = outputDir
        self._config = config.attributes

        matplotlib.pyplot.rcParams['font.family'] = 'Times New Roman'
        matplotlib.pyplot.rcParams['font.size']   = 14


    def plotEvaluation(self):
        self.plotNumOfNets()
        self.plotProcessDroppingAdding()
        self.plotDecisionBoundaryDropping()
        self.plotDecisionBoundaryAdding()
        self.plotNumOfTrainingEpochsAdding()


    def plotNumOfNets(self):
        numOfNets = self._evaluation.getNumOfNets()
        epochs = numpy.arange(start=1, stop=numOfNets.shape[0]+1)
        
        matplotlib.pyplot.figure()
        matplotlib.pyplot.gca().yaxis.set_major_locator(MaxNLocator(integer=True)) # Set yticks to only integers
        matplotlib.pyplot.plot(epochs, numOfNets, color='k', linewidth=3)
        matplotlib.pyplot.xlabel('Epoch')
        matplotlib.pyplot.ylabel('Number of nets')
        matplotlib.pyplot.title('Number of nets')
        matplotlib.pyplot.savefig(self._outputDir + 'Number_of_nets.png')
        matplotlib.pyplot.savefig(self._outputDir + 'Number_of_nets.svg')
        matplotlib.pyplot.close()


    def plotProcessDroppingAdding(self):
        if not self._config['test']['droppingActive'] and not self._config['test']['addingActive']:
            return
        
        droppingY = 1
        addingY   = 2
        
        droppingOneBasedEpochs, addingOneBasedEpochs = self._evaluation.getDroppingAddingEpochs()
        droppingMarkings = [droppingOneBasedEpochs, numpy.full_like(droppingOneBasedEpochs, droppingY)]
        addingMarkings   = [addingOneBasedEpochs,   numpy.full_like(addingOneBasedEpochs,   addingY)]

        droppingMarkings[0] = self.createJitter(droppingMarkings[0])
        addingMarkings[0]   = self.createJitter(addingMarkings[0])

        matplotlib.pyplot.scatter(droppingMarkings[0], droppingMarkings[1], c='k', marker='x')
        matplotlib.pyplot.scatter(addingMarkings[0]  , addingMarkings[1],   c='k', marker='x')
        matplotlib.pyplot.grid()
        matplotlib.pyplot.xlabel('Epoch')
        matplotlib.pyplot.ylabel('Adding and Dropping')
        matplotlib.pyplot.yticks([ droppingY,  addingY],
                                 ['Dropping', 'Adding'])
        matplotlib.pyplot.title('Adding and Dropping')
        matplotlib.pyplot.savefig(self._outputDir + 'Adding_and_dropping.png')
        matplotlib.pyplot.close()

   
    def plotDecisionBoundaryDropping(self):
        if not self._config['test']['droppingActive']:
            return

        replacabilities = self._evaluation.getDroppingReplacabilities()

        matplotlib.pyplot.figure()
        matplotlib.pyplot.yscale('log')
        matplotlib.pyplot.scatter(replacabilities[:,0], replacabilities[:,1], c='k', s=0.6)
        matplotlib.pyplot.plot([numpy.min(replacabilities[:,0]), numpy.max(replacabilities[:,0])], [self._config['test']['droppingReplacability'], self._config['test']['droppingReplacability']], color='k', label='Nets below were dropped')
        matplotlib.pyplot.grid()
        matplotlib.pyplot.xlabel('Epoch')
        matplotlib.pyplot.ylabel('Replacability')
        matplotlib.pyplot.legend(loc='upper right')
        matplotlib.pyplot.title('Decision boundary dropping')
        matplotlib.pyplot.savefig(self._outputDir + 'Decision_boundary_dropping.png')
        matplotlib.pyplot.close()


    def plotDecisionBoundaryAdding(self):
        if not self._config['test']['addingActive']:
            return
        
        improvements = self._evaluation.getAddingImprovements()

        matplotlib.pyplot.figure()
        matplotlib.pyplot.scatter(improvements[:,0], improvements[:,1], c='k', s=1.0)
        matplotlib.pyplot.plot([numpy.min(improvements[:,0]), numpy.max(improvements[:,0])], [self._config['test']['addingImprovement'], self._config['test']['addingImprovement']], color='k', label='Nets above were added')
        matplotlib.pyplot.grid()
        matplotlib.pyplot.xlabel('Epoch')
        matplotlib.pyplot.ylabel('Improvement factor')
        matplotlib.pyplot.legend(loc='upper right')
        matplotlib.pyplot.title('Decision boundary adding')
        matplotlib.pyplot.savefig(self._outputDir + 'Decision_boundary_adding.png')
        matplotlib.pyplot.close()


    def plotNumOfTrainingEpochsAdding(self):
        if not self._config['test']['addingActive']:
            return

        absoluteTrainingEpochs, relativeTrainingEpochs = self._evaluation.getTrainingEpochsAdding()

        matplotlib.pyplot.figure()

        ax1 = matplotlib.pyplot.gca()
        ax1.set_xlabel('Epoch')
        ax1.set_ylabel('Trained epochs')
        ax1.scatter(absoluteTrainingEpochs[:,0], absoluteTrainingEpochs[:,1], c='b', s=0.6)
        ax1.tick_params(axis='y', labelcolor='b')

        ax2 = ax1.twinx()
        ax2.set_ylabel('Trained epochs / Allowed epochs for training')
        ax2.scatter(relativeTrainingEpochs[:,0], relativeTrainingEpochs[:,1], c='r', s=0.6)
        ax2.tick_params(axis='y', labelcolor='r')

        matplotlib.pyplot.tight_layout()
        matplotlib.pyplot.title('Epochs of catch-up training of added nets')
        matplotlib.pyplot.savefig(self._outputDir + 'Epoch_catch_up_training_adding.png')
        matplotlib.pyplot.close()


    def createJitter(self, array):
        if array.size != 0:
            return array + numpy.random.randn(len(array)) * (0.015 * (max(array) - min(array)))
        else:
            return array
