import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
import sympy as sp
import numpy as np
import abc

class LFLModule(nn.Module, abc.ABC):
    @abc.abstractmethod
    def expression(self, x_exp):
        '''Generate output sympy expressions from the inputs' expressions.'''
        pass

class ConjunctionLayer(LFLModule):
    '''LFL-Type1's AND layer.'''
    def __init__(self, n_input: int, n_neuron: int, noise_scale=1., init_mean=0., init_scale=0.25):
        super().__init__()
        self.n_input = n_input
        self.n_neuron = n_neuron
        self.weight = Parameter(torch.randn(self.n_input, self.n_neuron) * init_scale + init_mean)
        self.epsilon = 1e-8
        self.noise_scale = noise_scale
    
    def get_membership(self, batch_size=None):
        if self.training:
            assert batch_size, 'batch_size is required for non-deterministic membership calculation.'
            u = torch.rand(size=(batch_size, self.n_input, self.n_neuron), device=self.weight.device, dtype=self.weight.dtype)
            u = torch.clip(u, self.epsilon, 1. - self.epsilon)
            noisy_weight = (torch.log(u) - torch.log(1.0 - u)) * self.noise_scale + self.weight[None, :, :]
            return torch.sigmoid(noisy_weight)
        else:
            return torch.ge(self.weight, 0.).type(self.weight.dtype)
        
    def bool_membership(self):
        return torch.ge(self.weight, 0.).detach().cpu().numpy()

    def forward(self, x):
        membership = self.get_membership(batch_size=x.shape[0])
        if self.training:
            y = 1. - torch.minimum(1. - x[:, :, None], membership)
        else:
            y = 1. - torch.minimum(1. - x[:, :, None], membership[None, :, :])
        y = torch.amin(y, dim=1, keepdim=False)
        y = torch.clip(y, self.epsilon, 1. - self.epsilon)
        return y
    
    def expression(self, x_exp):
        assert len(x_exp) == self.n_input, 'Number of x_exp does not match number of inputs.'
        membership = self.bool_membership()
        ret = []
        for n_neuron in range(self.n_neuron):
            terms = [e for i, e in enumerate(x_exp) if membership[i, n_neuron]]
            if len(terms) == 0:
                ret.append(sp.true)
            elif len(terms) == 1:
                ret.append(terms[0])
            else:
                ret.append(sp.And(*terms))
        return ret
    
    def reg_loss(self):
        median_membership = torch.sigmoid(self.weight)
        ret = torch.mean(median_membership)
        return ret
    
class DisjunctionLayer(ConjunctionLayer):
    '''LFL-Type1 and LFL-Type3's OR layer.'''
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
    
    def forward(self, x):
        membership = self.get_membership(batch_size=x.shape[0])
        if self.training:
            y = torch.minimum(x[:, :, None], membership)
        else:
            y = torch.minimum(x[:, :, None], membership[None, :, :])
        y = torch.amax(y, dim=1, keepdim=False)
        y = torch.clip(y, self.epsilon, 1. - self.epsilon)
        return y
    
    def expression(self, x_exp):
        assert len(x_exp) == self.n_input, 'Number of x_exp does not match number of inputs.'
        membership = self.bool_membership()
        ret = []
        for n_neuron in range(self.n_neuron):
            terms = [e for i, e in enumerate(x_exp) if membership[i, n_neuron]]
            if len(terms) == 0:
                ret.append(sp.false)
            elif len(terms) == 1:
                ret.append(terms[0])
            else:
                ret.append(sp.Or(*terms))
        return ret
    
class NegationLayer(LFLModule):
    '''The negation layer.'''
    def __init__(self, epsilon=1e-8):
        super().__init__()
        self.epsilon = epsilon
        self.weight = Parameter(torch.tensor(0.))
    
    def forward(self, x):
        y = 1. - x
        y = torch.clip(y, self.epsilon, 1. - self.epsilon)
        return y
    
    def expression(self, x_exp):
        return [sp.Not(e) for e in x_exp]
    
