from matplotlib import pyplot as plt
import numpy as np
import os
import math


def visualize_imgs_and_preds(numpy_images, labels, topk_preds, n_cols=5, filename=None):
    count = len(labels)
    n_rows = math.ceil(count / n_cols)
    plt.rcParams['font.family'] = 'monospace'
    fig, axs = plt.subplots(n_rows, n_cols, figsize=(15, 20), squeeze=False)
    for i, (img, label, topk_pred) in enumerate(zip(numpy_images, labels, topk_preds)):
        ax = axs[i//n_cols, i%n_cols]
        ax.imshow(img)
        ax.set_title(label, fontsize=12)
        ax.get_xaxis().set_ticks([])
        ax.get_yaxis().set_ticks([])
        for spine in ax.spines.values():
            spine.set_visible(False)
        probs = topk_pred['probs']
        preds = topk_pred['classnames']
        txt = f'{preds[0]}: {probs[0]*100:.1f}%\n'
        for j in range(1, len(probs)):
            txt += f'{preds[j]}: {probs[j]*100:.1f}%\n'
            # fontdict={'fontsize': 10}, transform=ax.transAxes
            # ax.text(0.1, 0.2 * j, f'{preds[j]}: {probs[j]*100:.1f}%',
                        # fontdict={'fontsize': 10}, transform=ax.transAxes)
        ax.set_xlabel(txt, fontsize=10, loc='left')
    if filename is not None:
        root = os.path.dirname(filename)
        if not os.path.exists(root):
            os.makedirs(root)
        plt.savefig(filename+'.pdf')
    else:
        print('Please provide a path for saving visualizations.')


def plot_activation_rates(model, dataset_name, model_name, save_path=None):
    title_fontsize = 22
    label_fontsize = 20
    tick_fontsize = 18
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    # get activation rates for modified resnet
    inner_model = model.pre_featurizer.model.visual
    dataset_name_map = {
        'ImageNet': 'ImageNet',
        'ImageNetV2': 'ImageNetV2',
        'ImageNetR': 'ImageNet-R',
        'ObjectNet': 'ObjectNet',
        'ImageNetSketch': 'ImageNet Sketch',
        'ImageNetA': 'ImageNet-A'
    }
    if hasattr(inner_model, 'layer1'):  # resnet
        for l in [1, 2, 3, 4]:
            layer = getattr(inner_model, f'layer{l}')
            n_blocks = len(layer)
            for b in range(n_blocks):
                block = layer[b]
                rate = block.activation_rate
                val = block.activation_val
                print(rate.size())
                # clear activation rate
                block.activation_rate = None
                block.activation_val = None
                block.batch_count = 0
                if rate is not None:
                    # plot activation rate
                    print(rate.max(), rate.min(), rate.mean())
                    print(val.max(), val.min(), val.mean())
                    plt.figure()
                    counts, bins = np.histogram(rate.numpy(), bins=100, range=(0., 1.))
                    plt.hist(bins[:-1], bins, weights=counts)
                    plt.title(dataset_name_map[dataset_name], fontdict={'fontsize': title_fontsize})
                    plt.xlabel('Activation rate', fontdict={'fontsize': label_fontsize})
                    plt.ylabel('Number of neurons', fontdict={'fontsize': label_fontsize})
                    plt.tick_params(labelsize=tick_fontsize)
                    plt.tight_layout()
                    filename = os.path.join(save_path, f'{model_name}_stage{l}_block{b}.pdf')
                    plt.savefig(filename)
                    print(f'Figure saved at {filename}.')
                    plt.close()
                    # plot activation value
                    plt.figure()
                    val_counts, val_bins = np.histogram(val.numpy(), bins=100, range=(0., 1.0))
                    plt.hist(val_bins[:-1], val_bins, weights=val_counts)
                    plt.title(dataset_name_map[dataset_name], fontdict={'fontsize': title_fontsize})
                    plt.xlim(0, 0.5)
                    plt.ylim(0, 50000)
                    plt.xlabel('Average activation value', fontdict={'fontsize': label_fontsize})
                    plt.ylabel('Number of neurons', fontdict={'fontsize': label_fontsize})
                    plt.tick_params(labelsize=tick_fontsize)
                    plt.tight_layout()
                    filename = os.path.join(save_path, f'{model_name}_stage{l}_block{b}_VAL.pdf')
                    plt.savefig(filename)
                    print(f'Figure saved at {filename}.')
                    plt.close()


def visualize_kernel_weights(model, model_name, n_cols=8, filename=None):
    if 'clip' in model_name.lower():
        first_conv = model.model.visual.conv1
    else:  # ERM resnet
        first_conv = model.model.conv1
    
    weight = first_conv.weight.cpu().numpy()
    weight = np.moveaxis(weight, 1, 3)  # (N, H, W, C)
    n = weight.shape[0]
    if n > 64:
        weight = weight[:64, ...]  # choose the first 64 kernels for visualization
        n = 64
    
    n_rows = math.ceil(n / n_cols)
    plt.rcParams['font.family'] = 'monospace'
    fig, axs = plt.subplots(n_rows, n_cols, figsize=(10, 10), squeeze=False)
    fig.suptitle(model_name)
    for i, w in enumerate(weight):
        for j in range(3):
            w[:, :, j] = (w[:, :, j] - np.amin(w[:, :, j])) / (np.amax(w[:, :, j]) - np.amin(w[:, :, j]))
        ax = axs[i//n_cols, i%n_cols]
        ax.imshow(w)
        ax.get_xaxis().set_ticks([])
        ax.get_yaxis().set_ticks([])
        for spine in ax.spines.values():
            spine.set_visible(False)
    if filename is not None:
        root = os.path.dirname(filename)
        if not os.path.exists(root):
            os.makedirs(root)
        plt.savefig(filename+'.pdf')
    else:
        print('Please provide a path for saving visualizations.')


# visualization test
if __name__ == '__main__':
    print('========== Visualization test ==========')
    n = 56
    k = 3
    imgs = np.random.rand(n, 224, 224, 3)
    labels = [f'class {i}' for i in range(n)]
    topk_preds = [{'probs': [np.random.rand() for _ in range(k)],
                   'classnames': [f'class {i}' for i in range(k)]} for j in range(n)]
    filename = './tmp/test'
    visualize_imgs_and_preds(imgs, labels, topk_preds=topk_preds,
                             n_cols=8, filename=filename)
    print(f'Test figure generated in {os.path.dirname(filename)}.')
    print('Done!')
