import torch
from transformers import BertTokenizerFast, BertForTokenClassification

import json
import argparse
from tqdm import tqdm

from custom_datasets import RelationExtractionDataset

def parse_args():
    parser =argparse.ArgumentParser()
    parser.add_argument('--ds_name',
                        type=str,
                        default='mquake',
                        choices=['kebench', 'mquake', 'rippleedit', 'webnlg'])
    parser.add_argument('--path_to_ds_json',
                        type=str,
                        default='../data/KEBench/multi_hop_knowledge.json')
    parser.add_argument('--path_to_model_pt',
                        type=str,
                        default='./r_extractor.pt')
    parser.add_argument('--ds_start',
                        type=int,
                        default=0)
    parser.add_argument('--ds_end',
                        type=int,
                        default=-1)
    parser.add_argument('--use_gpu',
                        action='store_true')
    parser.add_argument('--gpu_num',
                        type=int,
                        default=0)
    args = parser.parse_args()

    args.device = f'cuda:{args.gpu_num}' if args.use_gpu else 'cpu'
    return args

def main(args):
    tokenizer = BertTokenizerFast.from_pretrained('bert-base-cased', use_fast=False, model_max_length=512)
    model = BertForTokenClassification.from_pretrained(args.path_to_model_pt, num_labels=5).to(args.device)

    with open(args.path_to_ds_json, 'r') as file:
        ds = json.load(file)[args.ds_start : args.ds_end]
    
    if args.ds_name in ['kebench', 'webnlg']:
        manual_evaluation(model, tokenizer, ds, args)
    elif args.ds_name in ['mquake', 'rippleedit']:
        automated_evaluation(model, tokenizer, ds, args)

def manual_evaluation(model, tokenizer, ds, args):
    for c in ds:
        prompts = None

        if args.ds_name == 'kebench':
            prompts = c['two_hop_question']
        elif args.ds_name == 'webnlg':
            prompts = c['questions'][0]['text']
            if int(c['size']) > 4:
                continue
            print(int(c['size']))

        tokens = tokenizer(prompts, padding=False, return_tensors='pt').to(args.device)

        output = model(**tokens)

        print(f'Tokens: {[tokenizer.decode(t) for t in tokens.input_ids[0]]}')
        print(f'Output Predictions: {torch.argmax(output.logits, dim=-1)}') 

def automated_evaluation(model, tokenizer, ds, args):
    test_ds = RelationExtractionDataset(data=ds, 
                                        tokenizer=tokenizer, 
                                        device=args.device, 
                                        ds_name=args.ds_name,
                                        num_classes=5)
    testing_loader = torch.utils.data.DataLoader(test_ds, batch_size=128, shuffle=True)
    
    token_acc = 0
    token_total = 0
    sequence_acc = 0
    sequence_total = 0

    correct_total = {0 : 0, 1 : 0, 2 : 0, 3 : 0, 4 : 0}

    tp = {0 : 0, 1 : 0, 2 : 0, 3 : 0, 4 : 0}
    fp = {0 : 0, 1 : 0, 2 : 0, 3 : 0, 4 : 0}
    fn = {0 : 0, 1 : 0, 2 : 0, 3 : 0, 4 : 0}
    
    for data in tqdm(testing_loader, desc='Testing...'):
        labels = data.pop('label')
        labels = torch.argmax(labels, dim=-1).to(args.device)

        outputs = model(data['input_ids'].to(args.device), 
                        token_type_ids=None, 
                        attention_mask=data['attention_mask'], 
                        labels=labels)

        preds = torch.argmax(outputs.logits, dim=-1)


        for i  in range(preds.shape[0]):
            batch_token_acc = (preds[i] == labels[i]).sum()

            token_acc += batch_token_acc
            token_total += labels.shape[-1]

            sequence_total += 1
            if batch_token_acc == labels.shape[-1]:
                sequence_acc += 1
            
            binary_acc = preds[i] == labels[i]
            for j in range(5):
                correct_total[j] += (labels[i] == j).sum()
                tp[j] += ((labels[i] == j) * (preds[i] == j)).sum()
                fp[j] += (~(labels[i] == j) * (preds[i] == j)).sum()
                fn[j] += ((labels[i] == j) * ~(preds[i] == j)).sum()

    print(f'\nTest Token Accuracy          : {token_acc} / {token_total} or {token_acc / token_total * 100:.2f}%')
    print(f'Test Sequence Accuracy       : {sequence_acc} / {sequence_total} or {sequence_acc / sequence_total * 100:.2f}%')
    print(f'Proportion of Correct Classes: 0 - {tp[0] / correct_total[0] * 100:.2f}%, 1 - {tp[1] / correct_total[1] * 100:.2f}%, 2 - {tp[2] / correct_total[2] * 100:.2f}%, 3 - {tp[3] / correct_total[3] * 100:.2f}%, 4 - {tp[4] / correct_total[4] * 100:.2f}%')

    for i in range(5):
        print(f'F1-Score for Class {i}: {(2 * tp[i]) / (2 * tp[i] + fp[i] + fn[i])}') 

if __name__ == '__main__':
    main(parse_args()) 
