import torch as t
import torch.nn as nn
import torch.functional as F
import itertools as it
from models.transformer_layers import *
device = "cuda" if t.cuda.is_available() else "cpu"

class ModelTemplate(nn.Module):

    def __init__(self, *args, **kwargs):

        super().__init__()
        self.modelname = self.__class__.__name__


    def forward(self, triples, state):

        '''
        :arg triples integer tensor of size n x 3
        :arg state - current world state representation
        Receives an integer tensor of size n x 3, h r t format
        :returns n x 1 tensor of predictions
        '''

        raise NotImplementedError("Predict called from ModelTemplate")

    def update(self, triples, targets):
        '''
        :arg triples integer tensor n x 3
        :arg targets float tensor n x 1, values between 0 and 1

        Performs a zero-shot update to make sure that the model give
        answers provided in targets to the triples provided in triples
        '''

        raise NotImplementedError("Update called from ModelTemplate")


class MLP(ModelTemplate):

    def __init__(self, world_state_size, hidden_size, num_entities, num_relations, n_hidden=4):

        print("Initializing a simple MLP model... ")
        super().__init__()

        self.initparams = locals()  # used to recreate the model from a checkpoint
        del self.initparams["self"]

        self.world_state_size = world_state_size
        self.hidden_size = hidden_size
        self.num_entities = num_entities
        self.num_relations = num_relations
        self.n_hidden = n_hidden

        hidden_layer_components = it.chain.from_iterable([(nn.Linear(hidden_size, hidden_size), nn.LeakyReLU()) for _ in range(n_hidden)])

        self.NET = nn.Sequential(nn.Linear(num_entities + num_relations + num_entities \
                                           + world_state_size, hidden_size),
                                         nn.LeakyReLU(),
                                         nn.Sequential(*hidden_layer_components),
                                         nn.Linear(hidden_size, 1))


        self.updater = nn.Sequential(nn.Linear(num_entities + num_relations + num_entities \
                                               + 1 + world_state_size, hidden_size), nn.LeakyReLU(), nn.Linear(hidden_size, world_state_size))

        print("... done.")

    def forward(self, triples, state):

        '''state - either a vector of a size 1 x world_state_size, or a matrix of a size n x world_state_size,
        which is interpreted as providing different world states for different examples in triples'''

        e1 = nn.functional.one_hot(triples[:, 0], num_classes=self.num_entities)
        rels = nn.functional.one_hot(triples[:, 1], num_classes=self.num_relations)
        e2 = nn.functional.one_hot(triples[:, 2], num_classes=self.num_entities)

                                    # First entity onehot # Relation onehot   # 2nd entity onehot
        X = t.empty(triples.shape[0], self.num_entities  + self.num_relations + self.num_entities \
                                    + self.world_state_size, device=device)

        X[:, 0:self.num_entities] = e1
        X[:, self.num_entities:self.num_entities + self.num_relations] = rels
        X[:, self.num_entities + self.num_relations:self.num_entities * 2 + self.num_relations] = e2
        X[:, self.num_entities * 2 + self.num_relations:] = state

        return self.NET.forward(X)

    def update(self, triples, targets, state):

        '''
        :arg triples - a matrix of size (n, 3) - statements for which the answers are provided
        :arg targets - a matrix of size (n, 1) - desired answers for the corresponding triples
        :arg state - either a vector of size (1, world_state_size), or a matrix of a size (n, world_state_size),
        representing world states that need to be updated
        Outputs the new world state. The gradients should flow through the new world states into model parameters,
        but not into previous world states.
        :returns a new state for every triple-target-state triple: a matrix of a size (n, world_state_size)
        '''


        e1 = nn.functional.one_hot(triples[:, 0], num_classes=self.num_entities)
        rels = nn.functional.one_hot(triples[:, 1], num_classes=self.num_relations)
        e2 = nn.functional.one_hot(triples[:, 2], num_classes=self.num_entities)

        X = t.empty(triples.shape[0], self.num_entities * 2 + self.num_relations +
                    1 + self.world_state_size, device=device)
        X[:, 0:self.num_entities] = e1
        X[:, self.num_entities:self.num_entities + self.num_relations] = rels
        X[:, self.num_entities + self.num_relations:self.num_entities * 2 + self.num_relations] = e2
        X[:, self.num_entities * 2 + self.num_relations:self.num_entities * 2 + self.num_relations + 1] = targets
        X[:, self.num_entities * 2 + self.num_relations + 1:] = state

        return self.updater.forward(X)

    def update_recurrent(self, triples, targets, states):

        '''
        :arg triples - a matrix of size (n, num_updates, 3) - statements for which the answers are provided
        :arg targets - a matrix of size (n, num_updates, 1) - desired answers for the corresponding triples
        :arg states - (n, world_state_size), starting states for every update sequence
        representing world states that need to be updated
        Outputs the new world state. The gradients should flow through the new world states into model parameters,
        but not into previous world states.
        :returns a new state for every triple-target-state triple: a matrix of a size (n, world_state_size)
        '''

        seq_num, seq_len, _ = triples.shape
        cur_states = states.detach().clone()

        for i in range(seq_len + 1):

            X = t.empty(triples.shape[0], self.num_entities * 2 + self.num_relations +
                        1 + self.world_state_size, device=device)

            if i != seq_len:
                e1 = nn.functional.one_hot(triples[:, i, 0], num_classes=self.num_entities)
                rels = nn.functional.one_hot(triples[:, i, 1], num_classes=self.num_relations)
                e2 = nn.functional.one_hot(triples[:, i, 2], num_classes=self.num_entities)

                X[:, 0:self.num_entities] = e1
                X[:, self.num_entities:self.num_entities + self.num_relations] = rels
                X[:, self.num_entities + self.num_relations:self.num_entities * 2 + self.num_relations] = e2
                X[:, self.num_entities * 2 + self.num_relations:self.num_entities * 2 + self.num_relations + 1] = targets[:, i]

            X[:, self.num_entities * 2 + self.num_relations + 1:] = cur_states

            cur_states = self.updater.forward(X) + states

        return cur_states

