from typing import List, Optional

import torch


class TorchModuleWithFullyConnectedLayersAndHelperFunctions(torch.nn.Module):
    """
    A helper module that provides a few useful functions for fully connected layers, attention, and regularization loss.
    """

    def __init__(self):
        super().__init__()
        self.fnn_has_been_initialized = False
        self.fnn_layers = []
        self.fnn_last_layer_uses_residual = False
        self.fnn_activation_function = torch.nn.LeakyReLU(negative_slope=0.01)
        self.attention_has_been_initialized = False
        self.attention_type = None
        self.attention_gumbel_softmax_tau = None
        self.regularization_loss_has_been_initialized = False
        self.regularization_loss_accumulator = None
        self.regularization_loss_boundary = None

    def initialize_fully_connected_layers_based_on_sizes(self, list_of_tensor_sizes, add_residuals_for_input):
        """
        Take a list of sizes and create fully-connected layers that connect them.
        """
        self.fnn_has_been_initialized = True
        self.fnn_layers = []
        self.fnn_last_layer_uses_residual = add_residuals_for_input
        ins = list_of_tensor_sizes[:-1]
        outs = list_of_tensor_sizes[1:]
        if self.fnn_last_layer_uses_residual:
            ins = ins[:-1] + [ins[-1] + ins[0]]
        for i, (input_size, output_size) in enumerate(zip(ins, outs)):
            layer = torch.nn.Linear(input_size, output_size)
            self.fnn_layers.append(layer)
            self.add_module(f'layer{i}', layer)

    def apply_fully_connected_layers(self, x):
        """
        Apply the fully connected layers generated by initialize_fully_connected_layers_based_on_sizes().
        """
        assert self.fnn_has_been_initialized
        original_input = x
        for layer in self.fnn_layers[:-1]:
            x = layer(x)
            x = self.fnn_activation_function(x)
        if self.fnn_last_layer_uses_residual:
            x = torch.cat([x, original_input], dim=1)
        # Do not apply a RELU after the last layer.
        # This is because the last layer could be used to compute the logits for a softmax or for some other purpose
        # where we do not want to limit its range of values.
        x = self.fnn_layers[-1](x)
        return x

    def initialize_attention(self, attention_type, gumbel_softmax_tau_start_value=None):
        """
        Remember that we want to use attention, and of what type.
        """
        self.attention_has_been_initialized = True
        self.attention_type = attention_type
        if gumbel_softmax_tau_start_value is not None:
            self.attention_gumbel_softmax_tau = torch.nn.Parameter(torch.tensor([float(gumbel_softmax_tau_start_value)]))

    def apply_attention(self, weights, values):
        """
        Apply attention using a weight matrix and a set of values, based on initialize_attention().
        """
        assert self.attention_has_been_initialized
        if self.attention_type == 'softmax':
            weights = torch.nn.functional.softmax(weights, dim=1)
        elif self.attention_type == 'gumbel':
            weights = torch.nn.functional.gumbel_softmax(weights, dim=1, tau=self.attention_gumbel_softmax_tau, hard=True)
        else:
            raise NotImplementedError(self.attention_type)
        assert weights.shape == (values.shape[0], values.shape[1], 1), weights.shape
        val = (weights * values).sum(dim=1)
        return val

    def initialize_regularization(self, regularization_loss_accumulator, regularization_loss_boundary):
        """
        Remember an accumulator for regularization loss, and a boundary for the values of the tensors.
        """
        self.regularization_loss_has_been_initialized = True
        self.regularization_loss_accumulator = regularization_loss_accumulator
        self.regularization_loss_boundary = regularization_loss_boundary

    def add_regularization_loss(self, tensor):
        """
        Apply regularization loss to a tensor, based on initialize_regularization(),
        and store the loss in the accumulator.
        """
        assert self.regularization_loss_has_been_initialized
        if self.regularization_loss_boundary is None:
            return
        target = tensor.clamp(min=self.regularization_loss_boundary * -1, max=self.regularization_loss_boundary).detach()
        loss = torch.nn.MSELoss()(tensor, target)
        self.regularization_loss_accumulator.append(loss)


class FnnWithBlocks(TorchModuleWithFullyConnectedLayersAndHelperFunctions):
    """
    A fully connected neural network that takes a list of blocks as inputs and produces a list of blocks as outputs.
    It concatenates all of its inputs into a single tensor and then applies fully connected layers to that tensor.
    The output tensor is then split into a list of tensors.
    """

    def __init__(
            self,
            input_sizes: List[int],
            output_sizes: List[int],
            intermediate_sizes: List[int],
    ):
        super().__init__()
        self.input_sizes = list(input_sizes)
        self.output_sizes = list(output_sizes)
        fnn_sizes = [sum(input_sizes)] + intermediate_sizes + [sum(output_sizes)]
        self.initialize_fully_connected_layers_based_on_sizes(fnn_sizes, add_residuals_for_input=False)

    def forward(self, xs: List[torch.FloatTensor]):
        assert len(xs) == len(self.input_sizes), (len(xs), len(self.input_sizes))
        x = torch.cat(xs, dim=1)
        x = self.apply_fully_connected_layers(x)
        res = []
        i = 0
        for size in self.output_sizes:
            t = x[:, i:i + size]
            i += size
            res.append(t)
        assert i == sum(self.output_sizes)
        return res


