﻿from matplotlib.ticker import MaxNLocator
import matplotlib.colors
import matplotlib.pyplot
import numpy
matplotlib.use('Agg')


class SVMPlotter():

    def __init__(self, config, database, outputDir):
        self._config    = config.attributes
        self._database  = database
        self._outputDir = outputDir

        matplotlib.pyplot.rcParams['font.family'] = 'Times New Roman'
        matplotlib.pyplot.rcParams['font.size']   = 14


    def plotSVM(self, svm, netIds):
        if   self._config['input']['xDimensions'] == 1 and self._config['input']['yDimensions'] == 1:
            self.plotSVM2d(svm, netIds)
        elif self._config['input']['xDimensions'] == 2 and self._config['input']['yDimensions'] == 1:
            self.plotSVM3d(svm, netIds)
        else:
            return


    def plotSVM2d(self, svm, netIds):
        samplesToPlot        = 10000
        sampleFeatureRange   = svm.getSampleFeatureRange()
        samples              = numpy.reshape(numpy.linspace(start=sampleFeatureRange[0], stop=sampleFeatureRange[1], num=samplesToPlot), (-1,1))
        labels               = svm.predictLabels(samples)
        uniqueLabels         = numpy.unique(labels)
        functionX, functionY = self._database.getAllTrainingSamples()

        colors = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red', 'tab:purple', 'tab:brown', 'tab:pink', 'tab:gray', 'tab:olive', 'tab:cyan']
        nColors = len(colors)

        matplotlib.pyplot.figure()
        matplotlib.pyplot.title('Partitioning')
        matplotlib.pyplot.axis('off')
        plot1 = matplotlib.pyplot.subplot2grid((4, 1), (0, 0), rowspan=3, colspan=1)
        plot2 = matplotlib.pyplot.subplot2grid((4, 1), (3, 0), rowspan=1, colspan=1)
        plot1.xaxis.set_major_locator(matplotlib.pyplot.NullLocator())
        plot2.yaxis.set_major_locator(matplotlib.pyplot.NullLocator())
        plot1.scatter(functionX, functionY, color='k', s=12)
        for uniqueLabel, netId in zip(uniqueLabels, netIds):
            color = colors[numpy.mod(netId+1, nColors)]
            for sampleIdx in range(samples.shape[0]):
                if labels[sampleIdx] == uniqueLabel:
                    plot2.scatter(samples[sampleIdx], numpy.zeros(samples[sampleIdx].size), color=color, s=20)
        matplotlib.pyplot.savefig(self._outputDir + 'Partitioning.png')
        matplotlib.pyplot.savefig(self._outputDir + 'Partitioning.svg')
        matplotlib.pyplot.close()


    def plotSVM3d(self, svm, _):
        samplesToPlotPerDim = 500
        sampleFeatureRange = svm.getSampleFeatureRange()
        [x1Grid, x2Grid] = numpy.meshgrid(numpy.reshape(numpy.linspace(start=sampleFeatureRange[0], stop=sampleFeatureRange[1], num=samplesToPlotPerDim), (-1,1)), 
                                          numpy.reshape(numpy.linspace(start=sampleFeatureRange[0], stop=sampleFeatureRange[1], num=samplesToPlotPerDim), (-1,1)))
        samples = numpy.concatenate((numpy.reshape(x1Grid, (-1,1)), numpy.reshape(x2Grid, (-1,1))), axis=1)

        labels          = svm.predictLabels(samples)
        unscaledSamples = svm.unscaleSamples(samples)

        unscaledSamplesX1 = numpy.reshape(unscaledSamples[:,0], x1Grid.shape)
        unscaledSamplesX2 = numpy.reshape(unscaledSamples[:,1], x1Grid.shape)
        labels            = numpy.reshape(labels,               x1Grid.shape)

        colorNorm = matplotlib.colors.Normalize(vmin=numpy.amin(labels), vmax=numpy.amax(labels))
        colorMap = 'viridis'#'plasma'

        matplotlib.pyplot.figure()
        matplotlib.pyplot.contourf(unscaledSamplesX1, unscaledSamplesX2, labels, cmap=colorMap, norm=colorNorm)
        matplotlib.pyplot.text(100, 50, 'Score\nof SVM-\nCluste-\nring\n' + str(round(svm.getAccuracy(), 3)))
        matplotlib.pyplot.xlabel('X1')
        matplotlib.pyplot.ylabel('X2')
        matplotlib.pyplot.title('Clustering')
        matplotlib.pyplot.savefig(self._outputDir + 'Clustering.png')
        matplotlib.pyplot.close()
