import torch
import torch.nn as nn


class E2EDisc(nn.Module):
    def __init__(self, input_window, input_channels, hidden_size, output_size):
        super(E2EDisc, self).__init__()
        self.input_window = input_window 
        self.input_channels = input_channels 
        self.hidden_size = hidden_size 
        self.output_size = output_size
        
        self.fc_rh = nn.Linear(input_channels, hidden_size)
        self.fc_out = nn.Linear(hidden_size, 1) 

        self.gru_cell = nn.GRUCell(input_size = input_channels + hidden_size, hidden_size = hidden_size)

    def initialize_hidden(self, batch_size):
        device = next(self.parameters()).device
        return torch.zeros(batch_size, self.hidden_size, device=device)
        
    def forward(self, x, mask, delta):
        org_shape = x.shape
        mask = mask.reshape(-1, self.input_channels)
        x = x.reshape(-1, self.input_channels)
        delta = delta.reshape(-1, self.input_channels)

        rth = self.fc_rh(delta)
        rth = torch.exp(-1 * torch.maximum(rth, torch.tensor(0.0)))

        x = torch.cat([x, rth], 1)
        x_in = x.reshape(org_shape[0], self.input_window, self.input_channels + self.hidden_size)

        hidden = self.initialize_hidden(x_in.shape[0])
        output = torch.empty([x_in.size()[0], self.input_window, 
                              self.hidden_size], dtype=x_in.dtype, device=x_in.device)
        for i in range(self.input_window):
            hidden = self.gru_cell(x_in[:, i, :], hidden)
            output[:, i, :] = hidden
        
        output = self.fc_out(hidden)

        return output


class E2EGen(torch.nn.Module):
    def __init__(self, z_dim, input_window, input_channels, hidden_size, output_size):
        super(E2EGen, self).__init__()
        self.z_dim = z_dim
        self.input_window = input_window 
        self.input_channels = input_channels 
        self.hidden_size = hidden_size 
        self.output_size = output_size

        self.fc_z_rh = nn.Linear(input_channels, hidden_size)
        self.fc_z_out = nn.Linear(hidden_size, z_dim)

        self.fc_rh = nn.Linear(input_channels, hidden_size)
        self.fc_out = nn.Linear(hidden_size, input_channels)
        self.fc_z = nn.Linear(z_dim, input_channels)

        self.gru_cell_z = nn.GRUCell(input_size = input_channels + hidden_size, hidden_size = hidden_size)
        self.gru_cell_g = nn.GRUCell(input_size = input_channels + hidden_size, hidden_size = hidden_size)

    def initialize_hidden(self, batch_size):
        device = next(self.parameters()).device
        return torch.zeros(batch_size, self.hidden_size, device=device)

    def forward(self, x, mask, ita, delta):
        org_shape = x.shape
        x = x + ita
        mask = mask.reshape(-1, self.input_channels)
        x = x.reshape(-1, self.input_channels)
        delta_org = delta
        delta = delta.reshape(-1, self.input_channels)

        z_rth = self.fc_z_rh(delta)
        z_rth = torch.exp(-1 * torch.maximum(z_rth, torch.tensor(0.0)))
        
        x = torch.cat([x, z_rth], 1)
        x_in = x.reshape(org_shape[0], self.input_window, self.input_channels + self.hidden_size)

        hidden = self.initialize_hidden(x_in.shape[0])
        output = torch.empty([x_in.size()[0], self.input_window, 
                                self.hidden_size], dtype=x_in.dtype, device=x_in.device)
        for i in range(self.input_window):
            hidden = self.gru_cell_z(x_in[:, i, :], hidden)
            output[:, i, :] = hidden

        z = self.fc_z_out(hidden)

        x = self.fc_z(z)
        x = x.reshape(-1, self.input_channels)
        delta0 = torch.zeros(org_shape[0], self.input_channels).to(x.device)

        rth = self.fc_rh(delta0)
        rth = torch.exp(-1 * torch.maximum(rth, torch.tensor(0.0)))
        
        x = torch.cat([x, rth], 1)
        x_in = x.reshape(-1, 1, self.input_channels + self.hidden_size)

        hidden = self.initialize_hidden(x_in.shape[0])
        hidden = self.gru_cell_g(x_in[:, 0, :], hidden)

        init_state = hidden 
        outputs = hidden.reshape(-1, self.hidden_size)

        out_predict = self.fc_out(outputs).reshape(-1, 1, self.input_channels)

        total_result = out_predict * 1.0

        for i in range(1, self.input_window):
            out_predict = out_predict.reshape(org_shape[0], self.input_channels)
            delta_normal = delta_org[:, :, i:(i+1)].reshape(org_shape[0], self.input_channels)
            rth = self.fc_rh(delta_normal)
            rth = torch.exp(-1 * torch.maximum(rth, torch.tensor(0.0)))

            x = torch.cat([out_predict, rth], 1)
            x_in = x.reshape(-1, 1, self.input_channels + self.hidden_size)

            hidden = self.gru_cell_g(x_in[:, 0, :], init_state)

            init_state = hidden
            outputs = hidden.reshape(-1, self.hidden_size)
            out_predict = self.fc_out(outputs).reshape(-1, 1, self.input_channels)
            total_result = torch.cat([total_result, out_predict], 1)

        return total_result.permute(0, 2, 1)