import numpy
import torch
from feature_aug.collections.block_drop import DropBlock2D, DropBlockChannel2D, AdaptiveDropBlockChannel2D, ReverseAdaptiveDropBlockChannel2D
# from fairseq.fairseq.models.roberta_custom.collections import DropBlock1D, DropBlockChannel1D, ReverseAdaptiveDropBlockChannel1D
from matplotlib import pyplot as plt

# # batch size, head, seq_length
# seq = torch.randn(8, 16, 32)
# org_seq = seq.numpy()
# # fn = ReverseAdaptiveDropBlockChannel1D(block_size=12, drop_prob=0.1, block_prob=0.1)
# post_seq= fn(seq)

# print((seq != 0).sum() / (seq!= 0).sum())
# print((post_seq != 0).sum() / (post_seq != 0).sum())
# # exit()
# # plot two numpy images side by side
# _, axs = plt.subplots(1, 3, figsize=(12, 12))
# axs = axs.flatten()
# # for img, ax in zip([org_image, image], axs):
# # axs[0].imshow(org_image[0][0].reshape(32, 32, 1))
# # axs[1].imshow(image[0][0].reshape(32, 32, 1))
# # axs[2].imshow(image2[0][0].reshape(32, 32, 1))
# plt.show()
# exit()
import pickle

def plot_and_save(image, p, block_size, n_channels):
    # org_image = image.numpy()
    # apply DropBlock2D
    # fn = ReverseAdaptiveDropBlockChannel2D(0.05, 4, 3)
    # fn = ReverseAdaptiveDropBlockChannel2D(0.05, 4, 3)
    r_db_fn = ReverseAdaptiveDropBlockChannel2D(p, block_size)
    # fn2 = ReverseAdaptiveDropBlockChannel2D(0.1, 16)
    # fn2 = ReverseAdaptiveDropBlockChannel2D(3.0, 4)
    db_fn = DropBlock2D(p, block_size)
    dbc_fn = DropBlockChannel2D(p, block_size)
    # fn2 = DropBlock2D(0.1, 18)
    image_np = image.numpy()
    images = [image_np]
    for name, fn in zip(['DropBlock', 'Channel-wise DropBlock', 'Adaptive DropBlock'],[db_fn, dbc_fn, r_db_fn]):
        dropped = fn(image)
        dropped_np = dropped.numpy()
        images.append(dropped_np)

    image = image.numpy()
    image2 = image2.numpy()

    # print((image!=0).sum() / (org_image != 0).sum())
    # print((image2!=0).sum() / (org_image != 0).sum())
    # exit()
    # plot two numpy images side by side
    _, axs = plt.subplots(1, 3, figsize=(12, 12))
    axs = axs.flatten()
    # for img, ax in zip([org_image, image], axs):
    axs[0].imshow(org_image[0][0].reshape(32, 32, 1))
    axs[1].imshow(image[0][0].reshape(32, 32, 1))
    axs[2].imshow(image2[0][0].reshape(32, 32, 1))
    plt.show()

with open('resnet_out.pkl', 'rb') as f:
    x, out_conv1, out_layer1 = pickle.load(f)

