import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
from library.mnist_net import MNIST_Net
from library.LFL_modules import *
from library.LFL_layers import *
import random

class MNISTMultiAddDSL(nn.Module):
    '''Our implementation of DSL recursive NeSy predictor.'''
    def __init__(self, n_choice=10, n_output=10, epsilon_symbol=0.0, epsilon_rule=0.2):
        super().__init__()
        self.cnn = MNIST_Net(N=n_choice)
        self.fuzzifier_layer = EpsilonGreedySoftmaxLayer(epsilon=epsilon_symbol)
        self.reconstruction_layer = nn.Linear(n_choice, 784)
        self.lfl = DSLLogicModule([n_choice, n_choice, 2], [n_output, 2], epsilon_rule=epsilon_rule)

    def forward(self, left_images, right_images):
        batch_size = left_images.shape[0]
        carry = torch.zeros((batch_size, 2), device=left_images.device, dtype=left_images.dtype)
        carry[:, 0] = 1.
        preds = []
        left_reconstructions, right_reconstructions = [], []
        for i in range(left_images.shape[1]):
            left_image, right_image = left_images[:, i], right_images[:, i]
            left_pred, right_pred = self.cnn(left_image), self.cnn(right_image)
            left_label, right_label = self.fuzzifier_layer(left_pred), self.fuzzifier_layer(right_pred)
            left_reconstruction, right_reconstruction = self.reconstruction_layer(left_label.detach()).reshape(-1, 1, 28, 28), self.reconstruction_layer(right_label.detach()).reshape(-1, 1, 28, 28)
            left_reconstructions.append(left_reconstruction); right_reconstructions.append(right_reconstruction)
            labels = torch.cat([left_label, right_label, carry], dim=1)
            lfl_output = self.lfl(labels)
            pred, carry = lfl_output[:, :-2], lfl_output[:, -2:]
            preds.append(pred)
        preds, left_reconstructions, right_reconstructions = torch.stack(preds, dim=1), torch.stack(left_reconstructions, dim=1), torch.stack(right_reconstructions, dim=1)
        return preds, carry, left_reconstructions, right_reconstructions

class MNISTMultiAddConcreteDSL(nn.Module):
    '''Recursive NeSy predictor with LFL-Type2.'''
    def __init__(self, n_choice=10, n_output=10, fuzzifier_layer_kwargs={}, lfl_layer_kwargs=[{}, {}]):
        super().__init__()
        self.cnn = MNIST_Net(N=n_choice)
        self.fuzzifier_layer = ConcreteLayer(**fuzzifier_layer_kwargs)
        self.reconstruction_layer = nn.Linear(n_choice, 784)
        self.lfl  =ConcreteDSLLogicModule([n_choice, n_choice, 2], [n_output, 2], layer_kwargs=lfl_layer_kwargs)

    def forward(self, left_images, right_images):
        batch_size = left_images.shape[0]
        carry = torch.zeros((batch_size, 2), device=left_images.device, dtype=left_images.dtype)
        carry[:, 0] = 1.
        preds = []
        left_reconstructions, right_reconstructions = [], []
        for i in range(left_images.shape[1]):
            left_image, right_image = left_images[:, i], right_images[:, i]
            left_pred, right_pred = self.cnn(left_image), self.cnn(right_image)
            left_label, right_label = self.fuzzifier_layer(left_pred), self.fuzzifier_layer(right_pred)
            left_reconstruction, right_reconstruction = self.reconstruction_layer(left_label.detach()).reshape(-1, 1, 28, 28), self.reconstruction_layer(right_label.detach()).reshape(-1, 1, 28, 28)
            left_reconstructions.append(left_reconstruction); right_reconstructions.append(right_reconstruction)
            labels = torch.cat([left_label, right_label, carry], dim=1)
            lfl_output = self.lfl(labels)
            pred, carry = lfl_output[:, :-2], lfl_output[:, -2:]
            preds.append(pred)
        preds, left_reconstructions, right_reconstructions = torch.stack(preds, dim=1), torch.stack(left_reconstructions, dim=1), torch.stack(right_reconstructions, dim=1)
        return preds, carry, left_reconstructions, right_reconstructions

