# Copyright (c) 2020, Zhouxing shi <zhouxingshichn@gmail.com>
# Licenced under the BSD 2-Clause License.

import scipy as sp

import os
if not "CUDA_VISIBLE_DEVICES" in os.environ:
    os.environ["CUDA_VISIBLE_DEVICES"] = "1"

import torch
import numpy as np
import sys, random, time, shutil, copy, nltk, json
from multiprocessing import Pool
from Logger import Logger
from Parser import Parser, update_arguments
from data_utils import load_data, get_batches, set_seeds
from Models import Transformer_pr
from Verifiers import VerifierForward, VerifierBackward, VerifierDiscrete
from eval_words import eval_words
from tqdm import tqdm

argv = sys.argv[1:]
parser = Parser().getParser()
#args, _ = parser.parse_known_args(argv)
args = parser.parse_args()
args = update_arguments(args)
set_seeds(args.seed)
data_train, data_valid, data_test, _, _ = load_data(args)
set_seeds(args.seed)
print (args)

if args.approach != '':
    args.dir = args.dir + '_' + args.approach
if args.fix_word_emb:
    args.dir = args.dir + '_fixemb'
print (args.dir)
#assert 0

#import tensorflow as tf
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()


config = tf.ConfigProto(device_count = {'GPU': 0})
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

