#Using pretrained bert from here: pip install https://github.com/huggingface/pytorch-pretrained-BERT/releases/download/v0.6.2/pytorch_pretrained_bert-0.6.2-py3-none-any.whl

import numpy as np
import torch as t
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from torch.utils.data.distributed import DistributedSampler

import itertools as it

from torch.nn import CrossEntropyLoss
import torch.nn as nn

from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE, WEIGHTS_NAME, CONFIG_NAME
from pytorch_pretrained_bert.modeling import BertModel, BertForMaskedLM, BertForSequenceClassification, BertConfig
from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule

from models.transformer_layers import PositionalEncoding, LearnablePositionalEncoding, \
    IndependentTransformerDecoderLayer, TransformerDecoderLayer

device = "cuda" if t.cuda.is_available() else "cpu"


def sents_to_tensors(sents, tokenizer, max_length):
    
    tokens_tensor = t.zeros((len(sents), max_length)).type(t.LongTensor)
    masks_tensor = t.zeros((len(sents), max_length)).type(t.LongTensor)

    max_encountered = 0

    for i, s in enumerate(sents):
        tokenized_text = tokenizer.tokenize("[CLS]" + " " + s + " [SEP]")
        indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)

        max_encountered = max(max_encountered, len(tokenized_text))

        tokens_tensor[i, 0:len(indexed_tokens)] = t.LongTensor([indexed_tokens[:max_length]])
        masks_tensor[i, 0:len(indexed_tokens)] = 1

    segments_tensor = t.zeros((len(sents), max_length)).type(t.LongTensor)

    tokens_tensor = tokens_tensor.detach().to(device)
    segments_tensor = segments_tensor.to(device)
    masks_tensor = masks_tensor.detach().to(device)

    return tokens_tensor[:, :max_encountered], segments_tensor[:, :max_encountered], masks_tensor[:, :max_encountered]


class SentenceEmbedder(nn.Module):

    def __init__(self, bert_model, bert_tokenizer, max_length, embedding_size):
        super().__init__()

        self.bert_model = bert_model
        self.bert_tokenizer = bert_tokenizer
        self.max_length = max_length

        self.fc = t.nn.Linear(768, embedding_size)
        self.fc.to(device)
        self.bert_model.to(device)

    def forward(self, sents):
        '''

        :param sents: list of raw text sentences.
        :return:
        '''
        tokens_tensor, segments_tensor, masks_tensor = sents_to_tensors(sents, self.bert_tokenizer, self.max_length)
        bert_predictions = self.bert_model(tokens_tensor, segments_tensor, masks_tensor)
        bert_pooled_last_layer = bert_predictions[1]  # 1 holds the "pooled" layer (i.e. CLS token embedding passed through a dense layer)

        transformed = self.fc(bert_pooled_last_layer)

        return transformed

class BertExtractor(nn.Module):

    def __init__(self, sentence_embedder, query_emb_size, num_world_state_tokens, nhead, nlayers):
        super().__init__()

        self.sentence_embedder = sentence_embedder
        self.query_emb_size = query_emb_size
        self.num_world_state_tokens = num_world_state_tokens

        decoder_layer = IndependentTransformerDecoderLayer(d_model=query_emb_size, nhead=nhead)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=nlayers)

        self.pos_enc = PositionalEncoding(d_model=query_emb_size)

        self.fc = nn.Linear(self.query_emb_size, 1)

        self.to(device)

    def forward(self, world_states, queries):
        '''
        :param world_states: a matrix of size N x world_state_size
        :param queries: list (length N) of lists (of length K) of strings. queries[k] - all queries to the k th world.
        :return: a matrix N x K, predictions for every query.
        '''

        assert world_states.shape[1] % self.num_world_state_tokens == 0, "Number of world state tokens is inconsistent with the world state size"
        assert world_states.shape[1] // self.query_emb_size == self.num_world_state_tokens, "World state size is inconsistent with query embedding size"

        N = world_states.shape[0]
        K = len(queries[0])
        query_embeddings = self.sentence_embedder(list(it.chain.from_iterable(queries))).view((N, K, -1))
        world_states_pos = self.pos_enc(world_states.view(N, self.num_world_state_tokens, self.query_emb_size))


        query_answers = self.fc(self.decoder(query_embeddings.permute([1, 0, 2]), world_states_pos.permute([1, 0, 2])))


        return query_answers.permute([1, 0, 2])



class BertUpdater(nn.Module):

    def __init__(self, sentence_embedder, instr_emb_size, num_world_state_tokens, nhead, nlayers, unembed=False):
        super().__init__()

        self.sentence_embedder = sentence_embedder
        self.instr_emb_size = instr_emb_size
        self.num_world_state_tokens = num_world_state_tokens

        decoder_layer = TransformerDecoderLayer(d_model=instr_emb_size, nhead=nhead)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=nlayers)

        self.pos_enc = PositionalEncoding(d_model=instr_emb_size)
        self.unembed = unembed

        self.to(device)


    def forward(self, world_states, instructions):
        '''
        :param world_states: a matrix of size N x world_state_size
        :param instructions: list (length N) of lists (of length K) of strings. queries[k] - all queries to the k th world.
        :return: a matrix N x K, predictions for every query.
        '''

        assert world_states.shape[1] % self.num_world_state_tokens == 0, "Number of world state tokens is inconsistent with the world state size"
        assert world_states.shape[1] // self.instr_emb_size == self.num_world_state_tokens, "World state size is inconsistent with query embedding size"

        N = world_states.shape[0]
        K = len(instructions[0])
        instruction_embeddings = self.sentence_embedder(list(it.chain.from_iterable(instructions))).view((N, K, -1))
        world_states_pos = self.pos_enc(world_states.view(N, self.num_world_state_tokens, self.instr_emb_size))

        all_inputs = t.cat([world_states_pos, instruction_embeddings], dim=1)

        all_transformed = self.decoder(all_inputs.permute([1, 0, 2]), all_inputs.permute([1, 0, 2]))
        new_ws = all_transformed[0:self.num_world_state_tokens, :, :].permute([1, 0, 2])


        # Unembed
        if self.unembed:
            new_ws = self.pos_enc.unembed(new_ws)

        return new_ws



if __name__ == "__main__":
    bert_model = 'bert-base-uncased'

    max_seq_length = 128
    do_lower_case = True
    train_batch_size = 16
    eval_batch_size = 64
    learning_rate = 1e-5
    num_train_epochs = 1
    warmup_proportion = 0.0

    tokenizer = BertTokenizer.from_pretrained(bert_model)
    model = BertModel.from_pretrained(bert_model)
    model.to(device)

    qemb_size = 32

    se = SentenceEmbedder(model, tokenizer, 128, qemb_size)

    tst_extractor = BertExtractor(se, qemb_size, num_world_state_tokens=5, nhead=4, nlayers=2)

    res = tst_extractor(t.randn((2, qemb_size * 5)).to(device), [["This is query one to book one", "Query two to book one.", "Query three to book one."],
                                                                 ["This is query one to book two.", "This is query number two to book number two.", "This is query three to book two."]])


    tst_updater = BertUpdater(se, qemb_size, num_world_state_tokens=5, nhead=4, nlayers=2)
    tst_updater.forward(t.randn((2, qemb_size * 5)).to(device), [["Instruction 1 1", "Instruction 1 2", "Instruction 1 3"],
                                                                 ["Instruction 2 1", "Instruction 2 2", "Instruction 2 3"]])