class ConjunctionLayerDNL(LFLModule):
    '''Our implementation of dNL's AND layer.'''
    def __init__(self, n_input: int, n_neuron: int, init_mean=0., init_scale=0.25):
        super().__init__()
        self.n_input = n_input
        self.n_neuron = n_neuron
        self.weight = Parameter(torch.randn(self.n_input, self.n_neuron) * init_scale + init_mean)
    
    def get_membership(self):
        if self.training:
            return torch.sigmoid(self.weight)
        else:
            return torch.ge(self.weight, 0.).type(self.weight.dtype)
        
    def bool_membership(self):
        return torch.ge(self.weight, 0.).detach().cpu().numpy()

    def forward(self, x):
        membership = self.get_membership()
        y = 1. - (1. - x[:, :, None]) * (membership[None, :, :])
        y = torch.prod(y, dim=1, keepdim=False)
        return y
    
    def expression(self, x_exp):
        assert len(x_exp) == self.n_input, 'Number of x_exp does not match number of inputs.'
        membership = self.bool_membership()
        ret = []
        for n_neuron in range(self.n_neuron):
            terms = [e for i, e in enumerate(x_exp) if membership[i, n_neuron]]
            if len(terms) == 0:
                ret.append(sp.true)
            elif len(terms) == 1:
                ret.append(terms[0])
            else:
                ret.append(sp.And(*terms))
        return ret
    
    def reg_loss(self):
        median_membership = torch.sigmoid(self.weight)
        ret = torch.mean(median_membership)
        return ret
    
class DisjunctionLayerDNL(ConjunctionLayerDNL):
    '''Our implementation of dNL's OR layer.'''
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
    
    def forward(self, x):
        membership = self.get_membership()
        y = (x[:, :, None]) * (membership[None, :, :])
        y = 1. - torch.prod(1. - y, dim=1, keepdim=False)
        return y
    
    def expression(self, x_exp):
        assert len(x_exp) == self.n_input, 'Number of x_exp does not match number of inputs.'
        membership = self.bool_membership()
        ret = []
        for n_neuron in range(self.n_neuron):
            terms = [e for i, e in enumerate(x_exp) if membership[i, n_neuron]]
            if len(terms) == 0:
                ret.append(sp.false)
            elif len(terms) == 1:
                ret.append(terms[0])
            else:
                ret.append(sp.Or(*terms))
        return ret
    
class PredefinedConjunctionLayer(LFLModule):
    '''DSL, LFL-Type2 and LFL-Type3's frozen AND layer.'''
    def __init__(self, membership: np.array):
        super().__init__()
        self.n_input = membership.shape[0]
        self.n_neuron = membership.shape[1]
        self.membership = Parameter(torch.tensor(membership, dtype=torch.get_default_dtype()), requires_grad=False)
        self.bool_membership = membership.astype(np.bool_)
        
    def forward(self, x):
        membership = self.membership
        y = 1. - torch.minimum(1. - x[:, :, None], membership[None, :, :])
        y = torch.amin(y, dim=1, keepdim=False)
        return y
    
    def expression(self, x_exp):
        assert len(x_exp) == self.n_input, 'Number of x_exp does not match number of inputs.'
        membership = self.bool_membership
        ret = []
        for n_neuron in range(self.n_neuron):
            terms = [e for i, e in enumerate(x_exp) if membership[i, n_neuron]]
            if len(terms) == 0:
                ret.append(sp.true)
            elif len(terms) == 1:
                ret.append(terms[0])
            else:
                ret.append(sp.And(*terms))
        return ret

class EpsilonGreedySoftmaxLayer(nn.Module):
    '''Our implementation of DSL's fuzzifier layer.'''
    def __init__(self, epsilon=0.2807344052335263):
        super().__init__()
        self.epsilon = epsilon
        
    def forward(self, x):
        if self.training:
            random_selection = torch.rand((x.shape[0],)) < self.epsilon
            random_selection = random_selection.to(x.device)
            symbol_index_random = torch.randint(x.shape[1], (x.shape[0],))
            symbol_index_random = symbol_index_random.to(x.device)
            symbol_index_max = torch.argmax(x, dim=1)

            chosen_symbols = torch.where(random_selection, symbol_index_random, symbol_index_max)
            mask = nn.functional.one_hot(chosen_symbols, num_classes=x.shape[1]).detach()
            y = mask * torch.softmax(x, dim=1)
            return y
        else:
            return nn.functional.one_hot(torch.argmax(x, dim=1), num_classes=x.shape[1]).type(x.dtype)
    
