import numpy
import torch
from fairseq.models.roberta_custom.collections import DropBlock1D, DropBlockChannel1D, ReverseAdaptiveDropBlockChannel1D
from matplotlib import pyplot as plt

# batch size, head, seq_length
seq = torch.randn(8, 16, 32*32)
org_seq = seq.numpy()
fn = DropBlockChannel1D(block_size=4, drop_prob=0.3)
# fn = DropBlock1D(block_size=4, drop_prob=0.3)
# fn = ReverseAdaptiveDropBlockChannel1D(block_size=4, drop_prob=0.3)
post_seq= fn(seq)
import pdb; pdb.set_trace()

print((seq != 0).sum() / seq.numel())
print((post_seq != 0).sum() / post_seq.numel())
exit()
# plot two numpy images side by side
_, axs = plt.subplots(1, 3, figsize=(12, 12))
axs = axs.flatten()
# for img, ax in zip([org_image, image], axs):
axs[0].imshow(org_seq[0][0].reshape(32, 32, 1))
axs[1].imshow(post_seq[0][0].reshape(32, 32, 1))
plt.show()

exit()
image = torch.randn(8, 3, 32, 32)
org_image = image.numpy()
# apply DropBlock2D
# fn = ReverseAdaptiveDropBlockChannel2D(0.05, 4, 3)
# fn = ReverseAdaptiveDropBlockChannel2D(0.05, 4, 3)
fn = ReverseAdaptiveDropBlockChannel2D(0.1, 4)
fn2 = ReverseAdaptiveDropBlockChannel2D(3.0, 4)
# fn = DropBlock2D(0.05, 4)
# fn2 = DropBlock2D(0.2, 4)

image = fn(image)
image2 = fn2(image)

image = image.numpy()
image2 = image2.numpy()

print((image!=0).sum() / (org_image != 0).sum())
print((image2!=0).sum() / (org_image != 0).sum())
# exit()
# plot two numpy images side by side
_, axs = plt.subplots(1, 3, figsize=(12, 12))
axs = axs.flatten()
# for img, ax in zip([org_image, image], axs):
axs[0].imshow(org_image[0][0].reshape(32, 32, 1))
axs[1].imshow(image[0][0].reshape(32, 32, 1))
axs[2].imshow(image2[0][0].reshape(32, 32, 1))
plt.show()
