import torch

from net.memory.base import BaseMemory


class RNNMemory(BaseMemory):
    """
    series of latents -> history representation
    """
    def __init__(self, latent_size, output_size, nonlinearity, frozen_weights=False):
        super().__init__(latent_size)

        self.output_size = output_size
        self.frozen_weights = frozen_weights

        self.rnn = torch.nn.RNN(latent_size, output_size, batch_first=True, nonlinearity=nonlinearity)
        if frozen_weights:
            for param in self.rnn.parameters():
                param.requires_grad = False
        self.relu = torch.nn.ReLU()

    def forward(self, z, h):
        o, h = self.rnn(z, h)
        ctx = self.relu(o)
        return ctx, h


class LSTMMemory(BaseMemory):
    """
    series of latents -> history representation
    """
    def __init__(self, latent_size, output_size, frozen_weights=False):
        super().__init__(latent_size)

        self.output_size = output_size
        self.frozen_weights = frozen_weights

        self.lstm = torch.nn.LSTM(latent_size, output_size, batch_first=True)
        if frozen_weights:
            for param in self.lstm.parameters():
                param.requires_grad = False
        self.relu = torch.nn.ReLU()

    def forward(self, z, h):
        o, h = self.lstm(z, h)
        ctx = self.relu(o)
        return ctx, h


class GRUMemory(BaseMemory):
    """
    series of latents -> history representation
    """
    def __init__(self, latent_size, output_size, frozen_weights=False):
        super().__init__(latent_size)

        self.output_size = output_size
        self.frozen_weights = frozen_weights

        self.gru = torch.nn.GRU(latent_size, output_size, batch_first=True)
        if frozen_weights:
            for param in self.gru.parameters():
                param.requires_grad = False
        self.relu = torch.nn.ReLU()

    def forward(self, z, h):
        o, h = self.gru(z, h)
        ctx = self.relu(o)
        return ctx, h
