# Copyright (c) 2020, Zhouxing shi <zhouxingshichn@gmail.com>
# Licenced under the BSD 2-Clause License.

import scipy as sp

import os
if not "CUDA_VISIBLE_DEVICES" in os.environ:
    os.environ["CUDA_VISIBLE_DEVICES"] = "1"

import torch
import numpy as np
import sys, random, time, shutil, copy, nltk, json
from multiprocessing import Pool
from Logger import Logger
from Parser import Parser, update_arguments
from data_utils import load_data, get_batches, set_seeds
from Models import Transformer
from Verifiers import VerifierForward, VerifierBackward, VerifierDiscrete
from eval_words import eval_words
from tqdm import tqdm

argv = sys.argv[1:]
parser = Parser().getParser()
#args, _ = parser.parse_known_args(argv)
args = parser.parse_args()
args = update_arguments(args)
set_seeds(args.seed)
data_train, data_valid, data_test, _, _ = load_data(args)
set_seeds(args.seed)
print (args)

if args.approach != '':
    args.dir = args.dir + '_' + args.approach
if args.fix_word_emb:
    args.dir = args.dir + '_fixemb'
print (args.dir)
#assert 0

#import tensorflow as tf
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()


config = tf.ConfigProto(device_count = {'GPU': 0})
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

