
import torch as t



class LSTMClassifier(t.nn.Module):

    def __init__(self, tokens=10, width=32, depth=4, dropout=.1):
        super().__init__()
        tokens += 1
        self.embedding = t.nn.Embedding(tokens, width)
        self.lstm = t.nn.LSTM(input_size=width, hidden_size=width, num_layers=depth, batch_first=True, dropout=dropout)
        self.output = t.nn.Linear(width, tokens)
        self.width = width
        self.depth = depth

    def forward(self, input, cut_gradient=False):
        x = self.embedding(input)
        if cut_gradient:
            batch_size = input.shape[0]
            sequence_length = input.shape[1]
            x = self.embedding(input)
            h = t.zeros((self.depth, batch_size, self.width), device=self.embedding.weight.device)
            c = t.zeros((self.depth, batch_size, self.width), device=self.embedding.weight.device)
            lstm_outputs = []
            for sequence_element in range(sequence_length):
                lstm_output, (h, c) = self.lstm.forward(x[:, sequence_element:sequence_element + 1], (h, c))
         
                lstm_outputs += [lstm_output]
                h = h.detach()
                c = c.detach()

            lstm_outputs = t.cat(lstm_outputs, dim=1)
            y = self.output.forward(lstm_outputs)
            return y

        else:
            lstm_output, _ = self.lstm.forward(x)
            y = self.output.forward(lstm_output)
            return y


    def forward_with_loss(self, input, target, loss_fn, cut_gradient=False):
        x = self.embedding(input)

        batch_size = input.shape[0]
        sequence_length = input.shape[1]
        x = self.embedding(input)
        h = t.zeros((self.depth, batch_size, self.width), device=self.embedding.weight.device)
        c = t.zeros((self.depth, batch_size, self.width), device=self.embedding.weight.device)
        lstm_outputs = []

        for sequence_element in range(sequence_length):
            lstm_output, (h, c) = self.lstm.forward(x[:, sequence_element:sequence_element + 1], (h, c))
            cur_y = self.output.forward(lstm_output)
            cur_loss = loss_fn(cur_y.view(-1, cur_y.shape[-1]), target[:, sequence_element:sequence_element + 1].flatten()).mean()

            cur_loss.backward(retain_graph=True)
 
            if cut_gradient:
                h = h.detach()
                c = c.detach()
                lstm_outputs += [lstm_output.detach()]
            else:
                lstm_outputs += [lstm_output]

        lstm_outputs = t.cat(lstm_outputs, dim=1)
        y = self.output.forward(lstm_outputs)


        return y

    def forward_one(self, input, h, c):
        x = self.embedding(input)

        lstm_output, (h, c) = self.lstm.forward(x, (h, c))
        h = h.detach()
        c = c.detach()

        lstm_outputs = t.cat([lstm_output], dim=1)
        y = self.output.forward(lstm_outputs)
        return y, h, c




class LSTMClassifierMultipleQueries(t.nn.Module):

    def __init__(self, tokens=10, width=32, depth=4, dropout=.1):
        super().__init__()
        tokens += 1
        self.embedding = t.nn.Embedding(tokens, width)
        self.lstm = t.nn.LSTM(input_size=width, hidden_size=width, num_layers=depth, batch_first=True, dropout=dropout)
        self.output = t.nn.Linear(width, tokens)
        self.width = width
        self.depth = depth

    def forward(self, input, cut_gradient=False):
        x = self.embedding(input)
        if cut_gradient:
            batch_size = input.shape[0]
            sequence_length = input.shape[1]
            x = self.embedding(input)
            h = t.zeros((self.depth, batch_size, self.width), device=self.embedding.weight.device)
            c = t.zeros((self.depth, batch_size, self.width), device=self.embedding.weight.device)
            lstm_outputs = []
            for sequence_element in range(sequence_length):
                lstm_output, (h, c) = self.lstm.forward(x[:, sequence_element:sequence_element + 1], (h, c))
   
                lstm_outputs += [lstm_output]
                h = h.detach()
                c = c.detach()

            lstm_outputs = t.cat(lstm_outputs, dim=1)
            y = self.output.forward(lstm_outputs)
            return y

        else:
            lstm_output, _ = self.lstm.forward(x)
            y = self.output.forward(lstm_output)
            return y

    def forward_with_loss(self, input, target, loss_fn, cut_gradient=False):
        x = self.embedding(input)

        batch_size = input.shape[0]
        sequence_length = input.shape[1]
        x = self.embedding(input)
        h = t.zeros((self.depth, batch_size, self.width), device=self.embedding.weight.device)
        c = t.zeros((self.depth, batch_size, self.width), device=self.embedding.weight.device)
        lstm_outputs = []

        for sequence_element in range(sequence_length):
            lstm_output, (h, c) = self.lstm.forward(x[:, sequence_element:sequence_element + 1], (h, c))
            cur_y = self.output.forward(lstm_output)
            cur_loss = loss_fn(cur_y.view(-1, cur_y.shape[-1]), target[:, sequence_element:sequence_element + 1].flatten()).mean()

            cur_loss.backward(retain_graph=True)

            if cut_gradient:
                h = h.detach()
                c = c.detach()
                lstm_outputs += [lstm_output.detach()]
            else:
                lstm_outputs += [lstm_output]

        lstm_outputs = t.cat(lstm_outputs, dim=1)
        y = self.output.forward(lstm_outputs)


        return y

    def forward_one(self, input, h, c):
        x = self.embedding(input)

        lstm_output, (h, c) = self.lstm.forward(x, (h, c))
        h = h.detach()
        c = c.detach()

        lstm_outputs = t.cat([lstm_output], dim=1)
        y = self.output.forward(lstm_outputs)
        return y, h, c


class LSTMUpdater(t.nn.Module):

    def __init__(self, tokens=10, width=32, depth=4, dropout=.1):
        super().__init__()
        tokens += 1
        self.embedding = t.nn.Embedding(tokens, width)
        self.lstm = t.nn.LSTM(input_size=width, hidden_size=width, num_layers=depth, batch_first=True, dropout=dropout)
        self.width = width
        self.depth = depth

    def forward(self, input, cut_gradient=False):
        x = self.embedding(input)
        if cut_gradient:
            batch_size = input.shape[0]
            sequence_length = input.shape[1]
            x = self.embedding(input)
            h = t.zeros((self.depth, batch_size, self.width), device=self.embedding.weight.device)
            c = t.zeros((self.depth, batch_size, self.width), device=self.embedding.weight.device)
            lstm_outputs = []
            for sequence_element in range(sequence_length):
                lstm_output, (h, c) = self.lstm.forward(x[:, sequence_element:sequence_element + 1], (h, c))
                lstm_outputs += [lstm_output]
                h = h.detach()
                c = c.detach()
            lstm_outputs = t.cat(lstm_outputs, dim=1)
            return lstm_outputs

        else:
            lstm_output, _ = self.lstm.forward(x)
            return lstm_output



class MLPExtractor(t.nn.Module):

    def __init__(self, tokens=10, width=32, depth=4, dropout=.1):
        super().__init__()

        tokens += 1
        self.embedding = t.nn.Embedding(2, width)
        self.output = t.nn.Linear(width * 4, tokens)
        self.i_to_h = t.nn.Linear(width * 2, width * 4)
        self.width = width
        self.depth = depth

        self.act = t.nn.ReLU()

    def forward(self, input, world_state):
        x = self.embedding(input)

        tmp = t.cat([x, world_state], dim=2)

        hidden = self.act(self.i_to_h.forward(tmp))
        y = self.output(hidden)
        return y