class Multiplexer(TorchModuleWithFullyConnectedLayersAndHelperFunctions):
    """
    The Multiplexer, as described in the paper.
    It takes a list of input blocks to interpolate between,
    and a list of context tensors to base its decision on (these may be the same).
    """

    def __init__(
            self,
            attention_type: str,
            block_size: int,
            num_input_blocks: int,
            context_input_sizes: List[int],
            num_output_blocks: int,
            intermediate_sizes_of_fnn: List[int],
            regularization_loss_accumulator: List[torch.FloatTensor],
            regularization_loss_boundary: Optional[float],
    ):
        super().__init__()
        self.attention_type = attention_type
        self.block_size = block_size
        self.num_input_blocks = num_input_blocks
        self.num_output_blocks = num_output_blocks
        self.context_input_sizes = list(context_input_sizes)
        # FNN for making routing decision
        self.fnn_module = FnnWithBlocks(
            input_sizes=self.context_input_sizes,
            output_sizes=[self.num_input_blocks for _ in range(self.num_output_blocks)],
            intermediate_sizes=intermediate_sizes_of_fnn,
        )
        # Attention
        gumbel_softmax_tau_start_value = None
        if attention_type == 'gumbel':
            gumbel_softmax_tau_start_value = 1.0
        self.initialize_attention(attention_type, gumbel_softmax_tau_start_value=gumbel_softmax_tau_start_value)
        # Regularization
        self.initialize_regularization(regularization_loss_accumulator, regularization_loss_boundary)

    def forward(
            self,
            input_blocks: List[torch.FloatTensor],
            context_inputs: List[torch.FloatTensor],
    ):
        assert len(input_blocks) == self.num_input_blocks
        assert len(context_inputs) == len(self.context_input_sizes)
        batch_size = input_blocks[0].shape[0]
        # FNN for making routing decision
        network_outputs = self.fnn_module(xs=context_inputs)
        res = []
        assert len(network_outputs) == self.num_output_blocks
        # Prepare the input blocks for applying attention
        concatenated_input_blocks = torch.stack([a for a in input_blocks], dim=1)
        assert concatenated_input_blocks.shape == (batch_size, self.num_input_blocks, self.block_size)
        # Split the output of the FNN into chunks, one per output block of the Multiplexer
        for i, weights in enumerate(network_outputs):
            # Apply regularization loss
            self.add_regularization_loss(weights)
            # Apply attention to the input blocks
            weights = weights.unsqueeze(-1)
            assert weights.shape == (batch_size, len(input_blocks), 1), weights.shape
            val = self.apply_attention(weights, concatenated_input_blocks)
            assert val.shape == (batch_size, self.block_size), val.shape
            res.append(val)
        assert len(res) == self.num_output_blocks
        return res


class FNNR(TorchModuleWithFullyConnectedLayersAndHelperFunctions):
    """
    The FNNR module, as described in the paper.
    It takes a list of input blocks and a list of context tensors to base its decision on (these may be the same).
    """

    def __init__(
            self,
            num_inputs_and_outputs: int,
            block_size: int,
            sizes_of_context_tensors: List[int],
            intermediate_sizes_of_fnn: List[int],
            regularization_loss_accumulator: List[torch.FloatTensor],
            regularization_loss_boundary: Optional[float],
    ):
        super().__init__()
        self.num_inputs_and_outputs = num_inputs_and_outputs
        self.block_size = block_size
        self.sizes_of_context_tensors = list(sizes_of_context_tensors)
        # FNN for making routing decision
        self.fnn_module = FnnWithBlocks(
            input_sizes=[block_size] * num_inputs_and_outputs + self.sizes_of_context_tensors,
            output_sizes=[block_size] * num_inputs_and_outputs + [num_inputs_and_outputs],
            intermediate_sizes=intermediate_sizes_of_fnn,
        )
        # Regularization
        self.initialize_regularization(regularization_loss_accumulator, regularization_loss_boundary)

    def forward(
            self,
            input_blocks: List[torch.FloatTensor],
            context_inputs: List[torch.FloatTensor],
    ):
        assert len(input_blocks) == self.num_inputs_and_outputs, (len(input_blocks), self.num_inputs_and_outputs)
        assert len(context_inputs) == len(self.sizes_of_context_tensors), (
        len(context_inputs), len(self.sizes_of_context_tensors))
        batch_size = input_blocks[0].shape[0]
        for x in input_blocks:
            assert x.shape == (batch_size, self.block_size), x.shape
        for x, size in zip(context_inputs, self.sizes_of_context_tensors):
            assert x.shape == (batch_size, size), x.shape
        assert len(input_blocks) + len(context_inputs) == len(set(input_blocks) | set(context_inputs)), \
            "Inputs and contexts should not overlap, since all inputs will be used in the same way " \
            "as contexts as well, so this would be redundant."
        # FNN for generating new blocks and a gating weight for each of them
        xs = input_blocks + context_inputs
        xs = self.fnn_module(xs=xs)
        outputs_without_residuals = xs[:-1]
        raw_gating_weights = xs[-1]
        # Apply the regularization loss to the gating weights
        assert raw_gating_weights.shape == (batch_size, self.num_inputs_and_outputs)
        self.add_regularization_loss(raw_gating_weights)
        # Apply a sigmoid to the gating weights
        gating_weights = torch.sigmoid(raw_gating_weights)
        gating_weights_list = list(gating_weights.split(1, dim=1))
        assert len(outputs_without_residuals) == len(input_blocks)
        assert len(gating_weights_list) == len(input_blocks), (len(gating_weights), len(input_blocks))
        assert all(gating_weight.shape == (batch_size, 1) for gating_weight in gating_weights_list)
        res = []
        # Loop over the blocks
        for i, (input, output, gating_weight) in enumerate(
                zip(input_blocks, outputs_without_residuals, gating_weights_list)):
            assert gating_weight.shape == (batch_size, 1)
            assert input.shape == (batch_size, self.block_size)
            assert output.shape == (batch_size, self.block_size)
            # Combine each input block with its corresponding output block using the gating weight
            output_with_residual = gating_weight * output + (1.0 - gating_weight) * input
            res.append(output_with_residual)
        return res


