import torch as t
import torch.nn as nn
import torch.functional as F
import itertools as it
from models.transformer_layers import *
device = "cuda" if t.cuda.is_available() else "cpu"

class ModelTemplate(nn.Module):

    def __init__(self, *args, **kwargs):

        super().__init__()
        self.modelname = self.__class__.__name__


    def forward(self, triples, state):

        '''
        :arg triples integer tensor of size n x 3
        :arg state - current world state representation
        Receives an integer tensor of size n x 3, h r t format
        :returns n x 1 tensor of predictions
        '''

        raise NotImplementedError("Predict called from ModelTemplate")







def parse_transformer_layer(string):
    if string == "IndependentTransformerDecoderLayer":
        return IndependentTransformerDecoderLayer
    elif string == "TransformerEncoderLayer":
        return nn.TransformerEncoderLayer
    elif string == "TransformerDecoderLayer":
        return TransformerDecoderLayer
    elif string == "":
        pass


class TransformerTemplate(ModelTemplate):

    def __init__(self, world_state_size, embedding_size, num_entities, num_relations, num_single_entities=0, n_pre_blocks=0,
                 n_blocks=2, n_post_blocks=0, world_token_count=2, nheads=8, dim_feedforward=2048, world_offset=1.0, input_encoding=None, special_mode={}):
        ModelTemplate.__init__(self)
        self.initparams = locals()  # used to recreate the model from a checkpoint
        del self.initparams["self"]
        self.world_state_size = world_state_size
        self.embedding_size = embedding_size
        self.hidden_size = embedding_size * nheads
        self.num_entities = num_entities
        self.num_relations = num_relations
        self.num_single_entities = num_single_entities
        self.n_pre_blocks = n_pre_blocks
        self.n_blocks = n_blocks
        self.n_post_blocks = n_post_blocks
        self.world_token_count = world_token_count
        self.nheads = nheads
        self.dim_feedforward = dim_feedforward
        if isinstance(input_encoding, dict):
            self.input_pos_enc = LearnablePositionalEncoding(self.hidden_size, **input_encoding)
        if input_encoding == 1:
            self.input_pos_enc = PositionalEncoding(self.hidden_size,  1.0)
        elif input_encoding == 2:
            assert self.hidden_size % 2 == 0, "hidden size should be even for 2d embeddings"
            self.input_pos_enc = PositionalEncoding(self.hidden_size // 2, 1.0)
        else:
            self.entity_embedding = nn.Embedding(self.num_entities, embedding_size)
            self.relation_embedding = nn.Embedding(self.num_relations, embedding_size)
            self.triple_embedding = nn.Linear(3 * embedding_size, self.hidden_size)
        self.input_encoding = input_encoding
        if self.num_single_entities:
            self.class_embedding = nn.Embedding(self.num_single_entities, self.hidden_size)
            if special_mode.get("class_transformer_layer"):
                class_transformer_args = special_mode.get("class_transformer_layer")

                class_transformer_layer = parse_transformer_layer(class_transformer_args.get("layer",
                                                                           "IndependentTransformerDecoderLayer"))
                class_transformer_stack = nn.TransformerDecoder
                layer = class_transformer_layer(self.hidden_size, nhead=class_transformer_args.get("heads", self.nheads),
                                                dim_feedforward=class_transformer_args.get("dim_feedforward",
                                                                                           self.dim_feedforward),
                                                special_mode=special_mode)
                self.class_transformer = class_transformer_stack(layer, num_layers=class_transformer_args.get("n_blocks",
                                                                                                              self.n_blocks))


        if special_mode.get("triple_norm"):
            self.triple_norm = nn.LayerNorm([self.hidden_size])
        else:
            self.triple_norm = None
        transformer_layer = parse_transformer_layer(special_mode.get("transformer_layer",
                                                                     "IndependentTransformerDecoderLayer"))
        transformer_stack = nn.TransformerDecoder

        pre_transformer_layer = parse_transformer_layer(special_mode.get("pre_transformer_layer",
                                                                         "TransformerEncoderLayer"))
        pre_transformer_stack = nn.TransformerEncoder
        post_transformer_layer = parse_transformer_layer(special_mode.get("post_transformer_layer",
                                                                         "TransformerEncoderLayer"))
        post_transformer_stack = nn.TransformerEncoder
        # transformer_layer = get_transformer_layer(special_mode)
        layer = transformer_layer(self.hidden_size, nhead=self.nheads, dim_feedforward=self.dim_feedforward,
                                  special_mode=special_mode)
        self.transformer = transformer_stack(layer, num_layers=self.n_blocks)
        if n_pre_blocks > 0:
            layer = pre_transformer_layer(self.hidden_size, nhead=self.nheads, dim_feedforward=self.dim_feedforward)
            self.pre_transformer = pre_transformer_stack(layer, num_layers=self.n_blocks)
        else:
            self.pre_transformer = None
        if n_post_blocks > 0:
            layer = post_transformer_layer(self.hidden_size, nhead=self.nheads, dim_feedforward=self.dim_feedforward,
                                  special_mode=special_mode)
            self.post_transformer = post_transformer_stack(layer, num_layers=self.n_blocks)
        else:
            self.post_transformer = None

        if world_offset:
            if isinstance(world_offset, dict):
                self.world_offset = LearnablePositionalEncoding(self.hidden_size, **world_offset)
            else:
                self.world_offset = PositionalEncoding(self.hidden_size, world_offset)
        else:
            self.world_offset = None
        self.special_mode = special_mode



    def load_state_dict(self, state_dict, strict=False):
        # for k in list(state_dict.keys()):
        #     if "positional_encoding" in k:
        #         del state_dict[k]
        nn.Module.load_state_dict(self, state_dict, strict=strict)

    def input_embeddings(self, inputs, targets=None):
        inputs = inputs.to(device)
        if inputs.shape[-1] == 3: # triples
            if isinstance(self.input_encoding, dict):
                positions = inputs[:, :, 0] * self.num_entities + inputs[:, :, 2]
                query_enc = self.input_pos_enc.embed(positions)
            elif self.input_encoding == 1:
                positions = inputs[:, :, 0] * self.num_entities + inputs[:, :, 2]
                query_enc = self.input_pos_enc.embed(positions)
            elif self.input_encoding == 2:
                query_enc = self.input_pos_enc.embed2d(inputs[:, :, 0], inputs[:, :, 2])
            else:
                e1 = inputs[:, :, 0]
                e2 = inputs[:, :, 2]
                rel = inputs[:, :, 1]
                e1_enc = self.entity_embedding.forward(e1)
                e2_enc = self.entity_embedding.forward(e2)
                rel_enc = self.relation_embedding.forward(rel)
                query_enc = self.triple_embedding(t.cat([e1_enc, rel_enc, e2_enc], dim=-1))
        else:
            query_enc = self.class_embedding.forward(inputs.to(device))
        if targets is not None:
            t_emb = self.target_embedding.forward(targets.float().to(device))
            query_enc += t_emb
        if self.special_mode.get("triple_norm"):
            query_enc = self.triple_norm.forward(query_enc)
        return query_enc

    def apply_transformer(self, transformer, input):  # batch first mode
        return transformer.forward(input.permute([1, 0, 2])).permute([1, 0, 2])


    def apply_transformer_decoder(self, transformer, input, memory):  # batch first mode
        return transformer.forward(input.permute([1, 0, 2]), memory.permute([1, 0, 2])).permute([1, 0, 2])



class TransformerExtractor(TransformerTemplate):


    def __init__(self, world_state_size, embedding_size, num_entities, num_relations, num_single_entities=0, n_blocks=2, world_token_count=2, nheads=8, dim_feedforward=2048, world_offset=1.0, input_encoding=None, special_mode={}):
        print("Initializing a Transformer Extractor model... ")
        TransformerTemplate.__init__(**locals())
        self.query_output_layer = nn.Linear(self.hidden_size, 1)
        if self.special_mode.get("class_transformer_layer"):
            self.class_output_layer = nn.Linear(self.hidden_size, 1)
        print("... done.")

    def forward(self, inputs,  state):
        '''state - either a vector of a size 1 x world_state_size, or a matrix of a size n x world_state_size,
        which is interpreted as providing different world states for different examples in triples'''


        embeddings = self.input_embeddings(inputs)
        batch_size = embeddings.shape[0]
        state_enc = state
        # print("state",state_enc.norm(dim=1))
        if self.world_offset is not None:
            state_enc = self.world_offset.forward(state_enc.view(-1, self.world_token_count, self.hidden_size)).\
                view(-1, self.world_token_count * self.hidden_size)
        if state_enc.shape[0] == 1:
            state_enc = state_enc.repeat(batch_size, 1)
        # print("pos",state_enc.norm(dim=1), state_enc.shape)
        # print("query",query_enc.norm(dim=1), query_enc.shape)
        if inputs.shape[-1] != 3 and self.special_mode.get("class_transformer_layer"):
            transformer_output = self.apply_transformer_decoder(self.class_transformer,
                                                       embeddings.view(batch_size, -1, self.hidden_size),
                                                       state_enc.view(batch_size, -1, self.hidden_size))
            return self.class_output_layer(transformer_output)
        else:
            transformer_output = self.apply_transformer_decoder(self.transformer,
                                                       embeddings.view(batch_size, -1, self.hidden_size),
                                                       state_enc.view(batch_size, -1, self.hidden_size))
            return self.query_output_layer(transformer_output)



class TransformerUpdater(TransformerTemplate):


    def __init__(self, world_state_size, embedding_size, num_entities, num_relations, n_pre_blocks=2, n_blocks=2,
                 n_post_blocks=2, world_token_count=2, nheads=8, dim_feedforward=2048, world_offset=1.0, input_encoding=None, special_mode={}):
        print("Initializing a Transformer Updater model... ")
        TransformerTemplate.__init__(**locals())
        self.target_embedding = nn.Linear(1, self.hidden_size)
        self.transformer_output_size = self.hidden_size * self.world_token_count
        print("... done.")


    def forward(self, inputs, targets, state):

        '''
        :arg triples - a matrix of size (n, num_updates, 3) - statements for which the answers are provided
        :arg targets - a matrix of size (n, num_updates, 1) - desired answers for the corresponding triples
        :arg states - (n, world_state_size), starting states for every update sequence
        representing world states that need to be updated
        Outputs the new world state. The gradients should flow through the new world states into model parameters,
        but not into previous world states.
        :returns a new state for every triple-target-state triple: a matrix of a size (n, world_state_size)
        '''
        update_enc = self.input_embeddings(inputs, targets)
        batch_size = update_enc.shape[0]
        state = state.detach()
        state_enc = state
        if self.world_offset is not None:
            state_enc = self.world_offset.forward(state_enc.view(-1, self.world_token_count, self.hidden_size)). \
                view(-1, self.world_token_count * self.hidden_size)
        if self.pre_transformer is not None:
            state_pre_features = self.apply_transformer(self.pre_transformer,
                                                   state_enc.view(-1, self.world_token_count, self.hidden_size))
            state_pre_features = state_pre_features
        else:
            state_pre_features = state_enc.view(-1, self.world_token_count, self.hidden_size)
        # if state_pre_features.shape[0] == 1: #broken, probably unnecessary
        #     state_pre_features = state_pre_features.repeat(batch_size, 1)
        update_enc = update_enc.view(batch_size, -1, self.hidden_size)
        transformer_output = self.apply_transformer_decoder(self.transformer, state_pre_features, update_enc)
        if self.post_transformer is not None:
            transformer_state_output = transformer_output[:, :self.world_token_count]
            result = self.apply_transformer(self.post_transformer, transformer_state_output).mean(dim=1)
        else:
            result = transformer_output[:, :self.world_token_count]
        if self.special_mode.get("subtract_world_offset"):
            result = self.world_offset.forward(result, negative=True)
        return result.reshape(batch_size, self.world_state_size)

class Classifier(nn.Module):


    def __init__(self, world_state_size, hidden_size, num_layers, **kwargs):
        super().__init__()
        layers = [nn.Linear(world_state_size, hidden_size), nn.ReLU()]
        for i in range(num_layers-2):
            layers += [nn.Linear(hidden_size, hidden_size), nn.ReLU()]
        layers += [nn.Linear(hidden_size, 1)]
        self.layers = nn.Sequential(*layers)

    def forward(self, state):
        batch_size = state.shape[0]
        state = state.view(batch_size,-1)
        return self.layers.forward(state)





def reinit_model(checkpoint):

    model_class = eval(checkpoint["modelname"])
    return model_class(**checkpoint["modelinits"])

def init_model(params, arg_name="model_vars"):
    model_class = eval(params[arg_name]["model_name"])
    print(params[arg_name]["model_inits"])
    return model_class(**params[arg_name]["model_inits"])


