from typing import Generator, Tuple

from torch_intermediate_layer_getter import IntermediateLayerGetter as MidGetter
import matplotlib.pyplot as plt
import torch


def get_relu_outputs_intermediate_dummy_image(model, h: int, w: int, c: int = 3):
    dummy_image = torch.zeros(1, c, h, w)
    return get_relu_outputs_intermediate(model, dummy_image)


def get_relu_outputs_intermediate(model, input_image):
    mid_getter, _ = relu_interm_getter_model(model)
    intermediate_outputs: dict = mid_getter(input_image)[0]
    return intermediate_outputs


def relu_interm_getter_model(model):
    relu_layers = [x[0] for x in model.named_modules() if isinstance(x[1], torch.nn.ReLU)]
    relu_layers_dict = {y: y for y in relu_layers}
    mid_getter = MidGetter(model, relu_layers_dict)
    return mid_getter, relu_layers


def plot_intermidiate_outputs(outputs: list):
    skip = True

    for x, y in outputs[0].items():
        if skip:
            skip = False
            continue
        print(y[-1].shape)
        c, w, h = y[-1][0].shape
        plt.title(x)
        plt.imshow(y[-1][0].mean(dim=0, keepdim=True).permute(1, 2, 0).expand(w, h, 3).detach().numpy())
        plt.show()
        plt.close()


def generator_conv_relu_output(model, image: torch.Tensor) -> Generator[Tuple[str, torch.Tensor], None, None]:
    relu_outputs = get_relu_outputs_intermediate(model, image)
    for layer_name, relu_output in relu_outputs.items():
        if isinstance(relu_output, list): relu_output = relu_output[-1]
        if len(relu_output.shape) == 2:
            print(f'skip {layer_name=} because it is activation function after fully connected layer')
            continue
        relu_output = relu_output[0]
        yield layer_name, relu_output


def generator_conv_relu_output_dummy_image(model, w: int, h: int, c: int, additional_axes: list = None) -> Generator[
    Tuple[str, torch.Tensor, plt.Axes], None, None]:
    dummy_image = torch.zeros(1, c, h, w)
    for i, (layer_name, relu_output) in enumerate(generator_conv_relu_output(model, dummy_image)):
        c, w, h = relu_output.shape
        plt.title(layer_name)
        if additional_axes is not None:
            additional_axes[i].imshow(
                relu_output.mean(dim=0, keepdim=True).permute(1, 2, 0).expand(w, h, 3).detach().numpy()
            )
        yield layer_name, relu_output, plt.imshow(
            relu_output.mean(dim=0, keepdim=True).permute(1, 2, 0).expand(w, h, 3).detach().numpy())

def relu_output_for_imshow(relu_output):
    c, w, h = relu_output.shape
    return relu_output.mean(dim=0, keepdim=True).permute(1, 2, 0).detach().numpy()