class MNISTMultiAddLFLType3(nn.Module):
    '''Recursive NeSy predictor with LFL-Type3.'''
    def __init__(self, n_choice=10, n_output=10, fuzzifier_layer_kwargs={}, lfl_layer_kwargs=[{}, {}]):
        super().__init__()
        self.cnn = MNIST_Net(N=n_choice)
        self.fuzzifier_layer = ConcreteLayer(**fuzzifier_layer_kwargs)
        self.reconstruction_layer = nn.Linear(n_choice, 784)
        self.lfl = LFLType3([n_choice, n_choice, 2], n_output + 2, layer_kwargs=lfl_layer_kwargs)

    def forward(self, left_images, right_images):
        batch_size = left_images.shape[0]
        carry = torch.zeros((batch_size, 2), device=left_images.device, dtype=left_images.dtype)
        carry[:, 0] = 1.
        preds = []
        left_reconstructions, right_reconstructions = [], []
        for i in range(left_images.shape[1]):
            left_image, right_image = left_images[:, i], right_images[:, i]
            left_pred, right_pred = self.cnn(left_image), self.cnn(right_image)
            left_label, right_label = self.fuzzifier_layer(left_pred), self.fuzzifier_layer(right_pred)
            left_reconstruction, right_reconstruction = self.reconstruction_layer(left_label.detach()).reshape(-1, 1, 28, 28), self.reconstruction_layer(right_label.detach()).reshape(-1, 1, 28, 28)
            left_reconstructions.append(left_reconstruction); right_reconstructions.append(right_reconstruction)
            labels = torch.cat([left_label, right_label, carry], dim=1)
            lfl_output = self.lfl(labels)
            pred, carry = lfl_output[:, :-2], lfl_output[:, -2:]
            preds.append(pred)
        preds, left_reconstructions, right_reconstructions = torch.stack(preds, dim=1), torch.stack(left_reconstructions, dim=1), torch.stack(right_reconstructions, dim=1)
        return preds, carry, left_reconstructions, right_reconstructions

class MNISTMultiAddLFLType3_MLP(nn.Module):
    '''Recursive NeSy predictor with LFL-Type3 and MLP gradient shortcut.'''
    def __init__(self, n_choice=10, n_output=10, n_hidden_mlp=512, fuzzifier_layer_kwargs={}, lfl_layer_kwargs=[{}, {}]):
        super().__init__()
        self.cnn = MNIST_Net(N=n_choice)
        self.fuzzifier_layer = ConcreteLayer(**fuzzifier_layer_kwargs)
        self.softmax_layer = nn.Softmax(dim=1)
        self.reconstruction_layer = nn.Linear(n_choice, 784)
        self.lfl = LFLType3([n_choice, n_choice, 2], n_output + 2, layer_kwargs=lfl_layer_kwargs)
        self.mlp = nn.Sequential(nn.Linear(n_choice * 2 + 2, n_hidden_mlp), nn.ReLU(), nn.Linear(n_hidden_mlp, n_output + 2), nn.Sigmoid())

    def forward(self, left_images, right_images):
        batch_size = left_images.shape[0]
        carry = torch.zeros((batch_size, 2), device=left_images.device, dtype=left_images.dtype)
        carry[:, 0] = 1.
        preds_lfl, preds_mlp, left_reconstructions, right_reconstructions, left_label_means, right_label_means = [], [], [], [], [], []
        for i in range(left_images.shape[1]):
            left_image, right_image = left_images[:, i], right_images[:, i]
            left_pred, right_pred = self.cnn(left_image), self.cnn(right_image)
            left_label, right_label = self.fuzzifier_layer(left_pred), self.fuzzifier_layer(right_pred)
            left_label_mean, right_label_mean = torch.mean(self.softmax_layer(left_pred), dim=0), torch.mean(self.softmax_layer(right_pred), dim=0)
            left_label_means.append(left_label_mean); right_label_means.append(right_label_mean)
            left_reconstruction, right_reconstruction = self.reconstruction_layer(left_label.detach()).reshape(-1, 1, 28, 28), self.reconstruction_layer(right_label.detach()).reshape(-1, 1, 28, 28)
            left_reconstructions.append(left_reconstruction); right_reconstructions.append(right_reconstruction)
            labels = torch.cat([left_label, right_label, carry], dim=1)
            lfl_output = self.lfl(labels)
            mlp_output = self.mlp(labels)
            pred_lfl, carry_lfl = lfl_output[:, :-2], lfl_output[:, -2:]
            preds_lfl.append(pred_lfl)
            pred_mlp, carry_mlp = mlp_output[:, :-2], mlp_output[:, -2:]
            preds_mlp.append(pred_mlp)
            if self.training:
                carry = torch.where(torch.rand(batch_size, 1, device=carry_lfl.device) >= 0.5, carry_lfl, carry_mlp)
            else:
                carry = carry_lfl
        preds_lfl, preds_mlp, left_reconstructions, right_reconstructions = torch.stack(preds_lfl, dim=1), torch.stack(preds_mlp, dim=1), torch.stack(left_reconstructions, dim=1), torch.stack(right_reconstructions, dim=1), 
        left_label_mean, right_label_mean = torch.mean(torch.stack(left_label_means, dim=0), dim=0), torch.mean(torch.stack(right_label_means, dim=0), dim=0)
        return preds_lfl, preds_mlp, carry_lfl, carry_mlp, left_reconstructions, right_reconstructions, left_label_mean, right_label_mean

