# Copyright (c) 2020, Zhouxing shi <zhouxingshichn@gmail.com>
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights rved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BERT finetuning runner."""

from __future__ import absolute_import, division, print_function

import argparse, csv, logging, os, random, sys, shutil, pdb
import numpy as np
import torch
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
                              TensorDataset)
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange

from torch.nn import CrossEntropyLoss, MSELoss

from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE, WEIGHTS_NAME, CONFIG_NAME
from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.modeling import BertModel
from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule

from Models.modeling_withpr import BertForSequenceClassificationWithPretrain
from Models.onelip_modeling_withpr import OneLipBertForSequenceClassificationWithPretrain
from Models.utils import truncate_seq_pair, InputExample, \
    InputFeatures, convert_examples_to_features

class BERT_pr:
    def __init__(self, args, data_train):
        self.general_init(args, data_train)
        self.load_pretrained()
        self._build_trainer()

    def general_init(self, args, data_train):
        self.args = args
        self.max_seq_length = args.max_sent_length
        self.do_lower_case = True
        self.learning_rate = args.lr
        self.gradient_accumulation_steps = 1
        self.seed = args.seed
        self.num_labels = args.num_labels
        self.label_list = range(args.num_labels) 
        self.num_train_optimization_steps = \
            args.num_epoches * (len(data_train) + args.batch_size - 1) // args.batch_size
        self.warmup_proportion = args.warmup
        self.weight_decay = args.weight_decay
        self.device = args.device

        self.dir = self.bert_model = args.dir
        self.checkpoint = False
        if not os.path.exists(self.dir):
            os.system("cp -r %s %s" % (args.base_dir, self.dir))
        # TODO:xiaojun: Skip checkpoint
        if os.path.exists(os.path.join(self.bert_model, "checkpoint")):
        #if os.path.exists(self.bert_model):
            if args.save_all:
                with open(os.path.join(self.bert_model, "checkpoint")) as file:
                    self.bert_model = os.path.join(self.bert_model, "ckpt-%d" % (int(file.readline())))
            else:
                self.bert_model = os.path.join(self.bert_model, "ckpt")
                #self.bert_model = os.path.join(self.bert_model, "ckpt-best")
            self.checkpoint = True
            print("BERT checkpoint:", self.bert_model)

        torch.manual_seed(self.seed)
        torch.cuda.manual_seed_all(self.seed)
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.vocab = self.tokenizer.vocab
        self.pre_model = BertModel.from_pretrained('bert-base-uncased').cuda()
        self.pre_model.eval()
        self.softmax = torch.nn.Softmax(dim=-1)        

    def load_pretrained(self):
        cache_dir = "cache/bert"
        if self.args.approach == '':
            self.model = BertForSequenceClassificationWithPretrain.from_pretrained(self.bert_model,
                    cache_dir=cache_dir,
                    num_labels=self.num_labels)      
        else:
            self.model = OneLipBertForSequenceClassificationWithPretrain.from_pretrained(self.bert_model, cache_dir=cache_dir, num_labels=self.num_labels, approach=self.args.approach, last_noreg=self.args.last_noreg)
        self.model.to(self.device)

    def _build_trainer(self):
        param_optimizer = list(self.model.named_parameters())
        #print ("TRAINING ALL!!")
        #print ("TRAINING ALL!!")
        #print ("TRAINING ALL!!")
        #param_optimizer = list(self.model.named_parameters()) + list(self.pre_model.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']

        if self.args.approach == '':
            optimizer_grouped_parameters = [
                {'params': [p for n, p in param_optimizer\
                    if (not any(nd in n for nd in no_decay))], 'weight_decay': self.weight_decay},
                {'params': [p for n, p in param_optimizer\
                    if (any(nd in n for nd in no_decay))], 'weight_decay': 0.0}
            ]
            #for p in optimizer_grouped_parameters[0]['params']:
            #    print (p.shape)
            #for p in optimizer_grouped_parameters[1]['params']:
            #    print (p.shape)
            ##assert 0
            self.optimizer = BertAdam(optimizer_grouped_parameters,
                lr=self.learning_rate,
                warmup=self.warmup_proportion,
                t_total=self.num_train_optimization_steps
            )
            self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=[self.args.num_epoches*0.5,self.args.num_epoches*0.8], gamma=0.1)
            return
        #else:
        #    optimizer_grouped_parameters = [
        #        {'params': [p for n, p in param_optimizer\
        #            if not any(nd in n for nd in no_decay)], 'weight_decay': self.weight_decay},
        #        {'params': [p for n, p in param_optimizer\
        #            if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        #    ]
        ##print (optimizer_grouped_parameters[0]['weight_decay'], optimizer_grouped_parameters[1]['weight_decay'])
        ##print ([n for n,p in param_optimizer])
        ##assert 0

        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
        #print ("ADAFACTOR!!")
        #print ("ADAFACTOR!!")
        #from .adafactor import AdaFactor
        #self.optimizer = AdaFactor(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
        #print ("SGD!!!")
        #print ("SGD!!!")
        #print ("SGD!!!")
        #self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.learning_rate, momentum=0.9)

        #self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=[self.args.num_epoches*0.5,self.args.num_epoches*0.8],gamma=0.1)
        #print ("NO SCHEDULER!!")
        #print ("NO SCHEDULER!!")
        class WarmupScheduler:
            def __init__(self, optimizer, num_epoches):
                if num_epoches > 10:
                    self.sch1 = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.1, total_iters=num_epoches*0.1)
                else:
                    self.sch1 = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, total_iters=num_epoches*0.1)
                #self.sch2 = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[self.args.num_epoches*0.5,self.args.num_epoches*0.8], gamma=0.1)
                self.sch2 = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[num_epoches*0.8,], gamma=0.1)
                #print ("SCHEDULER AT 80!!")
                #print ("SCHEDULER AT 80!!")
                #print ("SCHEDULER AT 80!!")
                #self.sch2 = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[80,], gamma=0.1)
                pass
            def step(self,):
                self.sch1.step()
                self.sch2.step()
                pass
        self.scheduler = WarmupScheduler(self.optimizer, self.args.num_epoches)

    def save(self, epoch, is_best=False):
        # Save a trained model, configuration and tokenizer
        model_to_save = self.model.module if hasattr(self.model, 'module') else self.model  # Only save the model it-self

        # If we save using the predefined names, we can load using `from_pretrained`
        if self.args.save_all:
            output_dir = os.path.join(self.dir, "ckpt-%d" % epoch)
            if os.path.exists(output_dir):
                shutil.rmtree(output_dir)
            os.mkdir(output_dir)
        else:
            if is_best:
                output_dir = os.path.join(self.dir, "ckpt-best")
            else:
                output_dir = os.path.join(self.dir, "ckpt")
            if not os.path.exists(output_dir):
                os.mkdir(output_dir)
        
        output_model_file = os.path.join(output_dir, WEIGHTS_NAME)
        output_config_file = os.path.join(output_dir, CONFIG_NAME)

        torch.save(model_to_save.state_dict(), output_model_file)
        model_to_save.config.to_json_file(output_config_file)
        self.tokenizer.save_vocabulary(output_dir)    

        if is_best:
            with open(os.path.join(self.dir, "checkpoint"), "w") as file:
                file.write("%d" % epoch)   

        print("BERT saved: %s" % output_dir) 

    def get_input(self, batch):
        features = convert_examples_to_features(
            batch, self.label_list, self.max_seq_length, self.tokenizer)
        
        input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
        input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
        segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)        

        input_ids = input_ids.to(self.device)
        input_mask = input_mask.to(self.device)
        segment_ids = segment_ids.to(self.device)

        return input_ids, input_mask, segment_ids, features

    def get_embeddings(self, batch):
        input_ids, input_mask, token_type_ids, features = self.get_input(batch)

        #seq_length = input_ids.size(1)
        #position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
        #position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
        #if token_type_ids is None:
        #    token_type_ids = torch.zeros_like(input_ids)

        #word_embeddings = self.model.bert.embeddings.word_embeddings(input_ids)
        #position_embeddings = self.model.bert.embeddings.position_embeddings(position_ids)
        ##token_type_embeddings = self.model.bert.embeddings.token_type_embeddings(token_type_ids)
        ##embeddings = (word_embeddings + position_embeddings + token_type_embeddings)
        #embeddings = (word_embeddings + position_embeddings)
        embeddings = self.pre_model(input_ids, attention_mask=input_mask)[0][-1]#.detach()
        embeddings = embeddings / (1e-8+embeddings.norm(dim=2,keepdim=True)) * 2
        embeddings = self.model.bert.emb_transform(embeddings)
        tokens = [feature.tokens for feature in features]

        return embeddings, tokens

    def step(self, batch, epoch=-1, is_train=False, infer_grad=False):
        input_ids, input_mask, segment_ids, features = self.get_input(batch)
        label_ids = torch.tensor([f.label_id for f in features], dtype=torch.long)
        label_ids = label_ids.to(self.device)  
        
        if is_train:
            self.model.train()
            #embedding_output = self.pre_model(input_ids, attention_mask=input_mask)[0][-1]
            embedding_output = self.pre_model(input_ids, attention_mask=input_mask)[0][-1]#.detach()
            embedding_output = embedding_output / (1e-8+embedding_output.norm(dim=2,keepdim=True)) * 2
            logits, encoded_layers, attention_scores, attention_probs, self_output, pooled_output = self.model(embedding_output, attention_mask=input_mask)
            #logits, embedding_output, encoded_layers, attention_scores, attention_probs, \
            #    self_output, pooled_output = self.model(
            #        input_ids, segment_ids, input_mask, labels=None)
        else:
            self.model.eval()
            grad = torch.enable_grad() if infer_grad else torch.no_grad()
            with grad:
                embedding_output = self.pre_model(input_ids, attention_mask=input_mask)[0][-1]#.detach()
                embedding_output = embedding_output / (1e-8+embedding_output.norm(dim=2,keepdim=True)) * 2
                logits, encoded_layers, attention_scores, attention_probs, self_output, pooled_output = self.model(embedding_output, attention_mask=input_mask)
                #logits, embedding_output, encoded_layers, attention_scores, attention_probs, \
                #    self_output, pooled_output = self.model(
                #        input_ids, segment_ids, input_mask, labels=None)

        loss_fct = CrossEntropyLoss()
        loss = loss_fct(logits.view(-1, self.num_labels), label_ids.view(-1))
        if self.args.last_noreg:
            w = self.model.classifier.weight
            last_lip = np.linalg.norm(w.detach().cpu().numpy(), 2)
        else:
            last_lip = 1.0
        if self.args.cr_loss > 0.0:
            logits_v, labels_v = logits.view(-1, self.num_labels), label_ids.view(-1)
            #gap = logits_v[torch.arange(len(labels_v)),labels_v] - logits_v[torch.arange(len(labels_v)),1-labels_v] # TODO: only for binary!!
            logits_y = logits_v[torch.arange(len(labels_v)),labels_v]
            onehot = torch.zeros_like(logits_v)
            onehot[torch.arange(len(labels_v)),labels_v] = 1.
            logits_trunc = logits_v - onehot*1e6
            logits_othermax = logits_trunc.max(axis=1)[0]
            gap = logits_y - logits_othermax
            loss_cr = (-torch.nn.functional.relu(gap)/np.sqrt(2)).mean() / last_lip
            assert epoch >= 0
            #cr_gamma = self.args.cr_loss
            #cr_gamma = self.args.cr_loss * min(1.25 * (epoch+1) / self.args.num_epoches, 1)
            #cr_gamma = self.args.cr_loss * np.exp( max(self.args.num_epoches - epoch*1.25, 0)/self.args.num_epoches * np.log(0.001) )  # 0.001gamma -> gamma
            cr_gamma = self.args.cr_loss * np.exp( max(self.args.num_epoches - epoch*1.5, 0)/self.args.num_epoches * np.log(0.01) )  # 0.001gamma -> gamma
            #cr_gamma = self.args.cr_loss * np.exp( max(self.args.num_epoches - epoch*2, 0)/self.args.num_epoches * np.log(0.01) )  # 0.001gamma -> gamma
            loss = loss + cr_gamma * loss_cr

        preds = self.softmax(logits).detach().cpu().numpy()
        # For HANS transfer from MNLI: merge pred1 and pred2!
        if self.args.data == 'hans':
            assert not is_train
            print ("MERGING PRED 1 and 2")
            #preds[:,1] = np.maximum(preds[:,1], preds[:,2])
            preds[:,1] = preds[:,1]+preds[:,2]
        pred_labels = np.argmax(preds, axis=1)
        acc = (pred_labels == label_ids.cpu().numpy()).mean()

        if infer_grad:
            gradients = torch.autograd.grad(loss, embedding_output)[0]
        else:
            gradients = None

        logits_np = logits.detach().cpu().numpy()
        #rad = (logits_np.max(axis=1)-logits_np.min(axis=1))/np.sqrt(2)/last_lip # TODO: only for binary!!
        logits_max, indices = logits_np.max(axis=1), logits_np.argmax(axis=1)
        onehot = np.zeros_like(logits_np)
        onehot[np.arange(len(logits_np)), indices] = 1.
        logits_trunc = logits_np - onehot*1e6
        logits_nextmax = logits_trunc.max(axis=1)
        rad = (logits_max - logits_nextmax) / np.sqrt(2) / last_lip

        avg_rad = rad.mean()
        ret = [
            #loss, acc,
            loss, acc, avg_rad, 
            {
                "logits": logits, 
                "pred_scores": preds, 
                "pred_labels": pred_labels,
                "gt_labels": label_ids.cpu().numpy(),
                "embedding_output": embedding_output,
                "encoded_layers": encoded_layers,
                "attention_scores": attention_scores,
                "attention_probs": attention_probs,
                "self_output": self_output,
                "pooled_output": pooled_output,
                "features": features,
                "gradients": gradients
            }
        ]

        if is_train:
            loss.backward()
            self.optimizer.step()
            ####
            #for name, p in list(self.model.named_parameters()) + list(self.pre_model.named_parameters()):
            #    p_grad = p.grad.norm() if p.grad is not None else None
            #    print (name, p_grad)
            #assert 0
            ####
            self.optimizer.zero_grad()    
    
        return ret

    def predict(self, X):
        return self.step([
            { "sent_a": x, "label": 1}
            for x in X
        ])[-1]["pred_scores"]

    def get_gradients(self, batch, embeddings):
        raise NotImplementedError()
        input_ids, input_mask, segment_ids, features = self.get_input(batch)
        label_ids = torch.tensor([f.label_id for f in features], dtype=torch.long)
        label_ids = label_ids.to(self.device)        

        self.model.eval()
        logits, embedding_output, encoded_layers, attention_scores, attention_probs, \
            self_output, pooled_output = self.model(
                input_ids, segment_ids, input_mask, labels=None)
        loss_fct = CrossEntropyLoss()
        loss = loss_fct(logits.view(-1, self.num_labels), label_ids.view(-1))

        grad = torch.zeros(preds.shape[0], preds.shape[1], embeddings.shape[-1]).to(self.device)
        grad[:, 0, :] = torch.autograd.grad(torch.sum(preds, dim=0)[0], embeddings,
            retain_graph=True, only_inputs=True)[0][:, pos, :]
        grad[:, 1, :] = -grad[:, 0, :]

        return grad
