import torch
import itertools

def mapping(input_lenght, tuple_lenght):
    mapping = torch.randperm(input_lenght, dtype=torch.int32)
    if input_lenght % tuple_lenght:
        mapping = torch.concat([mapping, mapping[-1] * torch.ones(tuple_lenght - (input_lenght % tuple_lenght))])
    return mapping.view(-1, tuple_lenght)

def mapping_kernel(dims, tuple_lenght):
    input_lenght = torch.tensor(dims, dtype=torch.int32).prod()
    mapping = torch.tensor(list(itertools.product(*[range(dim) for dim in dims])), dtype=torch.int32)
    mapping = mapping[torch.randperm(input_lenght)]
    if input_lenght % tuple_lenght:
        mapping = torch.concat([mapping, torch.stack([mapping[-1] for i in range(tuple_lenght - (input_lenght % tuple_lenght))])])
    mapping = mapping.view(-1, tuple_lenght, 3)
    return mapping
    

class BinaryMappingFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, connections_weights):
        mapping = connections_weights.argmax(dim=0)
        output = input[:, mapping]
        ctx.save_for_backward(input, mapping)
        return output

    @staticmethod
    def backward(ctx, output_grad):
        input, mapping = ctx.saved_tensors
        connections_weights_grad = torch.sum((2*input.unsqueeze(2)-1) * output_grad.unsqueeze(1), dim=0)
        input_grad = torch.zeros_like(input)
        input_grad[:, mapping] = output_grad
        return input_grad, connections_weights_grad

class AutoBinaryMapping(torch.nn.Module):
    def __init__(self, num_inputs, num_output):
        super().__init__()
        self.connections_weights = torch.nn.Parameter(torch.rand(num_inputs, num_output, dtype=torch.float32), requires_grad=True)

    def forward(self, x):
        return BinaryMappingFunction.apply(x, self.connections_weights)
    

class MappingFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, connections_weights):
        mapping = connections_weights.argmax(dim=0)
        output = input[:, mapping]
        ctx.save_for_backward(input, mapping)
        return output

    @staticmethod
    def backward(ctx, output_grad):
        input, mapping = ctx.saved_tensors
        connections_weights_grad = torch.sum(input.unsqueeze(2) * output_grad.unsqueeze(1), dim=0)
        input_grad = torch.zeros_like(input)
        input_grad[:, mapping] = output_grad
        return input_grad, connections_weights_grad

class AutoMapping(torch.nn.Module):
    def __init__(self, num_inputs, num_output):
        super().__init__()
        self.connections_weights = torch.nn.Parameter(torch.rand(num_inputs, num_output, dtype=torch.float32), requires_grad=True)

    def forward(self, x):
        return MappingFunction.apply(x, self.connections_weights)


class MappingFCFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, connections_weights):
        mapping = connections_weights.argmax(dim=0)
        output = input[:, mapping]
        ctx.save_for_backward(input, mapping)
        return output

    @staticmethod
    def backward(ctx, output_grad):
        input, mapping = ctx.saved_tensors
        connections_weights_grad = torch.sum(input.unsqueeze(-1).unsqueeze(-1) * output_grad.unsqueeze(1), dim=0)
        input_grad = torch.zeros_like(input)
        input_grad[:, mapping] = output_grad
        return input_grad, connections_weights_grad

class AutoMappingFC(torch.nn.Module):
    def __init__(self, num_inputs, num_middle, num_output):
        super().__init__()
        self.connections_weights = torch.nn.Parameter(torch.rand(num_inputs, num_middle, num_output, dtype=torch.float32), requires_grad=True)

    def forward(self, x):
        return MappingFCFunction.apply(x, self.connections_weights)
    

class AutoFullyConnected(torch.nn.Module):
    def __init__(self, num_inputs, num_middle, num_output):
        super().__init__()
        self.auto_mapping = AutoMappingFC(num_inputs, num_middle, num_output)
        self.weights = torch.nn.Parameter(torch.rand(num_middle, num_output, dtype=torch.float32), requires_grad=True)
        self.bias = torch.nn.Parameter(torch.zeros(num_output, dtype=torch.float32), requires_grad=True)

        torch.nn.init.xavier_uniform_(self.weights, gain=1.0)

    def forward(self, x):
        x = self.auto_mapping(x)
        return (x * self.weights).sum(dim=1) + self.bias