class MNISTMultiAddLFL_MLP(nn.Module):
    '''Recursive NeSy predictor with LFL-Type1 and MLP gradient shortcut.'''
    def __init__(self, n_choice=10, n_output=10, n_hidden=512, fuzzifier_layer_kwargs={}, lfl_layer_kwargs=[{}, {}]):
        super().__init__()
        self.cnn = MNIST_Net(N=n_choice)
        self.fuzzifier_layer = ConcreteLayer(**fuzzifier_layer_kwargs)
        self.softmax_layer = nn.Softmax(dim=1)
        self.reconstruction_layer = nn.Linear(n_choice, 784)
        self.lfl = MultiLayerLFL(n_choice + n_choice + 2, [n_hidden, n_output + 2], layer_kwargs=lfl_layer_kwargs)
        self.mlp = nn.Sequential(nn.Linear(n_choice * 2 + 2, n_hidden), nn.ReLU(), nn.Linear(n_hidden, n_output + 2), nn.Sigmoid())

    def forward(self, left_images, right_images):
        batch_size = left_images.shape[0]
        carry = torch.zeros((batch_size, 2), device=left_images.device, dtype=left_images.dtype)
        carry[:, 0] = 1.
        preds_lfl, preds_mlp, left_reconstructions, right_reconstructions, left_label_means, right_label_means = [], [], [], [], [], []
        for i in range(left_images.shape[1]):
            left_image, right_image = left_images[:, i], right_images[:, i]
            left_pred, right_pred = self.cnn(left_image), self.cnn(right_image)
            left_label, right_label = self.fuzzifier_layer(left_pred), self.fuzzifier_layer(right_pred)
            left_label_mean, right_label_mean = torch.mean(self.softmax_layer(left_pred), dim=0), torch.mean(self.softmax_layer(right_pred), dim=0)
            left_label_means.append(left_label_mean); right_label_means.append(right_label_mean)
            left_reconstruction, right_reconstruction = self.reconstruction_layer(left_label.detach()).reshape(-1, 1, 28, 28), self.reconstruction_layer(right_label.detach()).reshape(-1, 1, 28, 28)
            left_reconstructions.append(left_reconstruction); right_reconstructions.append(right_reconstruction)
            labels = torch.cat([left_label, right_label, carry], dim=1)
            lfl_output = self.lfl(labels)
            mlp_output = self.mlp(labels)
            pred_lfl, carry_lfl = lfl_output[:, :-2], lfl_output[:, -2:]
            preds_lfl.append(pred_lfl)
            pred_mlp, carry_mlp = mlp_output[:, :-2], mlp_output[:, -2:]
            preds_mlp.append(pred_mlp)
            if self.training:
                carry = torch.where(torch.rand(batch_size, 1, device=carry_lfl.device) >= 0.5, carry_lfl, carry_mlp)
            else:
                carry = carry_lfl
        preds_lfl, preds_mlp, left_reconstructions, right_reconstructions = torch.stack(preds_lfl, dim=1), torch.stack(preds_mlp, dim=1), torch.stack(left_reconstructions, dim=1), torch.stack(right_reconstructions, dim=1), 
        left_label_mean, right_label_mean = torch.mean(torch.stack(left_label_means, dim=0), dim=0), torch.mean(torch.stack(right_label_means, dim=0), dim=0)
        return preds_lfl, preds_mlp, carry_lfl, carry_mlp, left_reconstructions, right_reconstructions, left_label_mean, right_label_mean

