from functools import partial
from typing import Dict, Tuple, Union, Sequence, Optional, Callable, Any

import torch
from torch import Tensor
from torch.nn import Module, BCELoss, Linear


def extract_features_from_layer(model: Module, layer_name: str, x: Tensor) -> (Tensor, Tensor):
    activation: Dict[str, Tensor] = {}

    def get_activation(name):
        def hook(model_hook: Module, x_hook: Tensor, out_hook: Tensor):
            activation[name] = out_hook.detach().cpu()
        return hook

    model.eval()
    with torch.no_grad():
        with getattr(model, layer_name).register_forward_hook(get_activation(layer_name)):
            output = model(x)

    return output, activation[layer_name]


def get_features_shape_from_layer(model: Module, layer_name: str, example_input: Tensor):
    return extract_features_from_layer(model, layer_name, example_input)[1].shape[1:]


def is_loss_criterion_vector_based(criterion: Module) -> bool:
    return isinstance(criterion, BCELoss)  # TODO: check other pytorch classes


def expand_or_shrink_linear_layer(original_layer: Module, n_classes_before: int, n_classes_after: int,
                                  reinitialize_existing_weights: bool = False, downsize_when_expanding: bool = False,
                                  partial_initializer: Optional[Callable[[Module, int, int], Any]] = None,
                                  layer_factory: Callable[[int, int], Module] = None) -> Module:
    use_bias: bool = original_layer.bias is not None
    if layer_factory is None:
        layer_factory = partial(Linear, bias=use_bias)
    original_output_units: int = original_layer.out_features
    result_output_units: int
    result_layer: Module

    if n_classes_before < n_classes_after:
        # Expand
        if downsize_when_expanding:
            result_output_units = n_classes_after
        else:
            result_output_units = max(original_output_units, n_classes_after)

        result_layer = layer_factory(original_layer.in_features, result_output_units)

        result_layer.weight.data[:n_classes_before] = original_layer.weight.data[:n_classes_before]
        if use_bias:
            result_layer.bias.data[:n_classes_before] = original_layer.bias.data[:n_classes_before]

        if result_output_units > n_classes_after:
            if reinitialize_existing_weights:
                # That is, downsize_when_expanding = False and result_output_units == original_output_units
                if partial_initializer is not None:
                    partial_initializer(result_layer, n_classes_before, result_output_units)
                # else -> default pytorch initializer for Linear
            else:
                # Note: "n_classes_before:result_output_units" is the same of "n_classes_before:"
                result_layer.weight.data[n_classes_before:result_output_units] = \
                    original_layer.weight.data[n_classes_before:result_output_units]
                if use_bias:
                    result_layer.bias.data[n_classes_before:result_output_units] = \
                        original_layer.bias.data[n_classes_before:result_output_units]
    else:  # n_classes_before >= n_classes_after
        # Shrink
        result_output_units = n_classes_after
        result_layer = layer_factory(original_layer.in_features, result_output_units)

        result_layer.weight.data[:n_classes_before] = original_layer.weight.data[:n_classes_before]
        if use_bias:
            result_layer.bias.data[:n_classes_before] = original_layer.bias.data[:n_classes_before]

    return result_layer


# TODO: shrunken?
def get_expanded_or_shrunken_head_from_model(model: Module, layer_name: str, n_classes_before: int,
                                             n_classes_after: int, reinitialize_existing_weights: bool = False,
                                             downsize_when_expanding: bool = False,
                                             partial_initializer: Optional[Callable[[Module, int, int], Any]] = None,
                                             layer_factory: Callable[[int, int], Module] = None) -> Module:
    return expand_or_shrink_linear_layer(getattr(model, layer_name), n_classes_before, n_classes_after,
                                         reinitialize_existing_weights=reinitialize_existing_weights,
                                         downsize_when_expanding=downsize_when_expanding,
                                         partial_initializer=partial_initializer, layer_factory=layer_factory)


def previous_layer(model: Module, layer: str):
    prev_name = None

    for layer_name, _ in model.named_children():
        if layer_name == layer:
            break
        prev_name = layer_name

    return prev_name
