import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
from library.LFL_layers import *
import sympy as sp
import numpy as np

class MultiLayerLFL(LFLModule):
    '''Multi-layer LFL-Type1 without NOT neurons.'''
    def __init__(self, n_input, n_hiddens, layer_kwargs, conjunction_first=True):
        super().__init__()
        layers = []
        current_n_input = n_input
        conjunction_flag = conjunction_first
        for i in range(len(n_hiddens)):
            layers.append(ConjunctionLayer(current_n_input, n_hiddens[i], **layer_kwargs[i]) if conjunction_flag else DisjunctionLayer(current_n_input, n_hiddens[i], **layer_kwargs[i]))
            conjunction_flag ^= True
            current_n_input = n_hiddens[i]
        self.layers = nn.ModuleList(layers)
        self.n_output = current_n_input
    
    def forward(self, x):
        ret = x
        for layer in self.layers:
            ret = layer(ret)
        return ret
            
    def expression(self, x_exp):
        hidden_exp = x_exp
        for layer in self.layers:
            hidden_exp = layer.expression(hidden_exp)
        return hidden_exp
    
    def reg_loss(self):
        return torch.sum(torch.stack([layer.reg_loss() for layer in self.layers]))

    def set_noise_scale(self, noise_scales):
        for i, layer in enumerate(self.layers):
            layer.set_noise_scale(noise_scales[i])
            
class MultiLayerLFLWithNegation(LFLModule):
    '''Multi-layer LFL-Type1 with NOT neurons in the first layer.'''
    def __init__(self, n_input, n_hiddens, layer_kwargs, conjunction_first=True):
        super().__init__()
        self.n_input = n_input
        self.negation_layer = NegationLayer()
        self.multiLayerLFL = MultiLayerLFL(n_input * 2, n_hiddens, layer_kwargs, conjunction_first)

    def forward(self, x):
        hidden = torch.cat([x, self.negation_layer(x)], dim=1)
        hidden = self.multiLayerLFL(hidden)
        return hidden

    def expression(self, x_exp=None):
        if not x_exp:
            x_exp = [sp.Symbol(f'x{i}') for i in range(self.n_input)]
        hidden_exp = x_exp + self.negation_layer.expression(x_exp)
        hidden_exp = self.multiLayerLFL.expression(hidden_exp)
        return hidden_exp

    def reg_loss(self):
        return self.multiLayerLFL.reg_loss()

class MultiLayerDNL(LFLModule):
    '''Multi-layer dNL without NOT neurons.'''
    def __init__(self, n_input, n_hiddens, layer_kwargs, conjunction_first=True):
        super().__init__()
        layers = []
        current_n_input = n_input
        conjunction_flag = conjunction_first
        for i in range(len(n_hiddens)):
            layers.append(ConjunctionLayerDNL(current_n_input, n_hiddens[i], **layer_kwargs[i]) if conjunction_flag else DisjunctionLayerDNL(current_n_input, n_hiddens[i], **layer_kwargs[i]))
            conjunction_flag ^= True
            current_n_input = n_hiddens[i]
        self.layers = nn.ModuleList(layers)
        self.n_output = current_n_input
    
    def forward(self, x):
        ret = x
        for layer in self.layers:
            ret = layer(ret)
        return ret
            
    def expression(self, x_exp):
        hidden_exp = x_exp
        for layer in self.layers:
            hidden_exp = layer.expression(hidden_exp)
        return hidden_exp
    
    def reg_loss(self):
        return torch.sum(torch.stack([layer.reg_loss() for layer in self.layers]))

class MultiLayerDNLWithNegation(LFLModule):
    '''Multi-layer dNL with NOT neurons in the first layer.'''
    def __init__(self, n_input, n_hiddens, layer_kwargs, conjunction_first=True):
        super().__init__()
        self.n_input = n_input
        self.negation_layer = NegationLayer()
        self.multiLayerDNL = MultiLayerDNL(n_input * 2, n_hiddens, layer_kwargs, conjunction_first)

    def forward(self, x):
        hidden = torch.cat([x, self.negation_layer(x)], dim=1)
        hidden = self.multiLayerDNL(hidden)
        return hidden

    def expression(self, x_exp=None):
        if not x_exp:
            x_exp = [sp.Symbol(f'x{i}') for i in range(self.n_input)]
        hidden_exp = x_exp + self.negation_layer.expression(x_exp)
        hidden_exp = self.multiLayerDNL.expression(hidden_exp)
        return hidden_exp

    def reg_loss(self):
        return self.multiLayerDNL.reg_loss()

