import os
import copy
import pickle

import ScalingEngine
import NetworkStrategy.ParallelNetworks.AbstractDataSeparator as AbstractDataSeparator


class PartitioningDataSeparator(AbstractDataSeparator.AbstractDataSeparator):

    def __init__(self, config, inputDir):
        self._config = config.attributes
        self._partitioning, self._partitioningSetting = self.loadSVM(inputDir)
        self._partitioningClasses = self._partitioning.classes_


    def separateData(self, data):
        separatedData = []

        labels = self.labelData(data[:, :self._config['input']['xDimensions']])
        
        for partitioningClass in self._partitioningClasses:
            separationIndices = labels == partitioningClass
            separatedData.append(data[:,][separationIndices])

        return separatedData


    def getLabels(self):
        return self._partitioningClasses


    def loadSVM(self, inputDir):
        partitioningFilePath        = inputDir + self._config['input']['partitioningFile']
        partitioningSettingFilePath = inputDir + self._config['input']['partitioningSettingFile']

        partitioningAvailable        = os.path.isfile(partitioningFilePath)
        partitioningSettingAvailable = os.path.isfile(partitioningSettingFilePath)

        if not partitioningAvailable or not partitioningSettingAvailable:
            raise IOError('Either partitioning or partitioning setting file or both could not be found.')

        partitioning        = None
        partitioningSetting = None

        with open(partitioningFilePath, mode='rb') as partitioningFile:
            partitioning = pickle.load(partitioningFile)
        with open(partitioningSettingFilePath, mode='rb') as partitioningSettingFile:
            partitioningSetting = pickle.load(partitioningSettingFile)

        return partitioning, partitioningSetting


    def labelData(self, unscaledSamplesX):
        # Scale x similar to SVM partitioning
        scalingEngineX = ScalingEngine.ScalingEngine(self._partitioningSetting['sampleFeatureRange'])
        svmScaledSamplesX = scalingEngineX.fitAndScale(copy.deepcopy(unscaledSamplesX))
        
        # Predict labels with loaded SVM
        return self._partitioning.predict(svmScaledSamplesX)
