import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
from library.mnist_net import MNIST_Net
from library.LFL_modules import *
from library.LFL_layers import *
    
class MNISTAddLFL_MLP(nn.Module):
    '''Direct NeSy predictor with LFL-Type1 and MLP gradient shortcut.'''
    def __init__(self, n_choice=10, n_hidden=256, n_output=19, fuzzifier_layer_kwargs={}, lfl_layer_kwargs=[{}, {}]):
        super().__init__()
        self.cnn = MNIST_Net(N=n_choice)
        self.fuzzifier_layer = ConcreteLayer(**fuzzifier_layer_kwargs)
        self.softmax_layer = nn.Softmax(dim=1)
        self.reconstruction_layer = nn.Linear(n_choice, 784)
        self.lfl = MultiLayerLFL(n_choice * 2, [n_hidden, n_output], lfl_layer_kwargs)
        self.mlp = nn.Sequential(nn.Linear(n_choice * 2, n_hidden), nn.ReLU(), nn.Linear(n_hidden, n_output), nn.Sigmoid())

    def forward(self, left_image, right_image):
        left_pred, right_pred = self.cnn(left_image), self.cnn(right_image)
        left_label, right_label = self.fuzzifier_layer(left_pred), self.fuzzifier_layer(right_pred)
        left_reconstruction, right_reconstruction = self.reconstruction_layer(left_label.detach()).reshape(-1, 1, 28, 28), self.reconstruction_layer(right_label.detach()).reshape(-1, 1, 28, 28)
        labels = torch.cat([left_label, right_label], dim=1)
        pred = self.lfl(labels)
        pred_mlp = self.mlp(labels)
        left_label_mean, right_label_mean = torch.mean(self.softmax_layer(left_pred), dim=0), torch.mean(self.softmax_layer(right_pred), dim=0)
        return pred, left_reconstruction, right_reconstruction, pred_mlp, left_label_mean, right_label_mean

class MNISTAddDSL(nn.Module):
    '''Our implementation of DSL NeSy predictor.'''
    def __init__(self, n_choice=10, n_output=19, epsilon_symbol=0.2807344052335263, epsilon_rule=0.1077119516324264):
        super().__init__()
        self.cnn = MNIST_Net(N=n_choice)
        self.fuzzifier_layer = EpsilonGreedySoftmaxLayer(epsilon=epsilon_symbol)
        self.reconstruction_layer = nn.Linear(n_choice, 784)
        self.lfl = DSLLogicModule([n_choice, n_choice], [n_output], epsilon_rule=epsilon_rule)
        
    def forward(self, left_image, right_image):
        left_pred, right_pred = self.cnn(left_image), self.cnn(right_image)
        left_label, right_label = self.fuzzifier_layer(left_pred), self.fuzzifier_layer(right_pred)
        left_reconstruction, right_reconstruction = self.reconstruction_layer(left_label.detach()).reshape(-1, 1, 28, 28), self.reconstruction_layer(right_label.detach()).reshape(-1, 1, 28, 28)
        labels = torch.cat([left_label, right_label], dim=1)
        pred = self.lfl(labels)
        return pred, left_reconstruction, right_reconstruction

class MNISTAddConcreteDSL(nn.Module):
    '''Direct NeSy predictor with LFL-Type2.'''
    def __init__(self, n_choice=10, n_output=19, fuzzifier_layer_kwargs={}, lfl_layer_kwargs=[{}, {}]):
        super().__init__()
        self.cnn = MNIST_Net(N=n_choice)
        self.fuzzifier_layer = ConcreteLayer(**fuzzifier_layer_kwargs)
        self.reconstruction_layer = nn.Linear(n_choice, 784)
        self.lfl = ConcreteDSLLogicModule([n_choice, n_choice], [n_output], layer_kwargs=lfl_layer_kwargs)

    def forward(self, left_image, right_image):
        left_pred, right_pred = self.cnn(left_image), self.cnn(right_image)
        left_label, right_label = self.fuzzifier_layer(left_pred), self.fuzzifier_layer(right_pred)
        left_reconstruction, right_reconstruction = self.reconstruction_layer(left_label.detach()).reshape(-1, 1, 28, 28), self.reconstruction_layer(right_label.detach()).reshape(-1, 1, 28, 28)
        labels = torch.cat([left_label, right_label], dim=1)
        pred = self.lfl(labels)
        return pred, left_reconstruction, right_reconstruction

class MNISTAddLFLType3(nn.Module):
    '''Direct NeSy predictor with LFL-Type3.'''
    def __init__(self, n_choice=10, n_output=19, fuzzifier_layer_kwargs={}, lfl_layer_kwargs=[{}, {}]):
        super().__init__()
        self.cnn = MNIST_Net(N=n_choice)
        self.fuzzifier_layer = ConcreteLayer(**fuzzifier_layer_kwargs)
        self.reconstruction_layer = nn.Linear(n_choice, 784)
        self.lfl = LFLType3([n_choice, n_choice], n_output, layer_kwargs=lfl_layer_kwargs)

    def forward(self, left_image, right_image):
        left_pred, right_pred = self.cnn(left_image), self.cnn(right_image)
        left_label, right_label = self.fuzzifier_layer(left_pred), self.fuzzifier_layer(right_pred)
        left_reconstruction, right_reconstruction = self.reconstruction_layer(left_label.detach()).reshape(-1, 1, 28, 28), self.reconstruction_layer(right_label.detach()).reshape(-1, 1, 28, 28)
        labels = torch.cat([left_label, right_label], dim=1)
        pred = self.lfl(labels)
        return pred, left_reconstruction, right_reconstruction
        
class MNISTAddLFL_NoMLP(nn.Module):
    '''Direct NeSy predictor with LFL-Type1 and no MLP gradient shortcut.'''
    def __init__(self, n_choice=10, n_hidden=256, n_output=19, fuzzifier_layer_kwargs={}, lfl_layer_kwargs=[{}, {}]):
        super().__init__()
        self.cnn = MNIST_Net(N=n_choice)
        self.fuzzifier_layer = ConcreteLayer(**fuzzifier_layer_kwargs)
        self.softmax_layer = nn.Softmax(dim=1)
        self.reconstruction_layer = nn.Linear(n_choice, 784)
        self.lfl = MultiLayerLFL(n_choice * 2, [n_hidden, n_output], lfl_layer_kwargs)

    def forward(self, left_image, right_image):
        left_pred, right_pred = self.cnn(left_image), self.cnn(right_image)
        left_label, right_label = self.fuzzifier_layer(left_pred), self.fuzzifier_layer(right_pred)
        left_reconstruction, right_reconstruction = self.reconstruction_layer(left_label.detach()).reshape(-1, 1, 28, 28), self.reconstruction_layer(right_label.detach()).reshape(-1, 1, 28, 28)
        labels = torch.cat([left_label, right_label], dim=1)
        pred = self.lfl(labels)
        left_label_mean, right_label_mean = torch.mean(self.softmax_layer(left_pred), dim=0), torch.mean(self.softmax_layer(right_pred), dim=0)
        return pred, left_reconstruction, right_reconstruction, left_label_mean, right_label_mean