from typing import Tuple, Generator, Any

import matplotlib.pyplot as plt
import seaborn as sns
import torch.nn
from path import Path


def heatmap_avg_kernel_filters(model, additional_axs=None):
    for i, (name_mod, mod) in enumerate(conv2d_modules(model, skip_1x1_convs=True)):
        c_out, c_in, k_w, k_h = mod.weight.shape
        ax = sns.heatmap(mod.weight.mean(dim=0).mean(dim=0).detach().numpy())
        plt.title(name_mod + f' with shape {k_w}x{k_h}')
        if additional_axs:
            sns.heatmap(mod.weight.mean(dim=0).mean(dim=0).detach().numpy(), ax=additional_axs[i])
        yield name_mod, ax

def heatmap_avg_kernel_filters_axes(model, axes):
    for i, (name_mod, mod) in enumerate(conv2d_modules(model, skip_1x1_convs=True)):
        c_out, c_in, k_w, k_h = mod.weight.shape
        sns.heatmap(mod.weight.mean(dim=0).mean(dim=0).detach().numpy(), ax=axes[i])
        axes[i].set_title(name_mod + f' with shape {k_w}x{k_h}')
        yield name_mod, axes[i]



def conv2d_modules(model, skip_1x1_convs: bool):
  for name_mod, mod in model.named_modules():
    if isinstance(mod, torch.nn.Conv2d):
      c_out, c_in, k_w, k_h = mod.weight.shape
      if skip_1x1_convs:
        if not(k_w == 1 and k_h == 1):
          yield name_mod, mod
      else:
        yield name_mod, mod