import torch
from transformers import BertTokenizerFast, BertForTokenClassification

import json
import random
import argparse
from tqdm import tqdm

from custom_datasets import RelationExtractionDataset

def parse_args():
    parser =argparse.ArgumentParser()
    parser.add_argument('--num_epochs',
                        type=int,
                        default=2)
    parser.add_argument('--train_test_split',
                        type=float,
                        default=0.8)
    parser.add_argument('--train_batch_size',
                        type=int,
                        default=32)
    parser.add_argument('--test_batch_size',
                        type=int,
                        default=128)
    parser.add_argument('--lr',
                        type=float,
                        default=1e-4)
    parser.add_argument('--dataset_name',
                        type=str,
                        default='mquake',
                        choices=['mquake', 'rippleedit'])
    parser.add_argument('--dataset_path',
                        type=str,
                        default='../data/RelationBERT/MQuAKE-CF-3k-V2_train.json')
    parser.add_argument('--save_path',
                        type=str,
                        default='./r_extractor.pt')
    parser.add_argument('--use_gpu',
                        action='store_true')
    parser.add_argument('--gpu_num',
                        type=int,
                        default=0)
    args = parser.parse_args()

    args.device = f'cuda:{args.gpu_num}' if args.use_gpu else 'cpu'
    return args

def main(args):
    num_classes = 5 if args.dataset_name == 'mquake' else 3

    tokenizer = BertTokenizerFast.from_pretrained('bert-base-cased', use_fast=False, model_max_length=512)
    model = BertForTokenClassification.from_pretrained('bert-base-cased', num_labels=num_classes).to(args.device)
    
    with open(args.dataset_path, 'r') as file:
        file = json.load(file)
        random.shuffle(file)
        train_file = file[:int(args.train_test_split*len(file))]
        test_file = file[int(args.train_test_split*len(file)):]
        train_ds = RelationExtractionDataset(data=train_file, 
                                             tokenizer=tokenizer, 
                                             device=args.device, 
                                             ds_name=args.dataset_name,
                                             num_classes=num_classes)
        test_ds = RelationExtractionDataset(data=test_file, 
                                            tokenizer=tokenizer, 
                                            device=args.device, 
                                            ds_name=args.dataset_name,
                                            num_classes=num_classes)

        training_loader = torch.utils.data.DataLoader(train_ds, batch_size=args.train_batch_size, shuffle=True)
        testing_loader = torch.utils.data.DataLoader(test_ds, batch_size=args.test_batch_size, shuffle=True)

    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)

    for _ in tqdm(range(args.num_epochs), desc='Epoch Count'):
        for data in tqdm(training_loader, desc='Batch Count'):
            optimizer.zero_grad()

            labels = data.pop('label')
            labels = torch.argmax(labels, dim=-1).to(args.device)

            model.zero_grad()        

            outputs = model(data['input_ids'].to(args.device), 
                            token_type_ids=None, 
                            attention_mask=data['attention_mask'], 
                            labels=labels)
            
            outputs.loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            
        token_acc = 0
        token_total = 0
        sequence_acc = 0
        sequence_total = 0

        tp = {0 : 0, 1 : 0, 2 : 0, 3 : 0, 4 : 0}
        fp = {0 : 0, 1 : 0, 2 : 0, 3 : 0, 4 : 0}
        fn = {0 : 0, 1 : 0, 2 : 0, 3 : 0, 4 : 0}

        for data in tqdm(testing_loader, desc='Testing...'):
            labels = data.pop('label')
            labels = torch.argmax(labels, dim=-1).to(args.device)

            outputs = model(data['input_ids'].to(args.device), 
                            token_type_ids=None, 
                            attention_mask=data['attention_mask'], 
                            labels=labels)

            preds = torch.argmax(outputs.logits, dim=-1)

            for i  in range(preds.shape[0]):
                batch_token_acc = (preds[i] == labels[i]).sum()

                token_acc += batch_token_acc
                token_total += labels.shape[-1]

                sequence_total += 1
                if batch_token_acc == labels.shape[-1]:
                    sequence_acc += 1
                
                for j in range(5):
                    tp[j] += ((labels[i] == j) * (preds[i] == j)).sum()
                    fp[j] += (~(labels[i] == j) * (preds[i] == j)).sum()
                    fn[j] += ((labels[i] == j) * ~(preds[i] == j)).sum()
        
        print(f'\nTest Token Accuracy   : {token_acc} / {token_total} or {token_acc / token_total * 100:.2f}%')
        print(f'Test Sequence Accuracy: {sequence_acc} / {sequence_total} or {sequence_acc / sequence_total * 100:.2f}%')

        for i in range(5):
            print(f'F1-Score for Class {i}: {(2 * tp[i]) / (2 * tp[i] + fp[i] + fn[i])}') 

    model.save_pretrained(args.save_path)

if __name__ == '__main__':
    main(parse_args()) 