with sess.as_default():
    target = Transformer(args, data_train)
    print (target.model.bert.embeddings.word_embeddings.weight)
    print (target.model.bert.embeddings.word_embeddings.weight.norm(dim=1).mean())
    print (target.model.bert.embeddings.position_embeddings.weight)
    print (target.model.bert.embeddings.position_embeddings.weight.norm(dim=1).mean())
    print (target.model.bert.embeddings.config)
    print (target.model.bert.embeddings.config.layer_norm)
    #print (target.model.bert.embeddings.word_embeddings.weight)
    #for name, p in target.model.named_parameters():
    #    print (name, p.shape, p.abs().mean())
    #assert 0
    #print (target)
    #print (target.model)
    #assert 0

    random.shuffle(data_valid)
    random.shuffle(data_test)
    valid_batches = get_batches(data_valid, args.batch_size)
    test_batches = get_batches(data_test, args.batch_size)
    print("Dataset sizes: %d/%d/%d" % (len(data_train), len(data_valid), len(data_test)))

    #summary_names = ["loss", "accuracy"]
    #summary_num_pre = 2
    summary_names = ["loss", "accuracy", "certified radius"]
    summary_num_pre = 3

    logger = Logger(sess, args, summary_names, 1)

    print("\n")

    if args.approach == '':
        from Models.modeling import BertForSequenceClassification
        target.model = BertForSequenceClassification.from_pretrained(target.bert_model, cache_dir='cache/bert', num_labels=args.num_labels).cuda()
    else:
        from Models.onelip_modeling import OneLipBertForSequenceClassification
        #target.model = OneLipBertForSequenceClassification.from_pretrained(target.bert_model, cache_dir='cache/bert', num_labels=args.num_labels, approach=args.approach, last_noreg=args.last_noreg).cuda()
        test_model = OneLipBertForSequenceClassification.from_pretrained(target.bert_model, cache_dir='cache/bert', num_labels=args.num_labels, approach=args.approach, last_noreg=args.last_noreg).cuda()

    data = data_valid if args.use_dev else data_test           

    class WrapModel(torch.nn.Module):
        def __init__(self, model):
            super(WrapModel, self).__init__()
            self.model = model
            self.aux = None

        def set_aux(self, aux):
            self.aux = aux

        def remove_aux(self,):
            self.aux = None

        def forward(self, word_emb):
            input_ids, segment_ids, input_mask, position_embeddings = self.aux
            if args.approach == '':
                emb = word_emb + position_embeddings
                emb = target.model.bert.embeddings.LayerNorm(emb)
                emb = target.model.bert.embeddings.dropout(emb)
            else:
                word_emb = word_emb / (1e-8+word_emb.norm(dim=2,keepdim=True)) * 2
                position_embeddings = position_embeddings / (1e-8+position_embeddings.norm(dim=2,keepdim=True)) * 2
                emb = word_emb + position_embeddings
            logits, _, _, _, _, _, _ = target.model(input_ids, segment_ids, input_mask, labels=None, embeddings=emb)
            return logits
    from attack_lib import L2PGDAttack

    EPS = 1.0
    #EPS = 1e-8
    wrap_model = WrapModel(target.model)
    adversary = L2PGDAttack(wrap_model, num_steps=10, epsilon=EPS)
    # Emp atk
    tot_acc = 0
    tot_num = 0
    for batch in test_batches:
        input_ids, input_mask, segment_ids, features = target.get_input(batch)
        label_ids = torch.tensor([f.label_id for f in features], dtype=torch.long)
        label_ids = label_ids.to(target.device)  
        #embeddings,_ = target.get_embeddings(batch)
        #embeddings = target.model.bert.embeddings(input_ids, segment_ids)
        seq_length = input_ids.size(1)
        position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
        position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
        position_embeddings = target.model.bert.embeddings.position_embeddings(position_ids)
        word_embeddings = target.model.bert.embeddings.word_embeddings(input_ids)
        if args.approach != '':
            word_embeddings = word_embeddings / (1e-8+word_embeddings.norm(dim=2,keepdim=True)) * 2
            position_embeddings = position_embeddings / (1e-8+position_embeddings.norm(dim=2,keepdim=True)) * 2
        
        wrap_model.set_aux((input_ids, segment_ids, input_mask, position_embeddings))
        target.model.eval()

        #logits = wrap_model(embeddings)
        adv_word_embeddings = adversary.perturb(word_embeddings, label_ids)
        logits = wrap_model(adv_word_embeddings)

        tot_num += len(label_ids)
        tot_acc += (logits.argmax(1).eq(label_ids)).sum().item()
        print ("%d/%d=%.4f"%(tot_acc, tot_num, tot_acc/tot_num))
        wrap_model.remove_aux()
    print ("Adv acc @ eps=%s: %d/%d=%.4f"%(EPS, tot_acc, tot_num, tot_acc/tot_num))
    assert 0


    # test the accuracy   
    #torch.save(target.model.state_dict(), args.dir+'/test_save.ckpt')
    acc = 0
    cert_rads = []
    for batch in test_batches:
        ret = target.step(batch, logger.epoch.eval())
        acc += ret[1] * len(batch)
        logits = ret[3]['logits'].detach().cpu().numpy()
        if args.last_noreg:
            w = target.model.classifier.weight
            last_lip = np.linalg.norm(w.detach().cpu().numpy(), 2)
        else:
            last_lip = 1.0
        rad = (logits.max(axis=1) - logits.min(axis=1)) / np.sqrt(2) / last_lip # TODO: only for binary!!
        cert_rads.append(rad)
    acc = float(acc / len(data_test))
    print("Accuracy: {:.3f}".format(acc))
    cert_rads = np.concatenate(cert_rads)
    print (cert_rads)
    print (len(cert_rads))
    print ("Avg Cert rads:", cert_rads.mean(), end='; ')
    print (" / ".join(['%.4f'%( (cert_rads>r).mean() ) for r in [0,0.05,0.1,0.2,0.5,1.0]]))
    print ("=================")
    print ("{:.3f}".format(acc),end='; ')
    print ("{:.6f}".format(cert_rads.mean()),end='; ')
    print (" / ".join(['%.4f'%( (cert_rads>r).mean() ) for r in [0,0.05,0.1,0.2,0.5,1.0]]))
    print (target.model.bert.embeddings.word_embeddings.weight.shape)
    print (target.model.bert.embeddings.word_embeddings.weight.norm(dim=1).shape)
    print (target.model.bert.embeddings.word_embeddings.weight)
    print (target.model.bert.embeddings.word_embeddings.weight.norm(dim=1).mean())
    print (target.model.bert.embeddings.position_embeddings.weight)
    print (target.model.bert.embeddings.position_embeddings.weight.norm(dim=1).mean())
    #print (target.model.bert.embeddings.token_type_embeddings.weight.norm(dim=1).mean())
