from torch.utils.data.dataloader import DataLoader
from transformers import BertTokenizer
from util.data_loader import get_loader
from util.data_loader import FewShotNERDataset, NERDataset
from util.framework import FewShotNERFramework
from util.word_encoder import BERTWordEncoder
from transformers import AdamW
from model.proto import Proto
from model.bigproto import BigProto
#from model.nnshot import NNShot
import sys
import torch
from torch import optim, nn
# from torch.utils.tensorboard import SummaryWriter
import numpy as np
import json
import argparse
import os
import torch
import random
from util.fewshotsampler import FewshotSampler
from util.util import prepare_initial_tensor_supportset, get_tag_label_mapping, load_data_from_file
from tqdm import tqdm



def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--mode', default='inter',
            help='training mode, must be in [inter, intra, supervised]')
    parser.add_argument('--trainN', default=5, type=int,
            help='N in train')
    parser.add_argument('--N', default=5, type=int,
            help='N way')
    parser.add_argument('--K', default=1, type=int,
            help='K shot')
    parser.add_argument('--Q', default=1, type=int,
            help='Num of query per class')
    parser.add_argument('--batch_size', default=1, type=int,
            help='batch size')
    parser.add_argument('--eval_batch_size', default=1, type=int,
            help='batch size')
    parser.add_argument('--train_iter', default=600, type=int,
            help='num of iters in training')
    parser.add_argument('--val_iter', default=100, type=int,
            help='num of iters in validation')
    parser.add_argument('--test_iter', default=500, type=int,
            help='num of iters in testing')
    parser.add_argument('--val_step', default=20, type=int,
           help='val after training how many iters')
    parser.add_argument('--model', default='proto',
            help='model name, must be proto or bigproto')
    parser.add_argument('--max_length', default=100, type=int,
           help='max length')
    parser.add_argument('--lr', default=1e-4, type=float,
           help='learning rate')
    parser.add_argument('--proto_lr', default=1e-4, type=float,
           help='learning rate')
    parser.add_argument('--support_lr', default=1e-5, type=float,
           help='learning rate')
    parser.add_argument('--support_proto_lr', default=1e-5, type=float,
           help='learning rate')
    parser.add_argument('--grad_iter', default=1, type=int,
           help='accumulate gradient every x iterations')
    parser.add_argument('--load_ckpt', default=None,
           help='load ckpt')
    parser.add_argument('--save_ckpt', default=None,
           help='save ckpt')
    parser.add_argument('--fp16', action='store_true',
           help='use nvidia apex fp16')
    parser.add_argument('--only_test', action='store_true',
           help='only test')
    parser.add_argument('--ckpt_name', type=str, default='',
           help='checkpoint name.')
    parser.add_argument('--seed', type=int, default=0,
           help='random seed')
    parser.add_argument('--ignore_index', type=int, default=-1,
           help='label index to ignore when calculating loss and metrics')
    parser.add_argument('--train_on_support', action='store_true',
           help='whether to train on support set while testing')
    parser.add_argument('--eval_train_round_num', default=1, type=int,
           help='number of training round on support set when evaluating')
    parser.add_argument('--full_supervised_train', action='store_true', 
           help='whether use supervised training')


    # only for bert / roberta
    parser.add_argument('--pretrain_ckpt', default=None,
           help='bert / roberta pre-trained checkpoint')

    # only for prototypical networks
    parser.add_argument('--dot', action='store_true', 
           help='use dot instead of L2 distance for proto')

    # only for structshot
