﻿import numpy
import matplotlib.pyplot
matplotlib.use('Agg')


class SamplePlotter():

    def __init__(self, database, config, outputDir):
        self._config = config.attributes
        self._outputDir = outputDir
        self._markerSize = 6
        self.defineStepIndex(database)

        matplotlib.pyplot.rcParams['font.family'] = 'Times New Roman'
        matplotlib.pyplot.rcParams['font.size']   = 14


    def plotMapping(self, database, activeNetIds, zeroBasedEpoch):
        if self._config['input']['xDimensions'] == 1 and self._config['input']['yDimensions'] == 1:
            self.plotMapping2d(database, activeNetIds, zeroBasedEpoch)
        else:
            return


    def plotMapping2d(self, database, activeNetIds, zeroBasedEpoch):
        colors = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red', 'tab:purple', 'tab:brown', 'tab:pink', 'tab:gray', 'tab:olive', 'tab:cyan']
        nColors = len(colors)

        samplesX, samplesY = database.getAllTrainingSamples()
        matplotlib.pyplot.scatter(samplesX[0:-1:self._stepIndex], samplesY[0:-1:self._stepIndex], color='k', s=0.35*self._markerSize, alpha=0.3, label='Function')

        for netId in activeNetIds:
            samplesMappedToNet = database.getTrainingSamplesMappedToNet(netId)[0]
            predictionsMappedToNet = database.getPredictionMappedToNet(netId)

            samplesNotMappedToNet = database.getTrainingSamplesNotMappedToNet(netId)[0]
            predictionsNotMappedToNet = database.getPredictionNotMappedToNet(netId)

            color = colors[numpy.mod(netId+1, nColors)]
            matplotlib.pyplot.scatter(samplesMappedToNet   [0:-1:self._stepIndex], predictionsMappedToNet   [0:-1:self._stepIndex], s=self._markerSize, color=color, alpha=1, label='Net '+str(netId))
            matplotlib.pyplot.scatter(samplesNotMappedToNet[0:-1:self._stepIndex], predictionsNotMappedToNet[0:-1:self._stepIndex], s=self._markerSize, color=color, alpha=0.007)

        matplotlib.pyplot.xticks([-1, -0.5, 0, 0.5, 1])
        matplotlib.pyplot.yticks([-1, -0.5, 0, 0.5, 1])
        if   zeroBasedEpoch == 0:
            matplotlib.pyplot.title('Mapping at start (after epoch 0)')
        elif zeroBasedEpoch == 1000:
            matplotlib.pyplot.title('Mapping at end (after epoch 1000)')
        else:
            matplotlib.pyplot.title('Mapping after epoch ' + str(zeroBasedEpoch))
        legend = matplotlib.pyplot.legend(loc='lower left', framealpha=0.7, markerscale=3)
        for lh in legend.legend_handles:
            lh._alpha = 1
        matplotlib.pyplot.savefig(self._outputDir + 'Mapping_after_epoch_' + str(zeroBasedEpoch) + '.png')
        matplotlib.pyplot.savefig(self._outputDir + 'Mapping_after_epoch_' + str(zeroBasedEpoch) + '.svg')
        matplotlib.pyplot.close()


    def defineStepIndex(self, database):
        samplesToPlot = 5000
        self._stepIndex = numpy.maximum(int(round(database.getAllTrainingSamples()[0].shape[0]/samplesToPlot)), 1)
        