# Copyright (c) 2020, Zhouxing shi <zhouxingshichn@gmail.com>
# Licenced under the BSD 2-Clause License.

import scipy as sp

import os
if not "CUDA_VISIBLE_DEVICES" in os.environ:
    os.environ["CUDA_VISIBLE_DEVICES"] = "1"

import torch
import numpy as np
import sys, random, time, shutil, copy, nltk, json
from multiprocessing import Pool
from Logger import Logger
from Parser import Parser, update_arguments
from data_utils import load_data, get_batches, set_seeds
from Models import Transformer_pr
from Verifiers import VerifierForward, VerifierBackward, VerifierDiscrete
from eval_words import eval_words
from tqdm import tqdm

argv = sys.argv[1:]
parser = Parser().getParser()
#args, _ = parser.parse_known_args(argv)
args = parser.parse_args()
args = update_arguments(args)
set_seeds(args.seed)
data_train, data_valid, data_test, _, _ = load_data(args)
set_seeds(args.seed)
print (args)

if args.approach != '':
    args.dir = args.dir + '_' + args.approach
if args.fix_word_emb:
    args.dir = args.dir + '_fixemb'
print (args.dir)
#assert 0

#import tensorflow as tf
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()


config = tf.ConfigProto(device_count = {'GPU': 0})
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

with sess.as_default():
    target = Transformer_pr(args, data_train)
    #print (target.model)
    #print (target.pre_model)
    #assert 0
    #print (target.model.bert.embeddings.word_embeddings.weight)
    #for name, p in target.model.named_parameters():
    #    print (name, p.shape, p.abs().mean())
    #assert 0
    #print (target)
    #print (target.model)
    #assert 0

    random.shuffle(data_valid)
    random.shuffle(data_test)
    valid_batches = get_batches(data_valid, args.batch_size)
    test_batches = get_batches(data_test, args.batch_size)
    print("Dataset sizes: %d/%d/%d" % (len(data_train), len(data_valid), len(data_test)))

    #summary_names = ["loss", "accuracy"]
    #summary_num_pre = 2
    summary_names = ["loss", "accuracy", "certified radius"]
    summary_num_pre = 3

    logger = Logger(sess, args, summary_names, 1)

    print("\n")

    #if args.train:
    #    for epoch in range(args.num_epochs):
    #        print ("Epoch %d"%epoch)
    #        random.shuffle(data_train)
    #        train_batches = get_batches(data_train, args.batch_size)

    #        with tqdm(enumerate(train_batches)) as pbar:
    #            for i, batch in pbar:
    #                target.step(batch, is_train=True)
    #                pbar.set_description(xxx)

    if args.train:          
        best_acc = -1.0
        while logger.epoch.eval() <= args.num_epoches:
            print (logger.epoch.eval(), args.num_epoches)
            #print (torch.exp(target.model.bert.encoder.layer[0].attention.self.log_ratio).item())
            if args.last_noreg:
                w = target.model.classifier.weight
                print ("last norm:",np.linalg.norm(w.detach().cpu().numpy(), 2))
            random.shuffle(data_train)
            train_batches = get_batches(data_train, args.batch_size)

            for i, batch in enumerate(train_batches):
                logger.next_step(target.step(batch, logger.epoch.eval(), is_train=True)[:summary_num_pre])
            #target.save(logger.epoch.eval())                     
            logger.next_epoch()
            for batch in valid_batches:
                logger.add_valid(target.step(batch, logger.epoch.eval())[:summary_num_pre])
            val_acc = logger.save_valid(log=True)[1]
            for batch in test_batches:
                logger.add_test(target.step(batch, logger.epoch.eval())[:summary_num_pre])
            logger.save_test(log=True)
            target.scheduler.step()

            target.save(logger.epoch.eval(), is_best=False)
            print (val_acc, best_acc)
            if val_acc > best_acc:
                print ("BEST SAVED")
                best_acc = val_acc
                target.save(logger.epoch.eval(), is_best=True)
    else:
        #target.load_pretrained()
        if args.approach == '':
            from Models.modeling_withpr import BertForSequenceClassificationWithPretrain
            target.model = BertForSequenceClassificationWithPretrain.from_pretrained(target.bert_model, cache_dir='cache/bert', num_labels=args.num_labels).cuda()
        else:
            from Models.onelip_modeling_withpr import OneLipBertForSequenceClassificationWithPretrain
            target.model = OneLipBertForSequenceClassificationWithPretrain.from_pretrained(target.bert_model, cache_dir='cache/bert', num_labels=args.num_labels, approach=args.approach, last_noreg=args.last_noreg).cuda()

    data = data_valid if args.use_dev else data_test           

    if args.verify:
        print("Verifying robustness...")
        if args.method == "forward" or args.method == "ibp":
            verifier = VerifierForward(args, target, logger)
        elif args.method == "backward" or args.method == "baf":
            verifier = VerifierBackward(args, target, logger, with_pr=True)
        elif args.method == "discrete":
            verifier = VerifierDiscrete(args, target, logger)
        else:
            raise NotImplementedError("Method not implemented".format(args.method))
        verifier.run(data)
        exit(0)

    if args.word_label:
        eval_words(args, target, data_test)
        exit(0)

    # test the accuracy   
    torch.save(target.model.state_dict(), args.dir+'/test_save.ckpt')
    acc = 0
    cert_rads = []
    for batch in test_batches:
        ret = target.step(batch, logger.epoch.eval())
        acc += ret[1] * len(batch)
        # calc cert rad
        #logits = ret[2]['logits'].detach().cpu().numpy()
        logits = ret[3]['logits'].detach().cpu().numpy()
        #print (logits)
        if args.last_noreg:
            w = target.model.classifier.weight
            last_lip = np.linalg.norm(w.detach().cpu().numpy(), 2)
        else:
            last_lip = 1.0
        #raise NotImplementedError()
        #rad = (logits.max(axis=1) - logits.min(axis=1)) / np.sqrt(2) / last_lip # TODO: only for binary!!
        logits_max, indices = logits.max(axis=1), logits.argmax(axis=1)
        onehot = np.zeros_like(logits)
        onehot[np.arange(len(logits)), indices] = 1.
        logits_trunc = logits - onehot*1e6
        logits_nextmax = logits_trunc.max(axis=1)
        rad = (logits_max - logits_nextmax) / np.sqrt(2) / last_lip
        #print (logits)
        #print (rad)
        #assert 0
        #is_acc = (ret[2]['pred_labels']==ret[2]['gt_labels'])
        #rad[is_acc==False] = 0
        cert_rads.append(rad)
    acc = float(acc / len(data_test))
    print("Accuracy: {:.3f}".format(acc))
    with open(args.log, "w") as file:
        file.write("{:.3f}".format(acc))
    cert_rads = np.concatenate(cert_rads)
    print (cert_rads)
    print (len(cert_rads))
    print ("Avg Cert rads:", cert_rads.mean(), end='; ')
    print (" / ".join(['%.4f'%( (cert_rads>r).mean() ) for r in [0,0.05,0.1,0.2,0.5,1.0]]))
    print ("=================")
    print ("{:.3f}".format(acc),end='; ')
    print ("{:.6f}".format(cert_rads.mean()),end='; ')
    print (" / ".join(['%.4f'%( (cert_rads>r).mean() ) for r in [0,0.05,0.1,0.2,0.5,1.0]]))
