# Copyright (c) 2020, Zhouxing shi <zhouxingshichn@gmail.com>
# Licenced under the BSD 2-Clause License.

import scipy as sp

import os
if not "CUDA_VISIBLE_DEVICES" in os.environ:
    os.environ["CUDA_VISIBLE_DEVICES"] = "1"

import torch
import numpy as np
import sys, random, time, shutil, copy, nltk, json
from multiprocessing import Pool
from Logger import Logger
from Parser import Parser, update_arguments
from data_utils import load_data, get_batches, set_seeds
from Models import Transformer
from Verifiers import VerifierForward, VerifierBackward, VerifierDiscrete
from eval_words import eval_words
from tqdm import tqdm

argv = sys.argv[1:]
parser = Parser().getParser()
#args, _ = parser.parse_known_args(argv)
args = parser.parse_args()
args = update_arguments(args)
set_seeds(args.seed)
data_train, data_valid, data_test, _, _ = load_data(args)
set_seeds(args.seed)
print (args)

if args.approach != '':
    args.dir = args.dir + '_' + args.approach
if args.fix_word_emb:
    args.dir = args.dir + '_fixemb'
print (args.dir)
#assert 0

#import tensorflow as tf
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()


config = tf.ConfigProto(device_count = {'GPU': 0})
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

