import numpy
import torch
from feature_aug.collections.block_drop import DropBlock2D, DropBlockChannel2D, AdaptiveDropBlockChannel2D, ReverseAdaptiveDropBlockChannel2D
# from fairseq.fairseq.models.roberta_custom.collections import DropBlock1D, DropBlockChannel1D, ReverseAdaptiveDropBlockChannel1D
from matplotlib import pyplot as plt

# # batch size, head, seq_length
# seq = torch.randn(8, 16, 32)
# org_seq = seq.numpy()
# # fn = ReverseAdaptiveDropBlockChannel1D(block_size=12, drop_prob=0.1, block_prob=0.1)
# post_seq= fn(seq)

# print((seq != 0).sum() / (seq!= 0).sum())
# print((post_seq != 0).sum() / (post_seq != 0).sum())
# # exit()
# # plot two numpy images side by side
# _, axs = plt.subplots(1, 3, figsize=(12, 12))
# axs = axs.flatten()
# # for img, ax in zip([org_image, image], axs):
# # axs[0].imshow(org_image[0][0].reshape(32, 32, 1))
# # axs[1].imshow(image[0][0].reshape(32, 32, 1))
# # axs[2].imshow(image2[0][0].reshape(32, 32, 1))
# plt.show()
# exit()
image = torch.randn(8, 3, 32, 32)
org_image = image.numpy()
# apply DropBlock2D
# fn = ReverseAdaptiveDropBlockChannel2D(0.05, 4, 3)
# fn = ReverseAdaptiveDropBlockChannel2D(0.05, 4, 3)
fn = ReverseAdaptiveDropBlockChannel2D(0.1, 4)
fn2 = ReverseAdaptiveDropBlockChannel2D(0.1, 16)
# fn2 = ReverseAdaptiveDropBlockChannel2D(3.0, 4)
fn = DropBlock2D(0.1, 4)
fn2 = DropBlock2D(0.1, 18)

image = fn(image)
image2 = fn2(image)

image = image.numpy()
image2 = image2.numpy()

print((image!=0).sum() / (org_image != 0).sum())
print((image2!=0).sum() / (org_image != 0).sum())
# exit()
# plot two numpy images side by side
_, axs = plt.subplots(1, 3, figsize=(12, 12))
axs = axs.flatten()
# for img, ax in zip([org_image, image], axs):
axs[0].imshow(org_image[0][0].reshape(32, 32, 1))
axs[1].imshow(image[0][0].reshape(32, 32, 1))
axs[2].imshow(image2[0][0].reshape(32, 32, 1))
plt.show()