class MNISTMultiAddLFL_NoMLP(nn.Module):
    '''Recursive NeSy predictor with LFL-Type1 and no MLP gradient shortcut.'''
    def __init__(self, n_choice=10, n_output=10, n_hidden=512, fuzzifier_layer_kwargs={}, lfl_layer_kwargs=[{}, {}]):
        super().__init__()
        self.cnn = MNIST_Net(N=n_choice)
        self.fuzzifier_layer = ConcreteLayer(**fuzzifier_layer_kwargs)
        self.softmax_layer = nn.Softmax(dim=1)
        self.reconstruction_layer = nn.Linear(n_choice, 784)
        self.lfl = MultiLayerLFL(n_choice + n_choice + 2, [n_hidden, n_output + 2], layer_kwargs=lfl_layer_kwargs)

    def forward(self, left_images, right_images):
        batch_size = left_images.shape[0]
        carry = torch.zeros((batch_size, 2), device=left_images.device, dtype=left_images.dtype)
        carry[:, 0] = 1.
        preds_lfl, left_reconstructions, right_reconstructions, left_label_means, right_label_means = [], [], [], [], []
        for i in range(left_images.shape[1]):
            left_image, right_image = left_images[:, i], right_images[:, i]
            left_pred, right_pred = self.cnn(left_image), self.cnn(right_image)
            left_label, right_label = self.fuzzifier_layer(left_pred), self.fuzzifier_layer(right_pred)
            left_label_mean, right_label_mean = torch.mean(self.softmax_layer(left_pred), dim=0), torch.mean(self.softmax_layer(right_pred), dim=0)
            left_label_means.append(left_label_mean); right_label_means.append(right_label_mean)
            left_reconstruction, right_reconstruction = self.reconstruction_layer(left_label.detach()).reshape(-1, 1, 28, 28), self.reconstruction_layer(right_label.detach()).reshape(-1, 1, 28, 28)
            left_reconstructions.append(left_reconstruction); right_reconstructions.append(right_reconstruction)
            labels = torch.cat([left_label, right_label, carry], dim=1)
            lfl_output = self.lfl(labels)
            pred_lfl, carry_lfl = lfl_output[:, :-2], lfl_output[:, -2:]
            preds_lfl.append(pred_lfl)
            carry = carry_lfl
        preds_lfl, left_reconstructions, right_reconstructions = torch.stack(preds_lfl, dim=1), torch.stack(left_reconstructions, dim=1), torch.stack(right_reconstructions, dim=1), 
        left_label_mean, right_label_mean = torch.mean(torch.stack(left_label_means, dim=0), dim=0), torch.mean(torch.stack(right_label_means, dim=0), dim=0)
        return preds_lfl, carry_lfl, left_reconstructions, right_reconstructions, left_label_mean, right_label_mean