import torch
from modules import GRUCell, Linear, LSTMCell, RNNCell


class GRU(torch.nn.Module):
    def __init__(self, glove_dim=100, hidden_units=64, num_classes=2):
        super().__init__()
        self.rnn = GRUCell(glove_dim, hidden_units)
        self.linear = Linear(hidden_units, num_classes)

    def forward(self, x, add_noise=False):
        emb = x
        h = torch.zeros(emb.shape[0], self.rnn.hidden_size, dtype=emb.dtype, device=emb.device)
        for t in range(emb.shape[1]):
            h = self.rnn(emb[:,t,:],h, add_noise)
        output = self.linear(h, add_noise)
        return output

    def backward(self, loss):
        self.rnn.backward(loss)
        self.linear.backward(loss)

    def clear_buf(self):
        self.rnn.clear_buf()
        self.linear.clear_buf()
        
class LSTM(torch.nn.Module):
    def __init__(self, glove_dim=100, hidden_units=64, num_classes=2):
        super().__init__()
        self.rnn =  LSTMCell(glove_dim, hidden_units)
        self.linear = Linear(hidden_units, num_classes)

    def forward(self, x, add_noise=False):
        emb = x
        hx = None
        for t in range(emb.shape[1]):
            h, c = self.rnn(emb[:,t,:], hx, add_noise)
            hx = (h, c)
        output = self.linear(c, add_noise)
        return output

    def backward(self, loss):
        self.rnn.backward(loss)
        self.linear.backward(loss)

    def clear_buf(self):
        self.rnn.clear_buf()
        self.linear.clear_buf()
        
        
class RNN(torch.nn.Module):
    def __init__(self, glove_dim=100, hidden_units=64, num_classes=2):
        super().__init__()
        self.rnn =  RNNCell(glove_dim, hidden_units)
        self.linear = Linear(hidden_units, num_classes)

    def forward(self, x, add_noise=False):
        emb = x
        h = torch.zeros(emb.shape[0], self.rnn.hidden_size, dtype=emb.dtype, device=emb.device)
        for t in range(emb.shape[1]):
            h = self.rnn(emb[:,t,:],h, add_noise)
        output = self.linear(h, add_noise)
        return output

    def backward(self, loss):
        self.rnn.backward(loss)
        self.linear.backward(loss)

    def clear_buf(self):
        self.rnn.clear_buf()
        self.linear.clear_buf()