class MFNNR(torch.nn.Module):
    """
    The MFNNR module, as described in the paper.
    It takes a list of input blocks and a list of context tensors to base its decision on (these may be the same).
    """

    def __init__(
            self,
            block_size: int,
            num_input_blocks: int,
            num_output_blocks: int,
            multiplexer_attention_type: str,
            multiplexer_intermediate_sizes_of_fnn: List[int],
            fnnr_intermediate_sizes_of_fnn: List[int],
            regularization_loss_accumulator: List[torch.FloatTensor],
            regularization_loss_boundary: Optional[float],
    ):
        super().__init__()
        self.block_size = block_size
        # Multiplexer
        self.multiplexer_module = Multiplexer(
            attention_type=multiplexer_attention_type,
            block_size=block_size,
            num_input_blocks=num_input_blocks,
            context_input_sizes=[block_size] * num_input_blocks,
            num_output_blocks=num_output_blocks,
            intermediate_sizes_of_fnn=multiplexer_intermediate_sizes_of_fnn,
            regularization_loss_accumulator=regularization_loss_accumulator,
            regularization_loss_boundary=regularization_loss_boundary,
        )
        # FNNR
        self.fnnr_module = FNNR(
            num_inputs_and_outputs=num_output_blocks,
            block_size=block_size,
            sizes_of_context_tensors=[block_size] * num_input_blocks,
            intermediate_sizes_of_fnn=fnnr_intermediate_sizes_of_fnn,
            regularization_loss_accumulator=regularization_loss_accumulator,
            regularization_loss_boundary=regularization_loss_boundary,
        )

    def forward(self,
                input_blocks: List[torch.FloatTensor],
                ):
        # Multiplexer
        intermediate_blocks = self.multiplexer_module(input_blocks=input_blocks, context_inputs=input_blocks)
        # FNNR
        output_blocks = self.fnnr_module(input_blocks=intermediate_blocks, context_inputs=input_blocks)
        return output_blocks


class SMFR(torch.nn.Module):
    def __init__(
            self,
            block_size: int,
            num_input_tensors: int,
            num_output_tensors: int,
            stack_width: int,
            stack_depth: int,
            multiplexer_attention_type: str,
            sizes_of_contained_fnns: List[int],
            regularization_loss_accumulator: List[torch.FloatTensor],
            regularization_loss_boundary: Optional[float],
    ):
        super().__init__()
        self.mfnnrs = []
        list_of_dimensions = [num_input_tensors] + [stack_width] * stack_depth + [num_output_tensors]
        for i, (num_in, num_out) in enumerate(zip(list_of_dimensions[:-1], list_of_dimensions[1:])):
            mfnnr = MFNNR(
                block_size=block_size,
                num_input_blocks=num_in,
                num_output_blocks=num_out,
                multiplexer_attention_type=multiplexer_attention_type,
                multiplexer_intermediate_sizes_of_fnn=list(sizes_of_contained_fnns),
                fnnr_intermediate_sizes_of_fnn=list(sizes_of_contained_fnns),
                regularization_loss_accumulator=regularization_loss_accumulator,
                regularization_loss_boundary=regularization_loss_boundary,
            )
            self.mfnnrs.append(mfnnr)
            self.add_module(f"MFNNR_{i}", mfnnr)

    def forward(
            self,
            input_blocks: List[torch.FloatTensor],
    ):
        blocks = input_blocks
        for i, mfnnr in enumerate(self.mfnnrs):
            blocks = mfnnr(input_blocks=blocks)
        return blocks
