import torch
import torch.nn as nn
import torch.nn.functional as F


# TODO allow user to choose device
if torch.cuda.is_available():
    device = 'cuda:1'
else:
    device = 'cpu'


def init_neural_function(input_type, output_type, input_size, output_size, num_units):
    if (input_type, output_type) == ("atom", "atom"):
        return AtomToAtomModule(input_size, output_size, num_units)
    else:
        raise NotImplementedError


class HeuristicNeuralFunction:

    def __init__(self, input_type, output_type, input_size, output_size, num_units, name):
        self.input_type = input_type
        self.output_type = output_type
        self.input_size = input_size
        self.output_size = output_size
        self.num_units = num_units
        self.name = name
        
        self.init_model()

    def init_model(self):
        raise NotImplementedError

class AtomToAtomModule(HeuristicNeuralFunction):

    def __init__(self, input_size, output_size, num_units):
        super().__init__("atom", "atom", input_size, output_size, num_units, "AtomToAtomModule")

    def init_model(self):
        self.model = FeedForwardModule(self.input_size, self.output_size, self.num_units).to(device)

    def execute_on_batch(self, batch, batch_lens=None):
        assert len(batch.size()) == 2
        model_out = self.model(batch)
        assert len(model_out.size()) == 2
        return model_out


##############################
####### NEURAL MODULES #######
##############################


class FeedForwardModule(nn.Module):

    def __init__(self, input_size, output_size, num_units):
        super(FeedForwardModule, self).__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.hidden_size = num_units
        self.first_layer = nn.Linear(self.input_size, self.hidden_size)
        self.second_layer = nn.Linear(self.hidden_size, self.hidden_size)
        self.out_layer = nn.Linear(self.hidden_size, self.output_size)

    def forward(self, current_input):
        assert isinstance(current_input, torch.Tensor)
        current_input = current_input.to(device).type(torch.float)
        current = F.relu(self.first_layer(current_input))
        #current = F.relu(self.second_layer(current))
        current = self.out_layer(current)
        return current
