import argparse
from torch import optim
import torch.nn.functional as F
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from transformers import BertTokenizer
from transformers import AdamW
from transformers.optimization import get_cosine_schedule_with_warmup
from sklearn.neighbors import LocalOutlierFactor
from sklearn.metrics import confusion_matrix, f1_score, accuracy_score
import matplotlib.pyplot as plt

import dataloader
import BERT
import utill
import labelsmoothing

def plot_logits(args, klr, logits, threshold):
    '''
    klr : known label ratio * 100 (ex. 25.0, 50.0, 75.0) -> type float
    logits : [0]: correct logits, [1] : wrong logits -> type list
    threshold -> type float
    '''
    plt.hist(logits[0], bins=20, alpha=0.5, label='known')
    plt.hist(logits[1], bins=20, alpha=0.5, label='ood')
    plt.xlabel('logit', fontsize=18)
    plt.ylabel('count', fontsize=18)
    
    plt.axvline(x=threshold, linewidth=2, label='Threshold', color='r') 
    plt.title('logit plot KLR '+str(klr)+'/ smoothing '+str(args.smoothing))
    plt.legend(loc = 'upper left')
    
    plt.savefig(f'./result/{args.dataset}/KLR'+str(klr)+'_seed_'+str(args.seed)+'smoothing_'+str(args.smoothing)+'_train_logit_plot.png')
    
    return 0

def train(train_dataloader, model, optimizer, device, criterion):
    model.train() 
    
    train_loss, train_acc = 0, 0
    
    for batch in tqdm(train_dataloader, desc='Batch'):
        optimizer.zero_grad()
        
        batch = tuple(t.to(device) for t in batch)
        input_ids, attention_mask, label_id = batch
        
        # output = batch_size * num_class
        output = model(input_ids, attention_mask) # output == logits
            
        # CrossEntropyLoss
        loss = criterion(output, label_id.view(-1))
        
        # check loss is Nan
        if not torch.isfinite(loss):
            print('WARNING: non-finite loss, ending training ')
            exit(1)

        train_loss += loss.item()
        acc = (torch.argmax(output, -1) == label_id.squeeze()).sum().item()
        train_acc += acc / label_id.size(0)
        
        loss.backward()
        optimizer.step()       
        
    return train_loss / len(train_dataloader), train_acc / len(train_dataloader)

def test(test_dataloader, model, device, criterion):
    model.eval()
    
    Threshold = 1e8 
    t_list = []
    test_loss, test_acc = 0, 0
    
    with torch.no_grad():
        for batch in test_dataloader:
            batch = tuple(t.to(device) for t in batch)
            input_ids, attention_mask, label_id = batch
            output = model(input_ids, attention_mask)

            # CrossEntropyLoss
            loss = criterion(output, label_id.view(-1))

            test_loss += loss.item()
            acc = (torch.argmax(output, -1) == label_id.squeeze()).sum().item()
            test_acc += acc / label_id.size(0)
            
            # find minimum threshold in valid dataset 
            max_val, max_idx = torch.max(output, 1)
            
            # set threshold to average of logits vector
            if max_idx.tolist() == label_id.tolist():
                t_list.append(max_val) # correct avg
    
    if len(t_list) == 0:
        Threshold = 0
    else:   
        Threshold = sum(t_list) / len(t_list)
        
    return test_loss / len(test_dataloader), test_acc / len(test_dataloader), Threshold
    
def ood_test(data, test_dataloader, model, device, criterion, threshold):
    '''
    Args:
        data (class) : Data class Object
        test_dataloader : Test dataloader
        threshold (float) : Criteria score for ood rejection from test Function
    
    Return:
        inclass_result (dict) : Count of correct and wrong in known class
        oodclass_result (dict) : Count of correct and wrong in unknown class
        total_result (dict) : Based on the threshold, the model predicts rejecting unknown class(OOD intent) +1 correct or did not +1 wrong.
        total_logit (list) : Based on the threshold, save known and ood logit, respectively.
        total_y_true (list) : Test dataset's labels for F1 score
        total_y_pred (list) : Model inference labels of test_dataset for F1 score
    '''
    
    model.eval()
    
    print(f"Threshold : {threshold}")
        
    total_result = {'correct':0, 'wrong':0}
    total_logit = [[],[]] # [0] : known logit, [1] : ood logit
    
    inclass_result = {'correct':0, "wrong":0}
    oodclass_result = {'correct':0, "wrong":0}
    
    # except ood label in known label id
    known_label_id_except_ood = sorted(data.known_label_id)
    ood_label = known_label_id_except_ood[-1]
    known_label_id_except_ood = known_label_id_except_ood[:-1]
    
    total_y_true, total_y_pred = [], []
    
    with torch.no_grad():
        for batch in test_dataloader:
            batch = tuple(t.to(device) for t in batch)
            input_ids, attention_mask, label_id = batch
            output = model(input_ids, attention_mask)
            
            max_val, max_idx = torch.max(output, 1)
            
            total_y_true.extend(label_id.tolist())
            
            # Threshold criteria for logit
            if max_val >= threshold: # pass
                total_y_pred.extend(max_idx.tolist())
                
                if label_id in known_label_id_except_ood: # known label
                    total_result['correct'] += 1
                    total_logit[0].extend(max_val.tolist())
                    
                else: # ood label
                    total_result['wrong'] += 1
                    total_logit[1].extend(max_val.tolist())
                                        
            else: # reject
                total_y_pred.append(ood_label)
                
                if label_id == ood_label: # ood label
                    total_result['correct'] += 1
                    total_logit[1].extend(max_val.tolist())
                    
                else: # known label
                    total_result['wrong'] += 1
                    total_logit[0].extend(max_val.tolist())
                
            
            # Label criteria
            # in class
            if label_id in known_label_id_except_ood:                
                if max_idx.tolist() == label_id.tolist():
                     inclass_result['correct'] += 1
                else:
                    inclass_result['wrong'] += 1                   
            
            # ood class
            else:                
                if max_idx.tolist() == label_id.tolist():
                    oodclass_result['correct'] += 1
                else:
                    if max_val < threshold:
                        oodclass_result['correct'] += 1
                    else:
                        oodclass_result['wrong'] += 1
                    
    
    return inclass_result, oodclass_result, total_result, total_logit, total_y_true, total_y_pred


