"""
Evaluate SVMNLU models on Camrest test dataset

Metric:
    dataset level Precision/Recall/F1

Usage:
    PYTHONPATH=../../../.. python evaluate.py [usr|sys|all]
"""
import json
import random
import sys
import zipfile

import numpy
import torch

from convlab.nlu.svm.camrest import SVMNLU

seed = 2019
random.seed(seed)
numpy.random.seed(seed)
torch.manual_seed(seed)


def da2triples(dialog_act):
    triples = []
    for intent, svs in dialog_act.items():
        for slot, value in svs:
            triples.append([intent, slot, value])
    return triples


if __name__ == '__main__':
    if len(sys.argv) != 2:
        print("usage:")
        print("\t python evaluate.py mode")
        print("\t mode=usr|sys|all")
        sys.exit()
    mode = sys.argv[1]
    if mode== 'usr':
        model = SVMNLU(mode='usr')
    elif mode== 'sys':
        model = SVMNLU(mode='sys')
    elif mode== 'all':
        model = SVMNLU(mode='all')
    else:
        raise Exception("Invalid mode")

    archive = zipfile.ZipFile('../../../../data/camrest/test.json.zip', 'r')
    test_data = json.load(archive.open('test.json'))
    TP, FP, FN = 0, 0, 0
    sen_num = 0
    sess_num = 0
    for dialog in test_data:
        sess_num += 1
        if sess_num%10==0:
            print('Session [%d|%d]' % (sess_num, len(test_data)))
            precision = 1.0 * TP / (TP + FP)
            recall = 1.0 * TP / (TP + FN)
            F1 = 2.0 * precision * recall / (precision + recall)
            print('Model on {} session {} sentences:'.format(sess_num, sen_num))
            print('\t Precision: %.2f' % (100 * precision))
            print('\t Recall: %.2f' % (100 * recall))
            print('\t F1: %.2f' % (100 * F1))
        for turn in dialog['dial']:
            if mode == 'usr' or mode == 'all':
                sen_num += 1
                labels = da2triples(turn['usr']['dialog_act'])
                predicts = model.predict(turn['usr']['transcript'])
                for triple in predicts:
                    if triple in labels:
                        TP += 1
                    else:
                        FP += 1
                for triple in labels:
                    if triple not in predicts:
                        FN += 1
            if mode == 'sys' or mode == 'all':
                sen_num += 1
                labels = da2triples(turn['sys']['dialog_act'])
                predicts = model.predict(turn['sys']['sent'])
                for triple in predicts:
                    if triple in labels:
                        TP += 1
                    else:
                        FP += 1
                for triple in labels:
                    if triple not in predicts:
                        FN += 1
    print(TP,FP,FN)
    precision = 1.0 * TP / (TP + FP)
    recall = 1.0 * TP / (TP + FN)
    F1 = 2.0 * precision * recall / (precision + recall)
    print('Model on {} session {} sentences data_key={}'.format(len(test_data), sen_num, mode))
    print('\t Precision: %.2f' % (100 * precision))
    print('\t Recall: %.2f' % (100 * recall))
    print('\t F1: %.2f' % (100 * F1))