with sess.as_default():
    target = Transformer_pr(args, data_train)
    print (target.model.bert.encoder.layer[0].output.config)
    print (target.model.bert.encoder.layer[0].output.config.layer_norm)

    random.shuffle(data_valid)
    random.shuffle(data_test)
    valid_batches = get_batches(data_valid, args.batch_size)
    test_batches = get_batches(data_test, args.batch_size)
    print("Dataset sizes: %d/%d/%d" % (len(data_train), len(data_valid), len(data_test)))

    #summary_names = ["loss", "accuracy"]
    #summary_num_pre = 2
    summary_names = ["loss", "accuracy", "certified radius"]
    summary_num_pre = 3

    logger = Logger(sess, args, summary_names, 1)

    print("\n")

    if args.approach == '':
        from Models.modeling_withpr import BertForSequenceClassificationWithPretrain
        target.model = BertForSequenceClassificationWithPretrain.from_pretrained(target.bert_model, cache_dir='cache/bert', num_labels=args.num_labels).cuda()
    else:
        from Models.onelip_modeling_withpr import OneLipBertForSequenceClassificationWithPretrain
        target.model = OneLipBertForSequenceClassificationWithPretrain.from_pretrained(target.bert_model, cache_dir='cache/bert', num_labels=args.num_labels, approach=args.approach, last_noreg=args.last_noreg).cuda()

    data = data_valid if args.use_dev else data_test           

    class WrapModel(torch.nn.Module):
        def __init__(self, model):
            super(WrapModel, self).__init__()
            self.model = model
            self.aux = None

        def set_aux(self, aux):
            self.aux = aux

        def remove_aux(self,):
            self.aux = None

        def forward(self, emb):
            input_ids, segment_ids, input_mask = self.aux
            #if args.approach == '':
            #    emb = word_emb + position_embeddings
            #    emb = target.model.bert.embeddings.LayerNorm(emb)
            #    emb = target.model.bert.embeddings.dropout(emb)
            #else:
            #    word_emb = word_emb / (1e-8+word_emb.norm(dim=2,keepdim=True)) * 2
            #    position_embeddings = position_embeddings / (1e-8+position_embeddings.norm(dim=2,keepdim=True)) * 2
            #    emb = word_emb + position_embeddings
            emb = emb / (1e-8+emb.norm(dim=2,keepdim=True)) * 2
            #logits, _, _, _, _, _, _ = target.model(input_ids, segment_ids, input_mask, labels=None, embeddings=emb)
            logits, _, _, _, _, _ = target.model(emb, attention_mask=input_mask)
            return logits
    from attack_lib import L2PGDAttack

    EPS = 1.0
    #EPS = 1e-8
    wrap_model = WrapModel(target.model)
    adversary = L2PGDAttack(wrap_model, num_steps=10, epsilon=EPS)
    # Emp atk
    tot_acc = 0
    tot_num = 0
    for batch in test_batches:
        input_ids, input_mask, segment_ids, features = target.get_input(batch)
        label_ids = torch.tensor([f.label_id for f in features], dtype=torch.long)
        label_ids = label_ids.to(target.device)  
        ##embeddings,_ = target.get_embeddings(batch)
        ##embeddings = target.model.bert.embeddings(input_ids, segment_ids)
        #seq_length = input_ids.size(1)
        #position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
        #position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
        #position_embeddings = target.model.bert.embeddings.position_embeddings(position_ids)
        #word_embeddings = target.model.bert.embeddings.word_embeddings(input_ids)
        #if args.approach != '':
        #    word_embeddings = word_embeddings / (1e-8+word_embeddings.norm(dim=2,keepdim=True)) * 2
        #    position_embeddings = position_embeddings / (1e-8+position_embeddings.norm(dim=2,keepdim=True)) * 2
        embeddings = target.pre_model(input_ids, attention_mask=input_mask)[0][-1].detach()
        
        wrap_model.set_aux((input_ids, segment_ids, input_mask))
        target.model.eval()

        #logits = wrap_model(embeddings)
        adv_embeddings = adversary.perturb(embeddings, label_ids)
        logits = wrap_model(adv_embeddings)

        tot_num += len(label_ids)
        tot_acc += (logits.argmax(1).eq(label_ids)).sum().item()
        print ("%d/%d=%.4f"%(tot_acc, tot_num, tot_acc/tot_num))
        wrap_model.remove_aux()
    print ("Adv acc @ eps=%s: %d/%d=%.4f"%(EPS, tot_acc, tot_num, tot_acc/tot_num))
    assert 0


    # test the accuracy   
    #torch.save(target.model.state_dict(), args.dir+'/test_save.ckpt')
    acc = 0
    cert_rads = []
    for batch in test_batches:
        ret = target.step(batch, logger.epoch.eval())
        acc += ret[1] * len(batch)
        logits = ret[3]['logits'].detach().cpu().numpy()
        if args.last_noreg:
            w = target.model.classifier.weight
            last_lip = np.linalg.norm(w.detach().cpu().numpy(), 2)
        else:
            last_lip = 1.0
        rad = (logits.max(axis=1) - logits.min(axis=1)) / np.sqrt(2) / last_lip # TODO: only for binary!!
        cert_rads.append(rad)
    acc = float(acc / len(data_test))
    print("Accuracy: {:.3f}".format(acc))
    cert_rads = np.concatenate(cert_rads)
    print (cert_rads)
    print (len(cert_rads))
    print ("Avg Cert rads:", cert_rads.mean(), end='; ')
    print (" / ".join(['%.4f'%( (cert_rads>r).mean() ) for r in [0,0.05,0.1,0.2,0.5,1.0]]))
    print ("=================")
    print ("{:.3f}".format(acc),end='; ')
    print ("{:.6f}".format(cert_rads.mean()),end='; ')
    print (" / ".join(['%.4f'%( (cert_rads>r).mean() ) for r in [0,0.05,0.1,0.2,0.5,1.0]]))
    print (target.model.bert.embeddings.word_embeddings.weight.shape)
    print (target.model.bert.embeddings.word_embeddings.weight.norm(dim=1).shape)
    print (target.model.bert.embeddings.word_embeddings.weight)
    print (target.model.bert.embeddings.word_embeddings.weight.norm(dim=1).mean())
    print (target.model.bert.embeddings.position_embeddings.weight)
    print (target.model.bert.embeddings.position_embeddings.weight.norm(dim=1).mean())
    #print (target.model.bert.embeddings.token_type_embeddings.weight.norm(dim=1).mean())