class ResidualBlock(nn.Module):

    def __init__(self, hidden_size, block_depth=4):
        super().__init__()

        hidden_layer_components = it.chain.from_iterable([(nn.Linear(hidden_size, hidden_size), nn.LeakyReLU()) for _ in range(block_depth)])
        self.layers = nn.Sequential(*hidden_layer_components)

    def forward(self, x):
        return x + self.layers.forward(x)


class Residual(MLP):

    def __init__(self, world_state_size, hidden_size, num_entities, num_relations, n_blocks=2, block_depth=4):


        print("Initializing a residual MLP model... ")
        ModelTemplate.__init__(self)

        self.initparams = locals()  # used to recreate the model from a checkpoint
        del self.initparams["self"]

        self.initparams = locals() # used to recreate the model from a checkpoint
        del self.initparams["self"]



        self.world_state_size = world_state_size
        self.hidden_size = hidden_size
        self.num_entities = num_entities
        self.num_relations = num_relations
        self.n_blocks = n_blocks
        self.block_depth = block_depth

        hidden_layer_components = []
        for i in range(n_blocks):
            hidden_layer_components += [ResidualBlock(hidden_size, block_depth=block_depth)]


        self.NET = nn.Sequential(nn.Linear(num_entities + num_relations + num_entities \
                                           + world_state_size, hidden_size),
                                 nn.LeakyReLU(),
                                 nn.Sequential(*hidden_layer_components),
                                 nn.Linear(hidden_size, 1))



        updater_hidden_layer_components = []
        for i in range(n_blocks):
            updater_hidden_layer_components += [ResidualBlock(hidden_size, block_depth=block_depth)]

        self.updater = nn.Sequential(nn.Linear(num_entities + num_relations + num_entities \
                                               + 1 + world_state_size, hidden_size), nn.LeakyReLU(),
                                     nn.Sequential(*updater_hidden_layer_components),
                                     nn.Linear(hidden_size, world_state_size))

        print("... done.")






