import torch
import domafilter_cpu
if torch.cuda.is_available():
    import domafilter_cuda
from .mapping import mapping

class DoMaFilterFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, memory, mappings, keys):
        output = domafilter_cuda.forward(input, memory, mappings, keys) if input.is_cuda else domafilter_cpu.forward(input, memory, mappings, keys)
        ctx.save_for_backward(input, memory, mappings, keys)
        return output

    @staticmethod
    def backward(ctx, output_grad):
        memory_grad = domafilter_cuda.backward(*ctx.saved_tensors, output_grad.contiguous()) if output_grad.is_cuda else domafilter_cpu.backward(*ctx.saved_tensors, output_grad.contiguous())
        return None, memory_grad, None, None

class DoMaFilter(torch.nn.Module):
    def __init__(self, input_lenght, tuple_lenght, num_keys, filter_tuple_lenght, num_output):
        super().__init__()
        
        self.input_lenght = input_lenght
        self.tuple_lenght = tuple_lenght
        self.num_keys = num_keys
        self.filter_tuple_lenght = filter_tuple_lenght
        self.num_output = num_output

        self.keys = torch.nn.Parameter(torch.randint(0, 2**filter_tuple_lenght-1, (num_keys, tuple_lenght), dtype=torch.int32), requires_grad=False)
        self.mappings = torch.nn.Parameter(torch.tensor(mapping(input_lenght, tuple_lenght), dtype=torch.int32), requires_grad=False)
        self.memory = torch.nn.Parameter(torch.rand(self.mappings.size(0), 2**filter_tuple_lenght, num_output, dtype=torch.float32)*2 - 1, requires_grad=True)
        self.memory.bnn = True

    def forward(self, input):
        return torch.sum(DoMaFilterFunction.apply(input, self.memory, self.mappings, self.keys), dim=1)


class MultiDoMaFilterC(torch.nn.Module):
    def __init__(self, configs, dropout=0.5):
        super().__init__()
        
        self.filters = torch.nn.ModuleList([DoMaFilter(*config) for config in configs])
        self.dropouts = torch.nn.ModuleList([torch.nn.Dropout(dropout) for config in configs] if dropout else None)

    def forward(self, input):
        
        outputs = []

        for i in range(len(self.filters)):

            out = self.filters[i](input)

            if self.dropouts:
                out = self.dropouts[i](out)
            
            outputs.append(torch.sum(out, 1))
            outputs.append(out)

        return torch.sum(torch.stack(outputs, 0), 0)


class MultiDoMaFilter(torch.nn.Module):
    def __init__(self, input_lenght, tuple_lenght, num_keys, filter_tuple_lenght, num_output, num_models=1, dropout=0.5):
        super().__init__()

        self.num_models = num_models
        
        self.keys = torch.nn.Parameter(torch.randint(0, 2**filter_tuple_lenght-1, (num_keys, tuple_lenght), dtype=torch.int32), requires_grad=False)
        mappings = torch.concat([torch.tensor(mapping(input_lenght, tuple_lenght), dtype=torch.int32) for i in range(num_models)], dim=0)
        self.mappings = torch.nn.Parameter(mappings, requires_grad=False)
        self.memory = torch.nn.Parameter(torch.rand(self.mappings.size(0), 2**filter_tuple_lenght, num_output, dtype=torch.float32)*2 - 1, requires_grad=True)
        self.memory.bnn = True
        self.dropout = torch.nn.Dropout(dropout) if dropout else None

    def forward(self, x):
        x = DoMaFilterFunction.apply(x, self.memory, self.mappings, self.keys)
        if self.dropout:
            x = self.dropout(x)
        x = torch.sum(x, dim=1)
        return x