﻿import numpy
import pickle
from sklearn import svm

import Database.ScalingEngine


class SVM():

    def __init__(self, database, outputDir, config):
        self._database = database
        self._outputDir = outputDir
        self._config = config.attributes

        # Fixed configuration
        self._sampleFeatureRange = (-1,1)
        kernel = 'rbf'
        regularizationParam = 9999

        self._sampleScalingEngine = Database.ScalingEngine.ScalingEngine(self._sampleFeatureRange)
        self._svm = svm.SVC(kernel=kernel, C=regularizationParam)
        self._svmSettings = {'sampleFeatureRange':  self._sampleFeatureRange}

        self._samples      = None
        self._uniqueNetIds = None
        self._labels       = None

    # --- Classification ---

    def classify(self):
        self._samples = self.defineSamples()
        self._uniqueNetIds, self._labels = self.defineLabels()
        
        if len(numpy.unique(self._labels)) > 1:
            self._svm.fit(self._samples, self._labels)
            self._parameters = numpy.sum(self._svm.n_support_)
        else:
            self._parameters = 0

    def defineSamples(self):
        databaseScaledSamples = self._database.getAllTrainingSamples()[0]
        unscaledSamples = self._database.unscaleSamplesX(databaseScaledSamples)
        svmScaledSamples = self._sampleScalingEngine.FitAndScale(unscaledSamples)
        return svmScaledSamples

    def defineLabels(self):
        ids = self._database.getIds()
        uniqueIds = numpy.unique(ids)
        labels = numpy.zeros_like(ids)
        for idIndex, uniqueId in enumerate(uniqueIds):
            labels[numpy.where(ids==uniqueId)] = idIndex 
        return uniqueIds, labels

    # --- Classification

    # --- Plot ---

    def predictLabels(self, samples):
        if len(numpy.unique(self._labels)) > 1:
            return self._svm.predict(samples)
        else:
            return numpy.full_like(samples, 0)

    def unscaleSamples(self, scaledSamples):
        return self._sampleScalingEngine.Unscale(scaledSamples)
    
    def countOccurrences(self):
        labels = self.predictLabels(self._database.getAllTrainingSamples()[0])
        occurrences = []
        for uid in numpy.unique(labels):
            occurrences.append((labels == uid).sum())
        return [occurence / labels.size for occurence in occurrences]        


    def getSampleFeatureRange(self):
        return self._sampleFeatureRange

    def getAccuracy(self):
        return self._svm.score(self._samples, self._labels)
    
    def getParamters(self):
        return self._parameters

    # --- Plot ---

    # --- Export ---

    def export(self):
        # Normally the case
        if len(numpy.unique(self._labels)) > 1:
            fileName = self._outputDir + 'svm.pkl'
            with open(fileName, mode='wb+') as svmFile:
                pickle.dump(self._svm, svmFile)

            fileName = self._outputDir + 'svmSettings.pkl'
            with open(fileName, mode='wb+') as svmSettingsFile:
                pickle.dump(self._svmSettings, svmSettingsFile)

        # If only one net has survived the partition algorithm,
        # a svm makes no sense. In this case we create a textfile
        # describing the problem instead.
        else:
            fileName = self._outputDir + 'svm.txt'
            with open(fileName, mode='w+') as svmFile:
                svmFile.write('No SVM-Model created: ' + '\n\n' + 'Only one '    + 
                'feature has been detected by the partitioning algorithm. This ' + 
                'is why in this case no svm model has been created and an '      +
                'application of our sampling strategy makes no sense.')
                
        fileName = self._outputDir + 'svmAccuracy.txt'
        with open(fileName, 'w+') as svmAccuracyFile:
            if len(numpy.unique(self._labels)) > 1:
                svmAccuracyFile.write('Accuracy score: ' + str(self.getAccuracy()))   
            else:
                svmAccuracyFile.write('Accuracy score: None')
            
        occurrences = self.countOccurrences()
        fileName = self._outputDir + 'competitionStats.txt'
        with open(fileName, 'w+') as competitionStatsFile:
            competitionStatsFile.write(f'Number of experts: {len(occurrences)}\nDistribution: {occurrences}')


    def getNetIdsOrderedBySVM(self):
        return self._uniqueNetIds

    # --- Export ---