def parse_transformer_layer(string):
    if string == "IndependentTransformerDecoderLayer":
        return IndependentTransformerDecoderLayer
    elif string == "TransformerEncoderLayer":
        return nn.TransformerEncoderLayer
    elif string == "TransformerDecoderLayer":
        return TransformerDecoderLayer
    elif string == "":
        pass


class TransformerTemplate(ModelTemplate):

    def __init__(self, world_state_size, embedding_size, num_entities, num_relations, num_single_entities=0, n_pre_blocks=0,
                 n_blocks=2, n_post_blocks=0, world_token_count=2, nheads=8, dim_feedforward=2048, pos_enc=1.0, special_mode={}):
        ModelTemplate.__init__(self)
        self.initparams = locals()  # used to recreate the model from a checkpoint
        del self.initparams["self"]
        self.world_state_size = world_state_size
        self.embedding_size = embedding_size
        self.hidden_size = embedding_size * nheads
        self.num_entities = num_entities
        self.num_relations = num_relations
        self.num_single_entities = num_single_entities
        self.n_pre_blocks = n_pre_blocks
        self.n_blocks = n_blocks
        self.n_post_blocks = n_post_blocks
        self.world_token_count = world_token_count
        self.nheads = nheads
        self.dim_feedforward = dim_feedforward
        pos_input = special_mode.get("positional_input")
        if pos_input == 1:
            self.input_pos_enc = PositionalEncoding(self.hidden_size, special_mode.get("positional_input_scale", 1.0))
        elif pos_input == 2:
            assert self.hidden_size % 2 == 0, "hidden size should be even for 2d embeddings"
            self.input_pos_enc = PositionalEncoding(self.hidden_size // 2, special_mode.get("positional_input_scale", 1.0))
        else:
            if special_mode.get("triple_embedding"):
                self.entity_embedding = nn.Embedding(self.num_entities, embedding_size)
                self.relation_embedding = nn.Embedding(self.num_relations, embedding_size)
            else:
                self.entity_embedding = nn.Linear(num_entities, embedding_size)
                self.relation_embedding = nn.Linear(num_relations, embedding_size, bias=False)
            self.triple_embedding = nn.Linear(3 * embedding_size, self.hidden_size)
            self.world_embedding = nn.Linear(world_state_size, self.hidden_size * world_token_count)
        if self.num_single_entities:
            self.class_embedding = nn.Embedding(self.num_single_entities, self.hidden_size)
        if special_mode.get("triple_norm"):
            self.triple_norm = nn.LayerNorm([self.hidden_size])
        else:
            self.triple_norm = None
        if special_mode.get("decoder"):
            self.decoder = True
            transformer_layer = parse_transformer_layer(special_mode.get("transformer_layer",
                                                                         "IndependentTransformerDecoderLayer"))
            transformer_stack = nn.TransformerDecoder
        else:
            self.decoder = False
            transformer_layer = parse_transformer_layer(special_mode.get("transformer_layer",
                                                                         "TransformerEncoderLayer"))
            transformer_stack = nn.TransformerEncoder
        pre_transformer_layer = parse_transformer_layer(special_mode.get("pre_transformer_layer",
                                                                         "TransformerEncoderLayer"))
        pre_transformer_stack = nn.TransformerEncoder
        post_transformer_layer = parse_transformer_layer(special_mode.get("post_transformer_layer",
                                                                         "TransformerEncoderLayer"))
        post_transformer_stack = nn.TransformerEncoder
        # transformer_layer = get_transformer_layer(special_mode)
        layer = transformer_layer(self.hidden_size, nhead=self.nheads, dim_feedforward=self.dim_feedforward,
                                  special_mode=special_mode)
        self.transformer = transformer_stack(layer, num_layers=self.n_blocks)
        if n_pre_blocks > 0:
            layer = pre_transformer_layer(self.hidden_size, nhead=self.nheads, dim_feedforward=self.dim_feedforward)
            self.pre_transformer = pre_transformer_stack(layer, num_layers=self.n_blocks)
        else:
            self.pre_transformer = None
        if n_post_blocks > 0:
            layer = post_transformer_layer(self.hidden_size, nhead=self.nheads, dim_feedforward=self.dim_feedforward,
                                  special_mode=special_mode)
            self.post_transformer = post_transformer_stack(layer, num_layers=self.n_blocks)
        else:
            self.post_transformer = None

        if pos_enc:
            if isinstance(pos_enc, dict):
                self.positional_encoding = LearnablePositionalEncoding()
            self.positional_encoding = PositionalEncoding(self.hidden_size, pos_enc)
        else:
            self.positional_encoding = None
        self.special_mode = special_mode



    def load_state_dict(self, state_dict, strict=False):
        for k in list(state_dict.keys()):
            if "positional_encoding" in k:
                del state_dict[k]
        nn.Module.load_state_dict(self, state_dict, strict=strict)

    def input_embeddings(self, inputs, targets=None):
        inputs = inputs.to(device)
        if inputs.shape[-1] == 3: # triples
            pos_input = self.special_mode.get("positional_input")
            if pos_input == 1:
                positions = inputs[:, :, 0] * self.num_entities + inputs[:, :, 2]
                query_enc = self.input_pos_enc.embed(positions)
            elif pos_input == 2:
                query_enc = self.input_pos_enc.embed2d(inputs[:, :, 0], inputs[:, :, 2])
            else:
                if self.special_mode.get("triple_embedding"):
                    e1 = inputs[:, :, 0]
                    e2 = inputs[:, :, 2]
                    rel = inputs[:, :, 1]
                else:
                    e1 = nn.functional.one_hot(inputs[:, :, 0], num_classes=self.num_entities).float()
                    rel = nn.functional.one_hot(inputs[:, :, 1], num_classes=self.num_relations).float()
                    e2 = nn.functional.one_hot(inputs[:, :, 2], num_classes=self.num_entities).float()
                e1_enc = self.entity_embedding.forward(e1)
                e2_enc = self.entity_embedding.forward(e2)
                rel_enc = self.relation_embedding.forward(rel)
                query_enc = self.triple_embedding(t.cat([e1_enc, rel_enc, e2_enc], dim=-1))
        else:
            query_enc = self.class_embedding.forward(inputs.to(device))
        if targets is not None:
            t_emb = self.target_embedding.forward(targets.float().to(device))
            query_enc += t_emb
        if self.special_mode.get("triple_norm"):
            query_enc = self.triple_norm.forward(query_enc)
        return query_enc

    def apply_transformer(self, transformer, input):  # batch first mode
        return transformer.forward(input.permute([1, 0, 2])).permute([1, 0, 2])


    def apply_transformer_decoder(self, transformer, input, memory):  # batch first mode
        return transformer.forward(input.permute([1, 0, 2]), memory.permute([1, 0, 2])).permute([1, 0, 2])



class TransformerExtractor(TransformerTemplate):


    def __init__(self, world_state_size, embedding_size, num_entities, num_relations, num_single_entities=0, n_blocks=2, world_token_count=2, nheads=8, dim_feedforward=2048, pos_enc=1.0, special_mode={}):
        print("Initializing a Transformer Extractor model... ")
        TransformerTemplate.__init__(**locals())
        query_sequence_length = 1 if self.decoder else 1 + self.world_token_count
        self.query_output_layer = nn.Linear(query_sequence_length * self.hidden_size, 1)
        print("... done.")

    def forward(self, triples, state):
        '''state - either a vector of a size 1 x world_state_size, or a matrix of a size n x world_state_size,
        which is interpreted as providing different world states for different examples in triples'''
        batch_size = triples.shape[0]
        e1 = nn.functional.one_hot(triples[:, 0], num_classes=self.num_entities).float().to(device)
        rel = nn.functional.one_hot(triples[:, 1], num_classes=self.num_relations).float().to(device)
        e2 = nn.functional.one_hot(triples[:, 2], num_classes=self.num_entities).float().to(device)
        e1_enc = self.entity_embedding.forward(e1)
        e2_enc = self.entity_embedding.forward(e2)
        rel_enc = self.relation_embedding.forward(rel)
        query_enc = self.triple_embedding(t.cat([e1_enc, rel_enc, e2_enc], dim=1))
        state_enc = self.world_embedding(state)
        if state_enc.shape[0] == 1:
            state_enc = state_enc.repeat(batch_size, 1)
        full_enc = (t.cat([state_enc, query_enc], dim=1)).view(batch_size, -1, self.hidden_size)
        transformer_output = self.transformer.forward(full_enc.permute([1, 0, 2])).permute([1, 0, 2]).reshape(batch_size, -1)
        return self.query_output_layer(transformer_output)

class TransformerExtractorV2(TransformerExtractor):


    def forward(self, inputs,  state):
        '''state - either a vector of a size 1 x world_state_size, or a matrix of a size n x world_state_size,
        which is interpreted as providing different world states for different examples in triples'''
        embeddings = self.input_embeddings(inputs)
        batch_size = embeddings.shape[0]
        if self.special_mode.get("world_embedding"):
            state_enc = self.world_embedding(state)
        else:
            state_enc = state
        # print("state",state_enc.norm(dim=1))
        if self.positional_encoding is not None:
            state_enc = self.positional_encoding.forward(state_enc.view(-1, self.world_token_count, self.hidden_size)).\
                view(-1, self.world_token_count * self.hidden_size)
        if state_enc.shape[0] == 1:
            state_enc = state_enc.repeat(batch_size, 1)
        # print("pos",state_enc.norm(dim=1), state_enc.shape)
        # print("query",query_enc.norm(dim=1), query_enc.shape)
        if self.decoder:

            transformer_output = self.apply_transformer_decoder(self.transformer,
                                                           embeddings.view(batch_size, -1, self.hidden_size),
                                                           state_enc.view(batch_size, -1, self.hidden_size))
        else:
            full_enc = (t.cat([state_enc, query_enc], dim=1)).view(batch_size, -1, self.hidden_size)
            transformer_output = self.apply_transformer(self.transformer, full_enc).reshape(batch_size, -1)
        return self.query_output_layer(transformer_output)



class TransformerUpdater(TransformerTemplate):

    def __init__(self, world_state_size, embedding_size, num_entities, num_relations, n_blocks=2, world_token_count=2, nheads=8, dim_feedforward=2048, pos_enc=1.0, special_mode=0):
        print("Initializing a Transformer Updater model... ")
        TransformerTemplate.__init__(**locals())
        self.target_embedding = nn.Linear(1, self.hidden_size)
        self.transformer_output_size = self.hidden_size * self.world_token_count
        if self.special_mode.get("world_embedding"):
            self.updater_output_layer = nn.Linear(self.hidden_size, self.world_state_size)
        else:
            self.updater_output_layer = nn.Linear(self.hidden_size * (self.world_token_count + 1), self.world_state_size)
        print("... done.")

    def forward(self, triples, targets, state):

        '''
        :arg triples - a matrix of size (n, num_updates, 3) - statements for which the answers are provided
        :arg targets - a matrix of size (n, num_updates, 1) - desired answers for the corresponding triples
        :arg states - (n, world_state_size), starting states for every update sequence
        representing world states that need to be updated
        Outputs the new world state. The gradients should flow through the new world states into model parameters,
        but not into previous world states.
        :returns a new state for every triple-target-state triple: a matrix of a size (n, world_state_size)
        '''
        batch_size = triples.shape[0]
        update_count = triples.shape[1]
        e1 = nn.functional.one_hot(triples[:, :, 0], num_classes=self.num_entities).float().to(device)
        rel = nn.functional.one_hot(triples[:, :, 1], num_classes=self.num_relations).float().to(device)
        e2 = nn.functional.one_hot(triples[:, :, 2], num_classes=self.num_entities).float().to(device)
        targets = targets.float().to(device)
        e1_enc = self.entity_embedding.forward(e1)
        e2_enc = self.entity_embedding.forward(e2)
        rel_enc = self.relation_embedding.forward(rel)
        target_enc = self.target_embedding.forward(targets)
        triple_enc = self.triple_embedding(t.cat([e1_enc, rel_enc, e2_enc], dim=2))
        update_enc = triple_enc + target_enc
        state = state.detach()
        state_enc = self.world_embedding(state)
        if self.positional_encoding is not None:
            state_enc = self.positional_encoding.forward(state_enc)
        if state_enc.shape[0] == 1:
            state_enc = state_enc.repeat(batch_size, 1)
        full_enc = (t.cat([state_enc, update_enc.view(batch_size, -1)], dim=1)).view(batch_size, -1, self.hidden_size)
        transformer_output = self.transformer.forward(full_enc.permute([1, 0, 2])).permute([1, 0, 2])
        if self.special_mode.get("triple_norm"):
            updater_features = t.cat([transformer_output[:, :self.world_token_count].view(batch_size, -1),
                                      transformer_output[:, self.world_token_count:].sum(dim=1)], dim=1)
        elif self.special_mode.get("world_embedding"):
            updater_features = transformer_output.mean(dim=1)
        else:
            update_tokens = transformer_output.mean(dim=1)
            updater_features = t.cat([transformer_output[:, :self.world_token_count].view(batch_size, -1),
                                      update_tokens], dim=1)
        res = self.updater_output_layer(updater_features)
        # print(res.sum())
        return res
        return self.updater_output_layer(updater_features)


class TransformerUpdaterV2(TransformerUpdater):

    def __init__(self, world_state_size, embedding_size, num_entities, num_relations, n_pre_blocks=2, n_blocks=2,
                 n_post_blocks=2, world_token_count=2, nheads=8, dim_feedforward=2048, pos_enc=1.0, special_mode={}):
        print("Initializing a Transformer Updater model... ")
        TransformerTemplate.__init__(**locals())
        self.target_embedding = nn.Linear(1, self.hidden_size)
        self.transformer_output_size = self.hidden_size * self.world_token_count
        if self.special_mode.get("world_embedding"):
            self.updater_output_layer = nn.Linear(self.hidden_size, self.world_state_size)
        else:
            self.updater_output_layer = nn.Linear(self.hidden_size * (self.world_token_count + 1),
                                                  self.world_state_size)
        print("... done.")


    def forward(self, inputs, targets, state):

        '''
        :arg triples - a matrix of size (n, num_updates, 3) - statements for which the answers are provided
        :arg targets - a matrix of size (n, num_updates, 1) - desired answers for the corresponding triples
        :arg states - (n, world_state_size), starting states for every update sequence
        representing world states that need to be updated
        Outputs the new world state. The gradients should flow through the new world states into model parameters,
        but not into previous world states.
        :returns a new state for every triple-target-state triple: a matrix of a size (n, world_state_size)
        '''
        update_enc = self.input_embeddings(inputs, targets)
        batch_size = update_enc.shape[0]
        state = state.detach()
        if self.special_mode.get("world_embedding"):
            state_enc = self.world_embedding(state)
        else:
            state_enc = state
        if self.positional_encoding is not None:
            state_enc = self.positional_encoding.forward(state_enc.view(-1, self.world_token_count, self.hidden_size)). \
                view(-1, self.world_token_count * self.hidden_size)
        if self.pre_transformer is not None:
            state_pre_features = self.apply_transformer(self.pre_transformer,
                                                   state_enc.view(-1, self.world_token_count, self.hidden_size))
            state_pre_features = state_pre_features
        else:
            state_pre_features = state_enc.view(-1, self.world_token_count, self.hidden_size)
        # if state_pre_features.shape[0] == 1: #broken, probably unnecessary
        #     state_pre_features = state_pre_features.repeat(batch_size, 1)
        update_enc = update_enc.view(batch_size, -1, self.hidden_size)
        if self.decoder:
            transformer_output = self.apply_transformer_decoder(self.transformer, state_pre_features, update_enc)
        else:
            full_enc = (t.cat([state_pre_features, update_enc], dim=1))
            transformer_output = self.apply_transformer(self.transformer, full_enc)
        if self.post_transformer is not None:
            transformer_state_output = transformer_output[:, :self.world_token_count]
            updater_features = self.apply_transformer(self.post_transformer, transformer_state_output).mean(dim=1)
            if not self.special_mode.get("world_embedding"):
                return updater_features.view(batch_size, self.world_state_size)
        else:
            if not self.special_mode.get("world_embedding"):
                return transformer_output[:, :self.world_token_count].reshape(batch_size, self.world_state_size)
            else:
                updater_features = transformer_output.mean(dim=1)
        return self.updater_output_layer(updater_features)

def reinit_model(checkpoint):

    model_class = eval(checkpoint["modelname"])
    return model_class(**checkpoint["modelinits"])

def init_model(params, arg_name="model_vars"):
    model_class = eval(params[arg_name]["model_name"])
    print(params[arg_name]["model_inits"])
    return model_class(**params[arg_name]["model_inits"])


def train_updater():
    b = TransformerUpdater(world_state_size=50, embedding_size=50, num_entities=28, num_relations=1)
    if t.cuda.is_available():
        b.cuda()

    optimizer_net = t.optim.SGD(b.parameters(), lr=1e-6)
    for i in range(300):
        tst_triples = t.randint(0, 28, (100, 2, 3), device=device)
        tst_triples[:, :, 1] = 0 # modeling the situation with only one relation
        tst_targets = t.randint(0, 1, (100, 2, 1), device=device)
        world_state = nn.Parameter(t.rand(100, 50, device=device))
        optimizer_net.zero_grad()
        res = b.forward(tst_triples, tst_targets, world_state)
        loss = ((res - world_state) ** 2)
        ave_loss = loss.sum()
        ave_loss.backward()
        optimizer_net.step()
        print(loss.abs().mean(), ave_loss)




if __name__ == "__main__":
    train_updater()
    exit()
    ## Testing the MLP

    a = Residual(50, 50, 28, 1)
    a = TransformerExtractor(world_state_size=50, embedding_size=50, num_entities=28, num_relations=1)

    if t.cuda.is_available():
        a.cuda()

    world_state = nn.Parameter(t.rand(1, 50, device=device))
    tst_triples = t.randint(0, 28, (100, 3), device=device)
    tst_triples[:, 1] = 0 # modeling the situation with only one relation
    tst_target = t.rand((100, 1), device=device)
    res = a.forward(tst_triples, world_state)
    res.sum().backward()

    b = TransformerUpdater(world_state_size=50, embedding_size=50, num_entities=28, num_relations=1)
    if t.cuda.is_available():
        b.cuda()


    tst_triples = t.randint(0, 28, (100, 2, 3), device=device)
    tst_triples[:, :, 1] = 0 # modeling the situation with only one relation
    tst_targets = t.randint(0, 1, (100, 2, 1), device=device)
    b.forward(tst_triples, tst_targets, world_state)

    new_states = a.update(tst_triples, tst_targets, world_state)

    tst_triples_updater = t.randint(0, 28, (5, 15, 3), device=device)
    tst_triples_updater[:, :, 1] = 0
    tst_targets_updater = t.randint(0, 1, (5, 15, 1), device=device)
    world_states = nn.Parameter(t.rand((5, 50), device=device))

    new_states = a.update_recurrent(tst_triples_updater, tst_targets_updater, world_states)
