import argparse, os, yaml
from tqdm import tqdm
import numpy as np
import tensorflow as tf
import tensorflow.keras as keras
from models.model_res_tiny import IterNetwork
from utils import MeanIoUIgnore, SparseCategoricalAccuracyIgnore, SparseCategoricalCrossentropyIgnore
from data.datasets import Camvid, Cityscapes

parser = argparse.ArgumentParser(fromfile_prefix_chars='@')
parser.add_argument('-md', '--model_dir', type=str, default='/home//exp/interactive')
parser.add_argument('-sp', '--save_predictions', action='store_true')
params = dict(parser.parse_args().__dict__) #FIXME: do not use __dict__

filenameConfig = 'config.yml'

with open(os.path.join(params['model_dir'], filenameConfig), 'r') as configFile:
    params.update(yaml.safe_load(configFile))

params['batch_size'] = 1

# load data
if params['dataset'] == 'camvid':
    dataset = Camvid(params['pathDataTrain'], params['pathDataValid'], pathTest=params['pathDataTest'],
            num_layers=params['num_layers'], batch_size=params['batch_size'], crop=params['crop'], merge_train_valid=params['merge_train_valid'])
elif params['dataset'] == 'cityscapes':
    dataset = Cityscapes(params['pathData'],
            num_layers=params['num_layers'], batch_size=params['batch_size'], crop=params['crop'], merge_train_valid=params['merge_train_valid'])
else:
    raise NotImplementedError
dataTrain = dataset.get_train_set()
dataTest = dataset.get_test_set()
inputShape = dataset.get_shape()

# build model
custom_objects = {'sparse_categorical_accuracy_ignore': SparseCategoricalAccuracyIgnore(), 'miou_ignore': MeanIoUIgnore(params['numClasses']), 'sparse_categorical_crossentropy_ignore': SparseCategoricalCrossentropyIgnore(from_logits=True)}
model = IterNetwork(num_classes=params['numClasses'],
        num_blocks=params['num_blocks'], num_iterations=params['num_layers'], num_feat_base=params['num_features'], num_repeats=params['num_repeats'],
        dropout_rate=params['dropout_rate'], weight_decay=params['weight_decay'],
        downsample=params['downsample'], global_pool=params['global_pool'], conv_type=params['conv_type'], share_parameters=params['share_parameters'], shuffle=False)
model.build((None,) + inputShape)
model.summary()
metricsList = [SparseCategoricalAccuracyIgnore(), MeanIoUIgnore(params['numClasses'])]
model.compile(optimizer=keras.optimizers.SGD(0.0), loss=SparseCategoricalCrossentropyIgnore(from_logits=True), metrics=metricsList, run_eagerly=True)
# pass one batch to initialize model like suggested here:
# https://www.tensorflow.org/guide/keras/save_and_serialize
images, labels = next(iter(dataTrain))
model.train_on_batch(images, labels)

# load weights
def evaluate_model(filenameModel):
    print('loading weights from', filenameModel)
    model.load_weights(filenameModel)
    predictions = {'output_' + str(i): [] for i in range(1, params['num_layers'] + 1)}
    labels = {'output_' + str(i): [] for i in range(1, params['num_layers'] + 1)}
    mious = {'output_' + str(i): [] for i in range(1, params['num_layers'] + 1)}
    images = []
    for image, label in tqdm(dataTest, total=dataset.get_len_test(), desc='recording predictions'):
        images.append(image.numpy()[0, :]) # [0, :] due to single batch inference
        prediction = model.predict_on_batch(image)
        for outputIndex, (outputName, outputLabel) in enumerate(label.items()):
            predictions[outputName].append(np.argmax(prediction[outputIndex][0, :], -1))
            labels[outputName].append(np.squeeze(outputLabel.numpy()[0, :], axis=-1))
        for outputIndex in range(1, params['num_layers'] + 1):
            miou = MeanIoUIgnore(params['numClasses'])
            output_name = 'output_' + str(outputIndex)
            miou.update_state(label[output_name], prediction[outputIndex - 1])
            mious[output_name].append(miou.result().numpy())

    images = np.stack(images)
    predictions = {key: np.stack(value) for i, (key, value) in enumerate(predictions.items()) if params['loss_weights'][i] > 0}
    labels = np.stack(labels[list(labels.keys())[0]]) # TODO: assuming all labels to be identical, only select labels of first output
    mious = {key: value for i, (key, value) in enumerate(mious.items()) if params['loss_weights'][i] > 0}

    if params['save_predictions']:
        print('saving predictions')
        np.savez_compressed(filenameModel + '_mious', **mious)
        np.savez_compressed(filenameModel + '_images', images=images)
        np.savez_compressed(filenameModel + '_labels', labels=labels)
        np.savez_compressed(filenameModel + '_predictions', **predictions)

    return images, labels, predictions, mious

def calc_miou(labels, predictions):
    results = {}
    for outputLabel, prediction in predictions.items():
        miou = MeanIoUIgnore(params['numClasses'])
        miou.update_state(labels, prediction)
        result = miou.result().numpy()
        assert(np.isscalar(result))
        results[outputLabel] = float(result)
    return results

# checkpointFile = os.path.join(params['model_dir'], 'model')
# images, labels, predictions = evaluate_model(checkpointFile)
# miouBest = calc_miou(labels, predictions)

checkpointFileLast = os.path.join(params['model_dir'], 'model_last')
images, labels, predictions, mious = evaluate_model(checkpointFileLast)
# miouLast = calc_miou(labels, predictions)

# with open(os.path.join(params['model_dir'], 'eval_post.yml'), 'w') as myFile:
    # yaml.dump({'miou_best': miouBest,'miou_last': miouLast}, myFile)