#     parser.add_argument('--tau', default=0.05, type=float,
       #     help='StructShot parameter to re-normalizes the transition probabilities')

    # experiment
    parser.add_argument('--use_sgd_for_bert', action='store_true',
           help='use SGD instead of AdamW for BERT.')

    opt = parser.parse_args()
    trainN = opt.trainN
    N = opt.N
    K = opt.K
    Q = opt.Q
    batch_size = opt.batch_size
    model_name = opt.model
    max_length = opt.max_length

    print("{}-way-{}-shot Few-Shot NER".format(N, K))
    print("model: {}".format(model_name))
    print("max_length: {}".format(max_length))
    print('mode: {}'.format(opt.mode))
    print(vars(opt))

    set_seed(opt.seed)
    print('loading model and tokenizer...')
    pretrain_ckpt = opt.pretrain_ckpt or 'bert-base-uncased'
    word_encoder = BERTWordEncoder(
            pretrain_ckpt)
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    print('loading data...')
    opt.train = f'data/{opt.mode}/train.txt'
    opt.test = f'data/{opt.mode}/test.txt'
    opt.dev = f'data/{opt.mode}/dev.txt'
    if not (os.path.exists(opt.train) and os.path.exists(opt.dev) and os.path.exists(opt.test)):
        os.system(f'bash data/download.sh {opt.mode}')
    # train_dataset = FewShotNERDataset(opt.train, tokenizer, N=trainN, K=K, Q=Q, max_length=max_length)
    # train_tag2ind = train_dataset.glb_tag2ind
    trainsamples = load_data_from_file(opt.train)
    tag2label, label2tag = get_tag_label_mapping(trainsamples)
    print(tag2label)
    if opt.full_supervised_train:
        print('fully supervised training')
        traindataset = NERDataset(opt.train, tokenizer, max_length, ignore_label_id=opt.ignore_index, tag2label=tag2label, samples=trainsamples)
        train_data_loader = DataLoader(traindataset, batch_size=16, shuffle=True, pin_memory=True, num_workers=8, 
        collate_fn=lambda data: {k:torch.cat([d[k]for d in data], 0) for k in data[0]})
        print('done with training set')

    else:
        train_data_loader = get_loader(opt.train, tokenizer,
            N=trainN, K=K, Q=Q, batch_size=batch_size, max_length=max_length, ignore_index=opt.ignore_index)
    val_data_loader = get_loader(opt.dev, tokenizer,
            N=N, K=K, Q=Q, batch_size=opt.eval_batch_size, max_length=max_length, ignore_index=opt.ignore_index)
    test_data_loader = get_loader(opt.test, tokenizer,
            N=N, K=K, Q=Q, batch_size=opt.eval_batch_size, max_length=max_length, ignore_index=opt.ignore_index)
    print('done')

        
    prefix = '-'.join([model_name, opt.mode, str(N), str(K), 'seed'+str(opt.seed)])
    if opt.dot:
        prefix += '-dot'
    if len(opt.ckpt_name) > 0:
        prefix += '-' + opt.ckpt_name
    
    if model_name == 'proto':
        print('use proto')
        model = Proto(word_encoder, dot=opt.dot, ignore_index=opt.ignore_index)
        framework = FewShotNERFramework(train_data_loader, val_data_loader, test_data_loader)
    elif model_name == 'bigproto':
        print('use bigproto')
        # model = BigProto(word_encoder, dot=opt.dot, ignore_index=opt.ignore_index, label2tag=label2tag)
        "*******"
        model = BigProto(word_encoder, dot=opt.dot, ignore_index=opt.ignore_index, label2tag=tag2label)

        framework = FewShotNERFramework(train_data_loader, val_data_loader, test_data_loader)

    if torch.cuda.is_available():
        model.cuda()
    
    if opt.full_supervised_train:
        print('init proto...')
        support_set_loader = prepare_initial_tensor_supportset(traindataset.samples, tokenizer, opt.max_length, ignore_idx=opt.ignore_index, tag2label=tag2label, K=50)
        for batch in tqdm(support_set_loader, desc='init proto'):
              model.register_buffer(batch)
        model.init_proto()

    #elif model_name == 'nnshot':
    #    print('use nnshot')
    #    model = NNShot(word_encoder, dot=opt.dot, ignore_index=opt.ignore_index)
    #    framework = FewShotNERFramework(train_data_loader, val_data_loader, test_data_loader)
    #elif model_name == 'structshot':
    #    print('use structshot')
    #    model = NNShot(word_encoder, dot=opt.dot, ignore_index=opt.ignore_index)
    #    framework = FewShotNERFramework(train_data_loader, val_data_loader, test_data_loader, N=opt.N, tau=opt.tau, train_fname=opt.train, viterbi=True)
    

    if not os.path.exists('checkpoint'):
        os.mkdir('checkpoint')
    ckpt = 'checkpoint/{}.pth.tar'.format(prefix)
    if opt.save_ckpt:
        ckpt = opt.save_ckpt
    print('model-save-path:', ckpt)



    if not opt.only_test:
        if opt.lr == -1:
            opt.lr = 2e-5

        if opt.full_supervised_train:
            framework.train_full_supervised(model, prefix,
                load_ckpt=opt.load_ckpt, save_ckpt=ckpt,
                val_step=opt.val_step, fp16=opt.fp16,
                train_iter=opt.train_iter, warmup_step=int(opt.train_iter * 0.1), val_iter=opt.val_iter, learning_rate=opt.lr, proto_lr=opt.proto_lr, support_lr=opt.support_lr, support_proto_lr=opt.support_proto_lr, use_sgd_for_bert=opt.use_sgd_for_bert, train_on_support=opt.train_on_support, eval_train_round_num=opt.eval_train_round_num)
        else:
            framework.train(model, prefix,
                    load_ckpt=opt.load_ckpt, save_ckpt=ckpt,
                    val_step=opt.val_step, fp16=opt.fp16,
                    train_iter=opt.train_iter, warmup_step=int(opt.train_iter * 0.1), val_iter=opt.val_iter, learning_rate=opt.lr, proto_lr=opt.proto_lr, support_lr=opt.support_lr, support_proto_lr=opt.support_proto_lr, use_sgd_for_bert=opt.use_sgd_for_bert, train_on_support=opt.train_on_support, eval_train_round_num=opt.eval_train_round_num)
    else:
        ckpt = opt.load_ckpt
        if ckpt is None:
            print("Warning: --load_ckpt is not specified. Will load Hugginface pre-trained checkpoint.")
            ckpt = 'none'

    # test
    
    precision, recall, f1, fp, fn, within, outer = framework.eval(model, opt.test_iter, opt.support_lr, opt.support_proto_lr, ckpt=ckpt, train_on_support=opt.train_on_support, round_num=opt.eval_train_round_num)
    print("RESULT: precision: %.4f, recall: %.4f, f1:%.4f" % (precision, recall, f1))
    print('ERROR ANALYSIS: fp: %.4f, fn: %.4f, within:%.4f, outer: %.4f'%(fp, fn, within, outer))

if __name__ == "__main__":
    main()