# Copyright (c) 2020, Zhouxing shi <zhouxingshichn@gmail.com>
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights rved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BERT finetuning runner."""

from __future__ import absolute_import, division, print_function

import argparse, csv, logging, os, random, sys, shutil, pdb
import numpy as np
import torch
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
                              TensorDataset)
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange

from torch.nn import CrossEntropyLoss, MSELoss

from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE, WEIGHTS_NAME, CONFIG_NAME
from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule

from Models.modeling import BertForSequenceClassification
from Models.onelip_modeling import OneLipBertForSequenceClassification
from Models.utils import truncate_seq_pair, InputExample, \
    InputFeatures, convert_examples_to_features

class BERT:
    def __init__(self, args, data_train):
        self.general_init(args, data_train)
        self.load_pretrained()
        self._build_trainer()

    def general_init(self, args, data_train):
        self.args = args
        self.max_seq_length = args.max_sent_length
        self.do_lower_case = True
        self.learning_rate = args.lr
        self.gradient_accumulation_steps = 1
        self.seed = args.seed
        self.num_labels = args.num_labels
        self.label_list = range(args.num_labels) 
        self.num_train_optimization_steps = \
            args.num_epoches * (len(data_train) + args.batch_size - 1) // args.batch_size
        self.warmup_proportion = args.warmup
        self.weight_decay = args.weight_decay
        self.device = args.device

        self.dir = self.bert_model = args.dir
        self.checkpoint = False
        if not os.path.exists(self.dir):
            os.system("cp -r %s %s" % (args.base_dir, self.dir))
        # TODO:xiaojun: Skip checkpoint
        if os.path.exists(os.path.join(self.bert_model, "checkpoint")):
        #if os.path.exists(self.bert_model):
            if args.save_all:
                with open(os.path.join(self.bert_model, "checkpoint")) as file:
                    self.bert_model = os.path.join(self.bert_model, "ckpt-%d" % (int(file.readline())))
            else:
                self.bert_model = os.path.join(self.bert_model, "ckpt")
                #self.bert_model = os.path.join(self.bert_model, "ckpt-best")
            self.checkpoint = True
            print("BERT checkpoint:", self.bert_model)

        torch.manual_seed(self.seed)
        torch.cuda.manual_seed_all(self.seed)
        self.tokenizer = BertTokenizer.from_pretrained(self.bert_model, do_lower_case=self.do_lower_case)
        self.vocab = self.tokenizer.vocab
        self.softmax = torch.nn.Softmax(dim=-1)        

    def load_pretrained(self):
        cache_dir = "cache/bert"
        if self.args.approach == '':
            self.model = BertForSequenceClassification.from_pretrained(self.bert_model,
                    cache_dir=cache_dir,
                    num_labels=self.num_labels)      
        else:
            self.model = OneLipBertForSequenceClassification.from_pretrained(self.bert_model, cache_dir=cache_dir, num_labels=self.num_labels, approach=self.args.approach, last_noreg=self.args.last_noreg)
        self.model.to(self.device)

    def _build_trainer(self):
        param_optimizer = list(self.model.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']

        if self.args.fix_word_emb:
            for n, p in param_optimizer:
                if '_embedding' in n:
                    print ("Excluding",n)
                    print (p.requires_grad)
                    p.requires_grad=False
            optimizer_grouped_parameters = [
                {'params': [p for n, p in param_optimizer\
                    if ('_embedding' not in n) and (not any(nd in n for nd in no_decay))], 'weight_decay': self.weight_decay},
                {'params': [p for n, p in param_optimizer\
                    if ('_embedding' not in n) and (any(nd in n for nd in no_decay))], 'weight_decay': 0.0}
            ]
            #assert 0
            self.optimizer = BertAdam(optimizer_grouped_parameters,
                lr=self.learning_rate,
                warmup=self.warmup_proportion,
                t_total=self.num_train_optimization_steps
            )
            self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=[self.args.num_epoches*0.5,self.args.num_epoches*0.8], gamma=0.1)
            return
        if self.args.approach == '':
            optimizer_grouped_parameters = [
                {'params': [p for n, p in param_optimizer\
                    if (not any(nd in n for nd in no_decay))], 'weight_decay': self.weight_decay},
                {'params': [p for n, p in param_optimizer\
                    if (any(nd in n for nd in no_decay))], 'weight_decay': 0.0}
            ]
            #assert 0
            self.optimizer = BertAdam(optimizer_grouped_parameters,
                lr=self.learning_rate,
                warmup=self.warmup_proportion,
                t_total=self.num_train_optimization_steps
            )
            self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=[self.args.num_epoches*0.5,self.args.num_epoches*0.8], gamma=0.1)
            return
        #else:
        #    optimizer_grouped_parameters = [
        #        {'params': [p for n, p in param_optimizer\
        #            if not any(nd in n for nd in no_decay)], 'weight_decay': self.weight_decay},
        #        {'params': [p for n, p in param_optimizer\
        #            if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        #    ]
        ##print (optimizer_grouped_parameters[0]['weight_decay'], optimizer_grouped_parameters[1]['weight_decay'])
        ##print ([n for n,p in param_optimizer])
        ##assert 0

        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
        #print ("ADAFACTOR!!")
        #print ("ADAFACTOR!!")
        #from .adafactor import AdaFactor
        #self.optimizer = AdaFactor(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
        #print ("SGD!!!")
        #print ("SGD!!!")
        #print ("SGD!!!")
        #self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.learning_rate, momentum=0.9)

        #self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=[self.args.num_epoches*0.5,self.args.num_epoches*0.8],gamma=0.1)
        #print ("NO SCHEDULER!!")
        #print ("NO SCHEDULER!!")
        class WarmupScheduler:
            def __init__(self, optimizer, num_epoches):
                if num_epoches > 10:
                    self.sch1 = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.1, total_iters=num_epoches*0.1)
                else:
                    self.sch1 = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, total_iters=num_epoches*0.1)
                #self.sch2 = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[self.args.num_epoches*0.5,self.args.num_epoches*0.8], gamma=0.1)
                self.sch2 = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[num_epoches*0.8,], gamma=0.1)
                #print ("SCHEDULER AT 80!!")
                #print ("SCHEDULER AT 80!!")
                #print ("SCHEDULER AT 80!!")
                #self.sch2 = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[80,], gamma=0.1)
                pass
            def step(self,):
                self.sch1.step()
                self.sch2.step()
                pass
        self.scheduler = WarmupScheduler(self.optimizer, self.args.num_epoches)

    def save(self, epoch, is_best=False):
        # Save a trained model, configuration and tokenizer
        model_to_save = self.model.module if hasattr(self.model, 'module') else self.model  # Only save the model it-self

        # If we save using the predefined names, we can load using `from_pretrained`
        if self.args.save_all:
            output_dir = os.path.join(self.dir, "ckpt-%d" % epoch)
            if os.path.exists(output_dir):
                shutil.rmtree(output_dir)
            os.mkdir(output_dir)
        else:
            if is_best:
                output_dir = os.path.join(self.dir, "ckpt-best")
            else:
                output_dir = os.path.join(self.dir, "ckpt")
            if not os.path.exists(output_dir):
                os.mkdir(output_dir)
        
        output_model_file = os.path.join(output_dir, WEIGHTS_NAME)
        output_config_file = os.path.join(output_dir, CONFIG_NAME)

        torch.save(model_to_save.state_dict(), output_model_file)
        model_to_save.config.to_json_file(output_config_file)
        self.tokenizer.save_vocabulary(output_dir)    

        if is_best:
            with open(os.path.join(self.dir, "checkpoint"), "w") as file:
                file.write("%d" % epoch)   

        print("BERT saved: %s" % output_dir) 

    def get_input(self, batch):
        features = convert_examples_to_features(
            batch, self.label_list, self.max_seq_length, self.tokenizer)
        
        input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
        input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
        segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)        

        input_ids = input_ids.to(self.device)
        input_mask = input_mask.to(self.device)
        segment_ids = segment_ids.to(self.device)

        return input_ids, input_mask, segment_ids, features

    def get_embeddings(self, batch):
        input_ids, input_mask, token_type_ids, features = self.get_input(batch)

        seq_length = input_ids.size(1)
        position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
        position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        word_embeddings = self.model.bert.embeddings.word_embeddings(input_ids)
        position_embeddings = self.model.bert.embeddings.position_embeddings(position_ids)
        #token_type_embeddings = self.model.bert.embeddings.token_type_embeddings(token_type_ids)
        #embeddings = (word_embeddings + position_embeddings + token_type_embeddings)
        embeddings = (word_embeddings + position_embeddings)
        tokens = [feature.tokens for feature in features]

        return embeddings, tokens

    def step(self, batch, epoch=-1, is_train=False, infer_grad=False):
        input_ids, input_mask, segment_ids, features = self.get_input(batch)
        label_ids = torch.tensor([f.label_id for f in features], dtype=torch.long)
        label_ids = label_ids.to(self.device)  
        
        if is_train:
            self.model.train()
            logits, embedding_output, encoded_layers, attention_scores, attention_probs, \
                self_output, pooled_output = self.model(
                    input_ids, segment_ids, input_mask, labels=None)
        else:
            self.model.eval()
            grad = torch.enable_grad() if infer_grad else torch.no_grad()
            with grad:
                logits, embedding_output, encoded_layers, attention_scores, attention_probs, \
                    self_output, pooled_output = self.model(
                        input_ids, segment_ids, input_mask, labels=None)

        loss_fct = CrossEntropyLoss()
        loss = loss_fct(logits.view(-1, self.num_labels), label_ids.view(-1))
        if self.args.last_noreg:
            w = self.model.classifier.weight
            last_lip = np.linalg.norm(w.detach().cpu().numpy(), 2)
        else:
            last_lip = 1.0
        if self.args.cr_loss > 0.0:
            assert self.num_labels == 2
            logits_v, labels_v = logits.view(-1, self.num_labels), label_ids.view(-1)
            gap = logits_v[torch.arange(len(labels_v)),labels_v] - logits_v[torch.arange(len(labels_v)),1-labels_v] # TODO: only for binary!!
            loss_cr = (-torch.nn.functional.relu(gap)/np.sqrt(2)).mean() / last_lip
            assert epoch >= 0
            #cr_gamma = self.args.cr_loss
            #cr_gamma = self.args.cr_loss * min(1.25 * (epoch+1) / self.args.num_epoches, 1)
            #cr_gamma = self.args.cr_loss * np.exp( max(self.args.num_epoches - epoch*1.25, 0)/self.args.num_epoches * np.log(0.001) )  # 0.001gamma -> gamma
            cr_gamma = self.args.cr_loss * np.exp( max(self.args.num_epoches - epoch*1.5, 0)/self.args.num_epoches * np.log(0.01) )  # 0.001gamma -> gamma
            #cr_gamma = self.args.cr_loss * np.exp( max(self.args.num_epoches - epoch*2, 0)/self.args.num_epoches * np.log(0.01) )  # 0.001gamma -> gamma
            loss = loss + cr_gamma * loss_cr

        preds = self.softmax(logits).detach().cpu().numpy()
        pred_labels = np.argmax(preds, axis=1)
        acc = (pred_labels == label_ids.cpu().numpy()).mean()

        if infer_grad:
            gradients = torch.autograd.grad(loss, embedding_output)[0]
        else:
            gradients = None

        #logits_np = logits.detach().cpu().numpy()
        #rad = (logits_np.max(axis=1)-logits_np.min(axis=1))/np.sqrt(2)/last_lip # TODO: only for binary!!
        logits_np = logits.detach().cpu()
        top2_pred = torch.topk(logits_np, 2, dim=1).values
        rad = (top2_pred[:,0] - top2_pred[:,1]) / np.sqrt(2) / last_lip
        rad = rad.numpy()
        avg_rad = rad.mean()
        ret = [
            #loss, acc,
            loss, acc, avg_rad, 
            {
                "logits": logits, 
                "pred_scores": preds, 
                "pred_labels": pred_labels,
                "gt_labels": label_ids.cpu().numpy(),
                "embedding_output": embedding_output,
                "encoded_layers": encoded_layers,
                "attention_scores": attention_scores,
                "attention_probs": attention_probs,
                "self_output": self_output,
                "pooled_output": pooled_output,
                "features": features,
                "gradients": gradients
            }
        ]

        if is_train:
            loss.backward()
            self.optimizer.step()
            self.optimizer.zero_grad()    
    
        return ret

    def predict(self, X):
        return self.step([
            { "sent_a": x, "label": 1}
            for x in X
        ])[-1]["pred_scores"]

    def get_gradients(self, batch, embeddings):
        input_ids, input_mask, segment_ids, features = self.get_input(batch)
        label_ids = torch.tensor([f.label_id for f in features], dtype=torch.long)
        label_ids = label_ids.to(self.device)        

        self.model.eval()
        logits, embedding_output, encoded_layers, attention_scores, attention_probs, \
            self_output, pooled_output = self.model(
                input_ids, segment_ids, input_mask, labels=None)
        loss_fct = CrossEntropyLoss()
        loss = loss_fct(logits.view(-1, self.num_labels), label_ids.view(-1))

        grad = torch.zeros(preds.shape[0], preds.shape[1], embeddings.shape[-1]).to(self.device)
        grad[:, 0, :] = torch.autograd.grad(torch.sum(preds, dim=0)[0], embeddings,
            retain_graph=True, only_inputs=True)[0][:, pos, :]
        grad[:, 1, :] = -grad[:, 0, :]

        return grad
