# Copyright (c) 2020, Zhouxing shi <zhouxingshichn@gmail.com>
# Licenced under the BSD 2-Clause License.

import torch
import os
import numpy as np
from Models.BERT import BERT
from Models.modeling import BertForSequenceClassification, BertConfig
from Models.onelip_modeling import OneLipBertForSequenceClassification
from pytorch_pretrained_bert.tokenization import BertTokenizer

class Transformer(BERT):
    def __init__(self, args, data_train):
        self.general_init(args, data_train)
        #if self.args.fix_word_emb:
        #    self.tokenizer = BertTokenizer.from_pretrained(self.args.dir[:-7]+'_onelip-softmax', do_lower_case=self.do_lower_case)
        #    self.vocab = self.tokenizer.vocab

        self.min_word_freq = args.min_word_freq
        self.update_vocabulary(data_train)

        config = BertConfig(self.vocab_size)
        config.num_hidden_layers = args.num_layers
        config.hidden_size = args.hidden_size
        config.intermediate_size = args.intermediate_size
        config.hidden_act = args.hidden_act
        config.num_attention_heads = args.num_attention_heads
        config.layer_norm = args.layer_norm

        if not self.checkpoint:
            #bert = self.model.bert

            if args.approach == '':
                self.model = BertForSequenceClassification(config, self.num_labels)
            else:
                self.model = OneLipBertForSequenceClassification(config, self.num_labels, approach=args.approach, last_noreg = args.last_noreg)
            self.model.to(self.device)

            if args.approach != '':
                #print ("Using vanilla embedding!")
                #print ("Using vanilla embedding!")
                #print ("Using vanilla embedding!")
                #vanilla_word_emb = torch.load('./model_sst_1/ckpt-3/pytorch_model.bin')['bert.embeddings.word_embeddings.weight']
                #self.model.bert.embeddings.word_embeddings.weight.data = vanilla_word_emb / (vanilla_word_emb.norm(dim=1,keepdim=True)+1e-8)

                print ("Using vanilla embedding!")
                print ("Using vanilla embedding!")
                print ("Using vanilla embedding!")
                if args.data == 'sst':
                    self.model.bert.embeddings.word_embeddings.weight.data = torch.load('./model_sst_1/ckpt-3/pytorch_model.bin')['bert.embeddings.word_embeddings.weight']
                    self.model.bert.embeddings.position_embeddings.weight.data = torch.load('./model_sst_1/ckpt-3/pytorch_model.bin')['bert.embeddings.position_embeddings.weight']
                elif args.data == 'yelp':
                    self.model.bert.embeddings.word_embeddings.weight.data = torch.load('./model_yelp_1/ckpt/pytorch_model.bin')['bert.embeddings.word_embeddings.weight']
                    self.model.bert.embeddings.position_embeddings.weight.data = torch.load('./model_yelp_1/ckpt/pytorch_model.bin')['bert.embeddings.position_embeddings.weight']
                elif args.data == 'qqp':
                    self.model.bert.embeddings.word_embeddings.weight.data = torch.load('./model_qqp_1/ckpt-best/pytorch_model.bin')['bert.embeddings.word_embeddings.weight']
                    self.model.bert.embeddings.position_embeddings.weight.data = torch.load('./model_qqp_1/ckpt-best/pytorch_model.bin')['bert.embeddings.position_embeddings.weight']
                elif args.data == 'fever':
                    self.model.bert.embeddings.word_embeddings.weight.data = torch.load('./model_fever_1/ckpt-best/pytorch_model.bin')['bert.embeddings.word_embeddings.weight']
                    self.model.bert.embeddings.position_embeddings.weight.data = torch.load('./model_fever_1/ckpt-best/pytorch_model.bin')['bert.embeddings.position_embeddings.weight']
                else:
                    raise NotImplementedError()

                #print ("Using glove embedding!")
                #print ("Using glove embedding!")
                #print ("Using glove embedding!")
                #glove_emb = np.random.randn(*self.model.bert.embeddings.word_embeddings.weight.shape)
                #is_exist = [False]*len(glove_emb)
                #N_exist = 0
                #with open('./data/glove.6B.200d.txt') as inf:
                #    for line in inf:
                #        info = line.strip().split(' ')
                #        w = info[0]
                #        if w in self.vocab:
                #            idx = self.vocab[w]
                #            emb = np.array([float(v) for v in info[1:]])
                #            glove_emb[idx,:200] = emb
                #            N_exist += 1
                #            is_exist[idx] = True
                #glove_emb = torch.FloatTensor(glove_emb).cuda()
                #self.model.bert.embeddings.word_embeddings.weight.data = glove_emb / (glove_emb.norm(dim=1,keepdim=True)+1e-8)
        else:
            self.load_pretrained() # TODO:xiaojun: Skip pretrained

        if self.args.fix_word_emb:
            #print (self.args.dir)
            #print (self.tokenizer)
            #print (len(self.vocab))
            #print (self.model.bert.embeddings.word_embeddings.weight)
            #with open(self.args.dir[:-7]+'_onelip-softmax/checkpoint') as inf:
            #    ckpt = inf.readline().strip()
            #state_dict = torch.load(self.args.dir[:-7]+'_onelip-softmax/ckpt-%s/pytorch_model.bin'%ckpt)

            #state_dict = torch.load(self.args.dir[:-7]+'_onelip-softmax-v2/ckpt/pytorch_model.bin')
            #self.model.bert.embeddings.word_embeddings.weight.data = state_dict['bert.embeddings.word_embeddings.weight']
            #self.model.bert.embeddings.position_embeddings.weight.data = state_dict['bert.embeddings.position_embeddings.weight']
            state_dict = torch.load(self.args.dir[:-7]+'_onelip-softmax/ckpt/pytorch_model.bin')
            w = state_dict['bert.embeddings.word_embeddings.weight']
            self.model.bert.embeddings.word_embeddings.weight.data = w / (1e-8+w.norm(dim=1,keepdim=True)) * 2
            w = state_dict['bert.embeddings.position_embeddings.weight']
            self.model.bert.embeddings.position_embeddings.weight.data = w / (1e-8+w.norm(dim=1,keepdim=True)) * 2
            #w = state_dict['bert.embeddings.token_type_embeddings.weight']
            #self.model.bert.embeddings.token_type_embeddings.weight.data = w / 3 / (1e-8+w.norm(dim=1,keepdim=True))
            #for k,v in state_dict.items():
            #    print (k, v.shape, v.abs().mean())
            #assert 0

        self._build_trainer()

    """
    Build a vocabulary from the training data instead of using BERT's vocabulary.
    Because we are now training the Transformer from scratch.
    """
    def update_vocabulary(self, data_train):        
        vocab_base = os.path.join(self.bert_model, "vocab_base.txt")
        if not os.path.exists(vocab_base):
            with open(os.path.join(self.bert_model, "vocab.txt")) as file:
                self.vocab_size = len(file.readlines())
            return
        cnt = {}
        in_bert = {}
        with open(vocab_base) as file:
            for line in file.readlines():
                cnt[line[:-1]] = 0
                in_bert[line[:-1]] = True
        for example in data_train:
            for token in example["sent_a"]:
                if not token in cnt:
                    cnt[token] = 0
                cnt[token] += 1
        cnt["[PAD]"] = 1e8
        words = []
        for w in cnt:
            if w[0] == "#" or w[0] == "[" or w in in_bert or cnt[w] >= self.min_word_freq:
                words.append(w)
        words = sorted(words, key=lambda w:cnt[w], reverse=True)          
        with open(os.path.join(self.bert_model, "vocab.txt"), "w") as file:
            for w in words:
                file.write("%s\n" % w)

        self.vocab_size = len(words)
        self.tokenizer = BertTokenizer.from_pretrained(self.bert_model, do_lower_case=self.do_lower_case)
        self.vocab = self.tokenizer.vocab        
