﻿import copy
import numpy

import Database.DataSet
import Database.DataLoader
import Database.ScalingEngine


class DataBase():
    
    # --- Initialization ---
    
    def __init__(self, config):
        self._config = config.attributes
        self.loadAndScaleData()
        #self.shortDataForDebugging()
        self.initSampleSet()


    def loadAndScaleData(self):
        # Load data
        dataLoader = Database.DataLoader.DataLoader(self._config)
        unscaledTrainingData, unscaledTestData = dataLoader.getData()

        # Scale data -> scales automatically each dimension independently, if there are multiple in- or output dimensions
        self._scalingEngineX = Database.ScalingEngine.ScalingEngine(self._config['test']['featureRangeX'])
        self._scalingEngineY = Database.ScalingEngine.ScalingEngine(self._config['test']['featureRangeY'])

        self._cachedTrainingData = numpy.zeros_like(unscaledTrainingData)
        self._cachedTrainingData[:, :self._config['input']['xDimensions']] = self._scalingEngineX.FitAndScale(copy.deepcopy(unscaledTrainingData[:, :self._config['input']['xDimensions']]))
        self._cachedTrainingData[:, self._config['input']['xDimensions']:] = self._scalingEngineY.FitAndScale(copy.deepcopy(unscaledTrainingData[:, self._config['input']['xDimensions']:]))

        #self._cachedTestData = numpy.zeros_like(unscaledTestData)
        #self._cachedTestData[:, :self._config['input']['xDimensions']] = self._scalingEngineX.Scale(copy.deepcopy(unscaledTestData[:, :self._config['input']['xDimensions']]))
        #self._cachedTestData[:, self._config['input']['xDimensions']:] = self._scalingEngineY.Scale(copy.deepcopy(unscaledTestData[:, self._config['input']['xDimensions']:]))


    def shortDataForDebugging(self):
        samplesForDebugging = 100
        self._cachedTrainingData = self._cachedTrainingData[0 : -1 : int(round(self._cachedTrainingData.shape[0]/samplesForDebugging)), :]


    def initSampleSet(self):
        self._sampleSet = []
        for index in range(self._cachedTrainingData.shape[0]):
            self._sampleSet.append({'sample': {'x': self._cachedTrainingData[index, None, :self._config['input']['xDimensions']],
                                               'y': self._cachedTrainingData[index, None, self._config['input']['xDimensions']:]},
                                    
                                    'byId':   [numpy.zeros((0, self._config['input']['yDimensions'])), # 1st row: Scaled predictions of all nets ordered by ID
                                               numpy.zeros((0,)),                                      # 2nd row: Losses of all nets ordered by ID (calculated between scaled predictions and samples)
                                               numpy.zeros((0,))],                                     # 3rd row: IDs of all nets ordered by ID
                                      
                                    'byPerf': [numpy.zeros((0, self._config['input']['yDimensions'])), # 1st row: Scaled predictions of all nets ordered by Performance
                                               numpy.zeros((0,)),                                      # 2nd row: Losses of all nets ordered by Performance (calculated between scaled predictions and samples)
                                               numpy.zeros((0,))],                                     # 3rd row: IDs of all nets ordered by Performance
                                    })                                                                 # -> First Id in 3rd row is the id of the net, this sample is assigned to

    # --- Initialization ---

    # --- Manipulation ---

    def removeNet(self, netId):
        for sample in self._sampleSet:
            indexOfNet = numpy.where(sample['byId'][2] == netId)
            sample['byId'][0] = numpy.delete(arr = sample['byId'][0], obj = indexOfNet)
            sample['byId'][1] = numpy.delete(arr = sample['byId'][1], obj = indexOfNet)
            sample['byId'][2] = numpy.delete(arr = sample['byId'][2], obj = indexOfNet)
            
            indexOfNet = numpy.where(sample['byPerf'][2] == netId)
            sample['byPerf'][0] = numpy.delete(arr = sample['byPerf'][0], obj = indexOfNet)
            sample['byPerf'][1] = numpy.delete(arr = sample['byPerf'][1], obj = indexOfNet)
            sample['byPerf'][2] = numpy.delete(arr = sample['byPerf'][2], obj = indexOfNet)
     
    # --- Manipulation ---

    # ---- Submit predictions ----

    def submitPredictionsAndLosses(self, predsAndLosses, netIds):
        self.submitPredAndLossById(predsAndLosses, netIds)
        self.submitPredAndLossByPerf()


    def submitPredAndLossById(self, predsAndLosses, netIds):
        # Clear old predictions and losses ordered by ID
        for sample in self._sampleSet:
            sample['byId'][0] = numpy.zeros((0, self._config['input']['yDimensions']))
            sample['byId'][1] = numpy.zeros((0,))
            sample['byId'][2] = numpy.zeros((0,))

        # Set Predictions and Losses ordered by ID
        for netId in netIds:
            predictions = predsAndLosses[netId][0]
            losses      = predsAndLosses[netId][1]
            for index, sample in enumerate(self._sampleSet):
                sample['byId'][0] = numpy.concatenate((sample['byId'][0], predictions[index, None, :]))
                sample['byId'][1] = numpy.append(sample['byId'][1], losses[index])
                sample['byId'][2] = numpy.append(sample['byId'][2], netId)


    def submitPredAndLossByPerf(self):
        # Clear old predictions and losses by Performance
        for sample in self._sampleSet:
            sample['byPerf'][0] = numpy.zeros((0, self._config['input']['yDimensions']))
            sample['byPerf'][1] = numpy.zeros((0,))
            sample['byPerf'][2] = numpy.zeros((0,))

        # Assign new predictions and losses ordered by performance
        for sample in self._sampleSet:
            # Define indices of ordering by ascending losses
            order = numpy.argsort(sample['byId'][1])
            sample['byPerf'][0] = sample['byId'][0][order]
            sample['byPerf'][1] = sample['byId'][1][order]
            sample['byPerf'][2] = sample['byId'][2][order]

    # ---- Submit predictions ----

    # --- Unscaling ---

    def unscaleSamplesX(self, scaledSamplesX):
        return self._scalingEngineX.Unscale(copy.deepcopy(scaledSamplesX))

    def unscaleSamplesY(self, scaledSamplesY):
        return self._scalingEngineY.Unscale(copy.deepcopy(scaledSamplesY))
    
    # --- Unscaling ---

    # --- Getters ----
    
    # --- DataSets ---

    def getDataSetMappedToNet(self, netId):
        return Database.DataSet.DataSet(self._config, numpy.concatenate(self.getTrainingSamplesMappedToNet(netId), axis=1))
        
    # --- Samples ---

    def getAllTrainingSamples(self):
        return self._cachedTrainingData[:, :self._config['input']['xDimensions']], \
               self._cachedTrainingData[:, self._config['input']['xDimensions']:]

    def getTrainingSamplesMappedToNet(self, netId):
        samples = [numpy.zeros((0, self._config['input']['xDimensions'])),
                   numpy.zeros((0, self._config['input']['yDimensions'])),]
        for sample in self._sampleSet:
            if sample['byPerf'][2][0] ==  netId:
                samples[0] = numpy.concatenate((samples[0], sample['sample']['x']))
                samples[1] = numpy.concatenate((samples[1], sample['sample']['y']))
        return samples

    def getTrainingSamplesNotMappedToNet(self, netId):
        samples = [numpy.zeros((0, self._config['input']['xDimensions'])),
                   numpy.zeros((0, self._config['input']['yDimensions'])),]
        for sample in self._sampleSet:
            if sample['byPerf'][2][0] !=  netId:
                samples[0] = numpy.concatenate((samples[0], sample['sample']['x']))
                samples[1] = numpy.concatenate((samples[1], sample['sample']['y']))
        return samples

    def getTrainingSamplesAboveLossBound(self, lossBound):
        samples = [numpy.zeros((0, self._config['input']['xDimensions'])),
                   numpy.zeros((0, self._config['input']['yDimensions'])),]
        for sample in self._sampleSet:
            if sample['byPerf'][1][0] > lossBound:
                samples[0] = numpy.concatenate((samples[0], sample['sample']['x']))
                samples[1] = numpy.concatenate((samples[1], sample['sample']['y']))
        return samples

    # --- Predictions ---

    def getPredictionMappedToNet(self, netId):
        prediction = numpy.zeros((0, self._config['input']['yDimensions']))
        for sample in self._sampleSet:
            if sample['byPerf'][2][0] ==  netId:
                prediction = numpy.concatenate((prediction, sample['byPerf'][0][0, None, :]))
        return prediction

    def getPredictionNotMappedToNet(self, netId):
        prediction = numpy.zeros((0, self._config['input']['yDimensions']))
        for sample in self._sampleSet:
            if sample['byPerf'][2][0] != netId:
                netIndex = numpy.where(sample['byId'][2] == netId)[0][0]
                prediction = numpy.concatenate((prediction, sample['byId'][0][netIndex, None, :]))
        return prediction

    # --- Losses ---

    def getLoss(self):
        loss = numpy.zeros(len(self._sampleSet))
        for index, sample in enumerate(self._sampleSet):
            loss[index] = sample['byPerf'][1][0]
        return loss

    def getLossAboveLossBound(self, lossBound):
        loss = numpy.zeros(0)
        for sample in self._sampleSet:
            if sample['byPerf'][1][0] > lossBound:
                loss = numpy.append(loss, sample['byPerf'][1][0])
        return loss

    # --- IDs ---

    def getIds(self):
        ids = numpy.zeros(len(self._sampleSet))
        for index, sample in enumerate(self._sampleSet):
            ids[index] = sample['byPerf'][2][0]
        return ids

    # --- Replacability ---

    def getReplacabilityOfNet(self, netId):
        # If there is only one net, it cannot be replaced -> Replacability very high
        if self._sampleSet[0]['byId'][2].size <= 1:
            return 99

        # If there is more than one net: Calculate replacability
        lossesWithNet = 0
        lossesWithoutNet = 0

        for sample in self._sampleSet:
            lossesWithNet += sample['byPerf'][1][0]
            
            if sample['byPerf'][2][0] == netId:
                lossesWithoutNet += sample['byPerf'][1][1]
            else:
                lossesWithoutNet += sample['byPerf'][1][0]

        # If loss with net is zero (happened for small datasets), there is no need to replace a net
        if lossesWithNet == 0:
            return 99
        else:
            return lossesWithoutNet/lossesWithNet

    # --- Getters ----