with sess.as_default():
    target = Transformer(args, data_train)
    print (target.model.bert.embeddings.word_embeddings.weight)
    print (target.model.bert.embeddings.word_embeddings.weight.norm(dim=1).mean())
    print (target.model.bert.embeddings.position_embeddings.weight)
    print (target.model.bert.embeddings.position_embeddings.weight.norm(dim=1).mean())
    print (target.model.bert.embeddings.config)
    print (target.model.bert.embeddings.config.layer_norm)
    #print (target.model.bert.embeddings.word_embeddings.weight)
    #for name, p in target.model.named_parameters():
    #    print (name, p.shape, p.abs().mean())
    #assert 0
    #print (target)
    #print (target.model)
    #assert 0

    random.shuffle(data_valid)
    random.shuffle(data_test)
    valid_batches = get_batches(data_valid, args.batch_size)
    test_batches = get_batches(data_test, args.batch_size)
    print("Dataset sizes: %d/%d/%d" % (len(data_train), len(data_valid), len(data_test)))

    #summary_names = ["loss", "accuracy"]
    #summary_num_pre = 2
    summary_names = ["loss", "accuracy", "certified radius"]
    summary_num_pre = 3

    logger = Logger(sess, args, summary_names, 1)

    print("\n")

    #if args.train:
    #    for epoch in range(args.num_epochs):
    #        print ("Epoch %d"%epoch)
    #        random.shuffle(data_train)
    #        train_batches = get_batches(data_train, args.batch_size)

    #        with tqdm(enumerate(train_batches)) as pbar:
    #            for i, batch in pbar:
    #                target.step(batch, is_train=True)
    #                pbar.set_description(xxx)

    if args.train:          
        best_acc = -1.0
        while logger.epoch.eval() <= args.num_epoches:
            print (logger.epoch.eval(), args.num_epoches)
            #print (torch.exp(target.model.bert.encoder.layer[0].attention.self.log_ratio).item())
            if args.last_noreg:
                w = target.model.classifier.weight
                print ("last norm:",np.linalg.norm(w.detach().cpu().numpy(), 2))
            random.shuffle(data_train)
            train_batches = get_batches(data_train, args.batch_size)

            for i, batch in enumerate(train_batches):
                logger.next_step(target.step(batch, logger.epoch.eval(), is_train=True)[:summary_num_pre])
            #target.save(logger.epoch.eval())                     
            logger.next_epoch()
            for batch in valid_batches:
                logger.add_valid(target.step(batch, logger.epoch.eval())[:summary_num_pre])
            val_acc = logger.save_valid(log=True)[1]
            for batch in test_batches:
                logger.add_test(target.step(batch, logger.epoch.eval())[:summary_num_pre])
            logger.save_test(log=True)
            target.scheduler.step()

            target.save(logger.epoch.eval(), is_best=False)
            print (val_acc, best_acc)
            if val_acc > best_acc:
                print ("BEST SAVED")
                best_acc = val_acc
                target.save(logger.epoch.eval(), is_best=True)
    else:
        #target.load_pretrained()
        if args.approach == '':
            from Models.modeling import BertForSequenceClassification
            target.model = BertForSequenceClassification.from_pretrained(target.bert_model, cache_dir='cache/bert', num_labels=args.num_labels).cuda()
        else:
            from Models.onelip_modeling import OneLipBertForSequenceClassification
            target.model = OneLipBertForSequenceClassification.from_pretrained(target.bert_model, cache_dir='cache/bert', num_labels=args.num_labels, approach=args.approach).cuda()

    data = data_valid if args.use_dev else data_test           

    if args.verify:
        print("Verifying robustness...")
        if args.method == "forward" or args.method == "ibp":
            verifier = VerifierForward(args, target, logger)
        elif args.method == "backward" or args.method == "baf":
            verifier = VerifierBackward(args, target, logger)
        elif args.method == "discrete":
            verifier = VerifierDiscrete(args, target, logger)
        else:
            raise NotImplementedError("Method not implemented".format(args.method))
        verifier.run(data)
        exit(0)

    if args.word_label:
        eval_words(args, target, data_test)
        exit(0)

    # test the accuracy   
    import time
    torch.save(target.model.state_dict(), args.dir+'/test_save.ckpt')
    acc = 0
    cert_rads = []
    for batch in test_batches:
        t0 = time.time()
        ret = target.step(batch, logger.epoch.eval())
        print ("Time:", time.time() - t0)
        acc += ret[1] * len(batch)
        # calc cert rad
        #logits = ret[2]['logits'].detach().cpu().numpy()
        logits = ret[3]['logits'].detach().cpu()#.numpy()
        #print (logits)
        if args.last_noreg:
            w = target.model.classifier.weight
            last_lip = np.linalg.norm(w.detach().cpu().numpy(), 2)
        else:
            last_lip = 1.0
        #if args.approach == 'onelip-softmax-v3':
        #    for layer in target.model.bert.encoder.layer:
        #        w1 = layer.attention.self.query_key.weight
        #        w2 = layer.attention.self.value.weight
        #        print ("Before:", last_lip)
        #        last_lip = last_lip * np.linalg.norm(w1.detach().cpu().numpy(), 2)
        #        last_lip = last_lip * np.linalg.norm(w2.detach().cpu().numpy(), 2)
        #        print ("After:", last_lip)
        #rad = (logits.max(axis=1) - logits.min(axis=1)) / np.sqrt(2) / last_lip # TODO: only for binary!!
        top2_pred = torch.topk(logits, 2, dim=1).values
        rad = (top2_pred[:,0] - top2_pred[:,1]) / np.sqrt(2) / last_lip
        rad = rad.numpy()
        #print (logits)
        #print (rad)
        #is_acc = (ret[2]['pred_labels']==ret[2]['gt_labels'])
        #rad[is_acc==False] = 0
        cert_rads.append(rad)
    acc = float(acc / len(data_test))
    print("Accuracy: {:.3f}".format(acc))
    with open(args.log, "w") as file:
        file.write("{:.3f}".format(acc))
    cert_rads = np.concatenate(cert_rads)
    print (cert_rads)
    print (len(cert_rads))
    print ("Avg Cert rads:", cert_rads.mean(), end='; ')
    print (" / ".join(['%.4f'%( (cert_rads>r).mean() ) for r in [0,0.05,0.1,0.2,0.5,1.0]]))
    print ("=================")
    print ("{:.3f}".format(acc),end='; ')
    print ("{:.6f}".format(cert_rads.mean()),end='; ')
    print (" / ".join(['%.4f'%( (cert_rads>r).mean() ) for r in [0,0.05,0.1,0.2,0.5,1.0]]))
    print (target.model.bert.embeddings.word_embeddings.weight.shape)
    print (target.model.bert.embeddings.word_embeddings.weight.norm(dim=1).shape)
    print (target.model.bert.embeddings.word_embeddings.weight)
    print (target.model.bert.embeddings.word_embeddings.weight.norm(dim=1).mean())
    print (target.model.bert.embeddings.position_embeddings.weight)
    print (target.model.bert.embeddings.position_embeddings.weight.norm(dim=1).mean())
    #print (target.model.bert.embeddings.token_type_embeddings.weight.norm(dim=1).mean())


    #print ("========Checking word substitute============")
    #print (list(target.tokenizer.vocab)[:20])
    #embs = target.model.bert.embeddings.word_embeddings.weight
    #embs = embs / (1e-8+embs.norm(dim=1, keepdim=True))


    #####
    ##emb1 = embs[target.tokenizer.vocab.get('movie')]
    ##emb2 = embs[target.tokenizer.vocab.get('film')]
    ##print (emb1.norm())
    ##print (emb2.norm())
    ##print ((emb1-emb2).norm())
    ##assert 0
    #####

    ##with open('./data/synonyms.json') as inf:
    ##    syns = json.load(inf)
    ##syns_toks = {}
    ##for w1 in syns:
    ##    if w1 not in target.tokenizer.vocab:
    ##        continue
    ##    idx1 = target.tokenizer.vocab.get(w1)
    ##    if idx1 not in syns_toks:
    ##        syns_toks[idx1] = []
    ##    for w2 in syns[w1]:
    ##        if w2 not in target.tokenizer.vocab:
    ##            continue
    ##        idx2 = target.tokenizer.vocab.get(w2)
    ##        if idx2 not in syns_toks[idx1]:
    ##            syns_toks[idx1].append(idx2)
    ##with open('./data/synonyms_tok.json', 'w') as outf:
    ##    json.dump(syns_toks, outf)
    #with open('./data/synonyms_tok.json') as inf:
    #    syn_toks = json.load(inf)

    #allowed_delta = []
    #for batch in test_batches:
    #    t0 = time.time()
    #    ret = target.step(batch, logger.epoch.eval())

    #    inp_ids, _, _, fs = target.get_input(batch)
    #    labs = [f.label_id for f in fs]
    #    logits = ret[3]['logits'].detach().cpu().numpy()
    #    pred_c = ret[3]['pred_labels']
    #    #print (logits)
    #    if args.last_noreg:
    #        w = target.model.classifier.weight
    #        last_lip = np.linalg.norm(w.detach().cpu().numpy(), 2)
    #    else:
    #        last_lip = 1.0
    #    rad = (logits.max(axis=1) - logits.min(axis=1)) / np.sqrt(2) / last_lip # TODO: only for binary!!
    #    for one_seq, one_rad, one_pred, one_lab in zip(inp_ids, rad, pred_c, labs):
    #        if one_pred != one_lab:
    #            allowed_delta.append(-1)
    #        else:
    #            cur_delta = 0
    #            change_val = []
    #            for tok in one_seq:
    #                tok = tok.item()
    #                if tok == 0:
    #                    break
    #                if str(tok) not in syn_toks:
    #                    continue
    #                max_diff = 0
    #                for syn in syn_toks[str(tok)]:
    #                    diff = (embs[tok] - embs[syn]).norm().item()
    #                    if diff > max_diff:
    #                        max_diff = diff
    #                change_val.append(max_diff)
    #            change_val = sorted(change_val)[::-1]
    #            while cur_delta < 10:
    #                if np.linalg.norm(change_val[:cur_delta+1]) > one_rad:
    #                    break
    #                cur_delta += 1
    #            allowed_delta.append(cur_delta)
    #allowed_delta = np.array(allowed_delta)
    #for budget in [0,1,2,3,4,5,6]:
    #    print ("Pert %d word, acc: %.4f"%(budget, (allowed_delta>=budget).mean()))
    #assert 0