class EpsilonGreedySoftmaxMembershipDisjunctionLayer(LFLModule):
    '''Our implementation of DSL's OR layer.'''
    def __init__(self, n_input: int, n_neuron: int, epsilon=0.1077119516324264):
        super().__init__()
        self.n_input = n_input
        self.n_neuron = n_neuron
        self.weight = Parameter(torch.randn(self.n_input, self.n_neuron))
        self.epsilon = epsilon
        
    def get_membership(self):
        if self.training:
            random_selection = torch.rand((self.n_input,), device=self.weight.device) < self.epsilon
            symbol_index_random = torch.randint(self.n_neuron, (self.n_input,), device=self.weight.device)
            symbol_index_max = torch.argmax(self.weight, dim=1)
            chosen_symbols = torch.where(random_selection, symbol_index_random, symbol_index_max)
            mask = nn.functional.one_hot(chosen_symbols, num_classes=self.n_neuron).detach()
            return mask * torch.softmax(self.weight, dim=1)
        else:
            return nn.functional.one_hot(torch.argmax(self.weight, dim=1), num_classes=self.n_neuron).type(self.weight.dtype)
    
    def bool_membership(self):
        return nn.functional.one_hot(torch.argmax(self.weight, dim=1), num_classes=self.n_neuron).type(torch.bool).detach().cpu().numpy()
        
    def forward(self, x):
        membership = self.get_membership()
        y = torch.minimum(x[:, :, None], membership[None, :, :])
        y = torch.amax(y, dim=1, keepdim=False)
        return y
    
    def expression(self, x_exp):
        assert len(x_exp) == self.n_input, 'Number of x_exp does not match number of inputs.'
        membership = self.bool_membership()
        ret = []
        for n_neuron in range(self.n_neuron):
            terms = [e for i, e in enumerate(x_exp) if membership[i, n_neuron]]
            if len(terms) == 0:
                ret.append(sp.false)
            elif len(terms) == 1:
                ret.append(terms[0])
            else:
                ret.append(sp.Or(*terms))
        return ret

class ConcreteLayer(nn.Module):
    '''LFL-Type1, LFL-Type2 and LFL-Type3's fuzzifier layer.'''
    def __init__(self, noise_scale=1.):
        super().__init__()
        self.noise_scale = noise_scale
        self.epsilon = 1e-8
        
    def forward(self, x):
        if self.training:
            u = torch.rand(size=x.shape, device=x.device, dtype=x.dtype)
            u = torch.clip(u, self.epsilon, 1. - self.epsilon)
            ret = torch.softmax(x - torch.log(-torch.log(u)) * self.noise_scale, dim=1)
        else:
            ret = nn.functional.one_hot(torch.argmax(x, dim=1), num_classes=x.shape[1]).type(x.dtype)
        return ret
    
class ConcreteMembershipDisjunctionLayer(LFLModule):
    '''LFL-Type2's OR layer.'''
    def __init__(self, n_input: int, n_neuron: int, noise_scale=1.):
        super().__init__()
        self.n_input = n_input
        self.n_neuron = n_neuron
        self.weight = Parameter(torch.randn(self.n_input, self.n_neuron))
        self.noise_scale = noise_scale
        self.epsilon = 1e-8
        
    def get_membership(self, batch_size):
        if self.training:
            u = torch.rand(size=(batch_size, self.n_input, self.n_neuron), device=self.weight.device, dtype=self.weight.dtype)
            u = torch.clip(u, self.epsilon, 1. - self.epsilon)
            ret = torch.softmax(self.weight[None, :, :] - torch.log(-torch.log(u)) * self.noise_scale, dim=2)
        else:
            ret = nn.functional.one_hot(torch.argmax(self.weight, dim=1), num_classes=self.n_neuron).type(self.weight.dtype)
        return ret
    
    def bool_membership(self):
        return nn.functional.one_hot(torch.argmax(self.weight, dim=1), num_classes=self.n_neuron).type(torch.bool).detach().cpu().numpy()
        
    def forward(self, x):
        membership = self.get_membership(x.shape[0])
        if self.training:
            y = torch.minimum(x[:, :, None], membership)
        else:
            y = torch.minimum(x[:, :, None], membership[None, :, :])
        y = torch.amax(y, dim=1, keepdim=False)
        return y
    
    def expression(self, x_exp):
        assert len(x_exp) == self.n_input, 'Number of x_exp does not match number of inputs.'
        membership = self.bool_membership()
        ret = []
        for n_neuron in range(self.n_neuron):
            terms = [e for i, e in enumerate(x_exp) if membership[i, n_neuron]]
            if len(terms) == 0:
                ret.append(sp.false)
            elif len(terms) == 1:
                ret.append(terms[0])
            else:
                ret.append(sp.Or(*terms))
        return ret