import torch
import torch.nn as nn


class GAN2StageZ(nn.Module):
    def __init__(self, z_dim):
        super(GAN2StageZ, self).__init__()
        self.fc = nn.Linear(z_dim, z_dim)
    
    def forward(self, x):
        return self.fc(x)


class GAN2StageDisc(nn.Module):
    def __init__(self, input_window, input_channels, hidden_size, output_size):
        super(GAN2StageDisc, self).__init__()
        self.input_window = input_window 
        self.input_channels = input_channels 
        self.hidden_size = hidden_size 
        self.output_size = output_size
        
        self.fc_rh = nn.Linear(input_channels, hidden_size)
        self.fc_out = nn.Linear(hidden_size, 1) 

        self.gru_cell = nn.GRUCell(input_size = input_channels + hidden_size, hidden_size = hidden_size)

    def initialize_hidden(self, batch_size):
        device = next(self.parameters()).device
        return torch.zeros(batch_size, self.hidden_size, device=device)
        
    def forward(self, x, mask, delta):
        org_shape = x.shape
        mask = mask.reshape(-1, self.input_channels)
        x = x.reshape(-1, self.input_channels)
        delta = delta.reshape(-1, self.input_channels)

        rth = self.fc_rh(delta)
        rth = torch.exp(-1 * torch.maximum(rth, torch.tensor(0.0)))

        x = torch.cat([x, rth], 1)
        x_in = x.reshape(org_shape[0], self.input_window, self.input_channels + self.hidden_size)

        hidden = self.initialize_hidden(x_in.shape[0])
        output = torch.empty([x_in.size()[0], self.input_window, 
                              self.hidden_size], dtype=x_in.dtype, device=x_in.device)
        for i in range(self.input_window):
            hidden = self.gru_cell(x_in[:, i, :], hidden)
            output[:, i, :] = hidden
        
        output = self.fc_out(hidden)

        return output


class GAN2StageGen(torch.nn.Module):
    def __init__(self, z_dim, input_window, input_channels, hidden_size, output_size):
        super(GAN2StageGen, self).__init__()
        self.z_dim = z_dim
        self.input_window = input_window 
        self.input_channels = input_channels 
        self.hidden_size = hidden_size 
        self.output_size = output_size
        
        self.fc_z = nn.Linear(z_dim, input_channels)
        self.fc_rh = nn.Linear(input_channels, hidden_size)
        self.fc_out = nn.Linear(hidden_size, input_channels)

        self.gru_cell = nn.GRUCell(input_size = input_channels + hidden_size, hidden_size = hidden_size)
        
    def initialize_hidden(self, batch_size):
        device = next(self.parameters()).device
        return torch.zeros(batch_size, self.hidden_size, device=device)

    def forward(self, x, mask, delta, z=None, gen=False):
        if gen:
            x = self.fc_z(z).reshape(-1, self.input_channels)
            delta0 = torch.zeros(delta.shape[0], self.input_channels).to(x.device)

            rth = self.fc_rh(delta0)
            rth = torch.exp(-1 * torch.maximum(rth, torch.tensor(0.0)))
            
            x = torch.cat([x, rth], 1)
            x_in = x.reshape([-1, 1, self.input_channels + self.hidden_size])

            hidden = self.initialize_hidden(x_in.shape[0])
            hidden = self.gru_cell(x_in[:, 0, :], hidden)

            init_state = hidden 
            outputs = hidden.reshape(-1, self.hidden_size)

            out_predict = self.fc_out(outputs).reshape(-1, 1, self.input_channels)

            total_result = out_predict * 1.0

            for i in range(1, self.input_window):
                out_predict = out_predict.reshape(mask.shape[0], self.input_channels)
                delta_normal = delta[:, :, i:(i+1)].reshape(mask.shape[0], self.input_channels)
                rth = self.fc_rh(delta_normal)
                rth = torch.exp(-1 * torch.maximum(rth, torch.tensor(0.0)))

                x = torch.cat([out_predict, rth], 1)
                x_in = x.reshape(-1, 1, self.input_channels + self.hidden_size)

                hidden = self.gru_cell(x_in[:, 0, :], init_state)

                init_state = hidden
                outputs = hidden.reshape(-1, self.hidden_size)
                out_predict = self.fc_out(outputs).reshape(-1, 1, self.input_channels)
                total_result = torch.cat([total_result, out_predict], 1)

            return total_result.permute(0, 2, 1)
        else:
            x = x.reshape(-1,  self.input_channels)
            delta = delta.reshape(-1, self.input_channels)

            rth = self.fc_rh(delta)
            rth = torch.exp(-1 * torch.maximum(rth, torch.tensor(0.0)))

            x = torch.cat([x, rth], 1)
            x_in = x.reshape(-1, self.input_window, self.input_channels + self.hidden_size)

            hidden = self.initialize_hidden(x_in.shape[0])
            output = torch.empty([x_in.size()[0], self.input_window, 
                                self.hidden_size], dtype=x_in.dtype, device=x_in.device)
            for i in range(self.input_window):
                hidden = self.gru_cell(x_in[:, i, :], hidden)
                output[:, i, :] = hidden
            
            output = output.reshape(-1, self.hidden_size)
            output = self.fc_out(output).reshape(-1, self.input_window, self.input_channels)

            return output.permute(0, 2, 1)
