import os, yaml
import numpy as np

def get_weights_sklearn(dataset, numClasses): # to test implementation for small datasets
    labelCollection = []
    for _, labels in dataset:
        labels = labels['output_1']
        labelCollection.append(labels)
    labelCollection = np.stack(labelCollection).flatten()
    from sklearn.utils.class_weight import compute_class_weight
    classes = np.arange(numClasses)
    labelCollection = labelCollection[labelCollection != 255]
    weightsClass = compute_class_weight('balanced', classes, labelCollection)

    return {classIndex: float(x) for classIndex, x in enumerate(weightsClass)}

def record_hist(dataset, numClasses):
    histCollection = []
    numSamples = 0
    for _, labels in dataset:
        labels = labels['output_1']
        hist, bins = np.histogram(labels, bins=256, range=[-0.5, 255.5])
        histCollection.append(hist)
        numSamples += 1

    print('loaded', numSamples, 'samples')
    histCollection = np.array(histCollection)
    histClasses = np.sum(histCollection, axis=0)
    print({classIndex: x for classIndex, x in enumerate(histClasses) if x > 0})
    histClasses = histClasses[:numClasses]
    weightsClass = np.sum(histClasses) / (numClasses * histClasses)

    return {classIndex: float(x) for classIndex, x in enumerate(weightsClass)}

# camvid
from datasets import Camvid

dataDir = '/fs/scratch//Data'
dataExt = '360x480'
pathDataTrain = os.path.join(dataDir, 'camvid', 'camvid-' + dataExt + '-train.npz')
pathDataValid = os.path.join(dataDir, 'camvid', 'camvid-' + dataExt + '-val.npz')
pathDataTest = os.path.join(dataDir, 'camvid', 'camvid-' + dataExt + '-test.npz')

filenameClassBalance = 'class_balance.yml'

dataset = Camvid(pathDataTrain, pathDataValid, pathTest=pathDataTest, num_layers=1, batch_size=1, merge_train_valid=True)
train = dataset.get_train_set()
numClasses = dataset.get_num_classes()
weights_camvid = record_hist(train, numClasses)
# weights_test = get_weights_sklearn(train, numClasses)
# TODO: similar, but not same
# for classIndex, weight in weights_camvid.items():
    # assert(np.isclose(weight, weights_test[classIndex]))

# cityscapes
from datasets import Cityscapes

dataDir = '/fs/scratch//OpenData/cityscapes'

dataset = Cityscapes(dataDir, num_layers=1, batch_size=1, merge_train_valid=False)
train = dataset.get_train_set()
numClasses = dataset.get_num_classes()
weights_cityscapes = record_hist(train, numClasses)

with open(filenameClassBalance, 'w') as myFile:
    yaml.dump({'camvid': weights_camvid,
        'cityscapes': weights_cityscapes}, myFile)
