import torch
import numpy as np
import pandas as pd
import os
import math
import warnings
import itertools
import numbers
import torch.utils.data as utils
from neuralfaults.impute_models.grud_layers import GRUD_cell



def grud_model_old( input_size, hidden_size, output_size, num_layers=1, x_mean=0,\
                bias=True, batch_first=False, bidirectional=False, dropout_type='mloss', dropout=0):

    layer_list =[]
    for i in range(num_layers-1):
        layer = GRUD_cell(input_size = input_size, hidden_size= hidden_size, output_size=input_size, dropout=dropout, dropout_type=dropout_type, x_mean=x_mean, num_layers=num_layers, return_hidden = True)
        layer_list.append(layer)

    layer = GRUD_cell(input_size = input_size, hidden_size=hidden_size, output_size=output_size, dropout=dropout, dropout_type=dropout_type, x_mean=x_mean, num_layers=num_layers, return_hidden = False)
    layer_list.append(layer)

    model = torch.nn.Sequential(*layer_list)

    return model


class GRUD(torch.nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers = 1, x_mean = 0.,\
     bias =True, batch_first = False, bidirectional = False, dropout_type ='mloss', dropout = 0):
        super(GRUD, self).__init__()

        self.gru_d = GRUD_cell(input_size = input_size, hidden_size= hidden_size, output_size=output_size, 
                dropout=dropout, dropout_type=dropout_type, x_mean=x_mean)
        self.hidden_to_output = torch.nn.Linear(hidden_size, output_size, bias=True)
        self.num_layers = num_layers
        self.hidden_size = hidden_size

        if self.num_layers > 1:
            self.gru_layers = torch.nn.GRU(input_size = hidden_size, hidden_size = hidden_size, batch_first = True, num_layers = self.num_layers -1, dropout=dropout)

    def initialize_hidden(self, batch_size):
        device = next(self.parameters()).device
        return torch.zeros(self.num_layers-1, batch_size, self.hidden_size, device=device)

    def forward(self, x, mask, delta):
        output, hidden = self.gru_d(x, mask, delta)

        if self.num_layers >1:
            output, hidden = self.gru_layers(hidden)

            output = self.hidden_to_output(output)
            # output = torch.sigmoid(output)
            
        return output.permute(0, 2, 1)