if __name__ == '__main__':
    
    parser = argparse.ArgumentParser()
    parser.add_argument('--known-class-ratio', default=0.25, type=float)
    parser.add_argument('--train-batch-size', default=128, type=int)
    parser.add_argument('--test-batch-size', default=1, type=int)
    parser.add_argument('--epochs', default=100, type=int)
    parser.add_argument('--mode', default='test', type=str)
    parser.add_argument('--dataset', default='banking', type=str, choices=['banking','stackoverflow','oos'])
    parser.add_argument('--seed', default=0, type=int)
    parser.add_argument('--label-smoothing', default=False)
    parser.add_argument('--smoothing', default=0.0, type=float)
    args = parser.parse_args()
        
    print('---load tokenizer & pretrained BERT---')

    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    print('---Done---')

    print('---data preprocessing---')
    
    data = dataloader.Data(args, tokenizer)

    train_dataloader = data.train_dataset
    valid_dataloader = data.valid_dataset
    test_dataloader = data.test_dataset
    
    print('total label cnt : {}'.format(len(data.all_label_id)))
    print('known label id : {}'.format(str(data.known_label_id)))
    
    print('---Done---')
    
    if args.mode == 'train':
        # use huggingface pretrained model
        model = BERT.BERT(len(data.all_label_id))
        model.cuda()
    
    else:
        # use my saved model
        model_PATH = f'./result/{args.dataset}/model_KLR_'+str(args.known_class_ratio)+'_seed_'+str(args.seed)+'smoothing_'+str(args.smoothing)+'.pt'
        model = BERT.BERT(len(data.all_label_id))
        model.load_state_dict(torch.load(model_PATH))
        model.cuda()
    
    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]

    optimizer = AdamW(optimizer_grouped_parameters, lr=2e-5)
    t_total = len(train_dataloader) * args.epochs
    
    warmup_ratio = 0.1
    warmup_step = int(t_total * warmup_ratio)
    scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_step, num_training_steps=t_total)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')    
    
    
    if not args.label_smoothing:
        # CrossEntropy
        criterion = torch.nn.CrossEntropyLoss().to(device)
    else: 
        criterion = labelsmoothing.LabelSmoothingLoss(len(data.all_label_id), args.smoothing).to(device)
        
    if args.mode == 'train':
    
        # train
        for _epoch in tqdm(range(args.epochs), desc='Epoch'):
            train_loss, train_acc = train(train_dataloader, model, optimizer, device, criterion)
            valid_loss, valid_acc, threshold = test(valid_dataloader, model, device, criterion)

            print("[Epoch: %d] train loss : %5.2f | train accuracy : %5.2f" % (_epoch, train_loss, train_acc))
            print("[Epoch: %d] val loss : %5.2f | val accuracy : %5.2f" % (_epoch, valid_loss, valid_acc))

            scheduler.step()
        
        # model save
        torch.save(model.state_dict(), f'./result/{args.dataset}/model_KLR_'+str(args.known_class_ratio)+'_seed_'+str(args.seed)+'smoothing_'+str(args.smoothing)+'.pt')
        
    else:
        # test
        valid_loss, valid_acc, threshold = test(valid_dataloader, model, device, criterion)
        
        print('Threshold : ', threshold)
        
        in_result, ood_result, total_result, logits, total_y_true, total_y_pred = ood_test(data, test_dataloader, model, device, criterion, threshold)

        # plot maximum logit vetcors        
        # plot_logits(args, args.known_class_ratio * 100, logits, threshold)
                
        # Measure F1 score
        cm = confusion_matrix(total_y_true, total_y_pred)
        f1 = utill.get_F1score(cm)['F1-score']
        
        print(f"F1-score : {f1}")
        print(f"Acc : {total_result['correct']/(total_result['correct']+total_result['wrong'])}")
      