def conjunction_membership(n_choices):
    '''Membership of frozen AND layers. Each neuron represents one of the cartesian products of input concepts.'''
    true_indexes = np.array(np.meshgrid(*[np.arange(n_choice) for n_choice in n_choices])).T.reshape(-1, len(n_choices))
    ret = torch.cat([torch.nn.functional.one_hot(torch.tensor(true_indexes[:, i])) for i in range(len(n_choices))], dim=1).T.numpy()
    return ret

class DSLLogicModule(LFLModule):
    '''Our implementation of DSL's logic module.'''
    def __init__(self, n_choices: list, n_outputs: list, epsilon_rule=0.1077119516324264):
        super().__init__()
        membership = conjunction_membership(n_choices)
        self.conjunction_layer = PredefinedConjunctionLayer(membership)
        self.disjunction_layers = nn.ModuleList([EpsilonGreedySoftmaxMembershipDisjunctionLayer(membership.shape[1], n_output, epsilon=epsilon_rule) for n_output in n_outputs])
    
    def forward(self, x, soft=None, deterministic=None):
        hidden = self.conjunction_layer(x)
        out = torch.cat([disjunction_layer(hidden) for disjunction_layer in self.disjunction_layers], dim=1)
        return out
    
    def expression(self, x_exp=None):
        if not x_exp:
            x_exp = [sp.Symbol(f'x{i}') for i in range(self.n_input)]
        hidden_exp = self.conjunction_layer.expression(x_exp)
        out_exp = []
        for disjunction_layer in self.disjunction_layers:
            out_exp += disjunction_layer.expression(hidden_exp)
        return out_exp
        
class ConcreteDSLLogicModule(LFLModule):
    '''LFL-Type2.'''
    def __init__(self, n_choices: list, n_outputs: list, layer_kwargs=[{}, {}]):
        super().__init__()
        membership = conjunction_membership(n_choices)
        self.conjunction_layer = PredefinedConjunctionLayer(membership, **layer_kwargs[0])
        self.disjunction_layers = nn.ModuleList([ConcreteMembershipDisjunctionLayer(membership.shape[1], n_output, **layer_kwargs[1]) for n_output in n_outputs])
    
    def forward(self, x):
        hidden = self.conjunction_layer(x)
        out = torch.cat([disjunction_layer(hidden) for disjunction_layer in self.disjunction_layers], dim=1)
        return out

    def expression(self, x_exp=None):
        if not x_exp:
            x_exp = [sp.Symbol(f'x{i}') for i in range(self.n_input)]
        hidden_exp = self.conjunction_layer.expression(x_exp)
        out_exp = []
        for disjunction_layer in self.disjunction_layers:
            out_exp += disjunction_layer.expression(hidden_exp)
        return out_exp

class LFLType3(LFLModule):
    '''LFL-Type3.'''
    def __init__(self, n_choices: list, n_output: int, layer_kwargs=[{}, {}]):
        super().__init__()
        membership = conjunction_membership(n_choices)
        self.conjunction_layer = PredefinedConjunctionLayer(membership, **layer_kwargs[0])
        self.disjunction_layer = DisjunctionLayer(membership.shape[1], n_output, **layer_kwargs[1])
        
    def forward(self, x):
        hidden = self.conjunction_layer(x)
        hidden = self.disjunction_layer(hidden)
        return hidden
    
    def expression(self, x_exp=None):
        if not x_exp:
            x_exp = [sp.Symbol(f'x{i}') for i in range(self.n_input)]
        hidden_exp = self.conjunction_layer.expression(x_exp)
        hidden_exp = self.disjunction_layer.expression(hidden_exp)
        return hidden_exp
    
    def reg_loss(self):
        return self.disjunction_layer.reg_loss()