import os
import pickle
import numpy as np
import fire
import matplotlib.pyplot as plt
from data.MNIST_test import LitMNIST
from data.KMNIST_test import LitKMNIST
from data.FashionMNIST_test import LitFashionMNIST
from scripts.magnitude_layers import *
from mpl_toolkits.axes_grid1 import ImageGrid

'''
Utility script to visualise the magnitude vectors
'''

def min_max(img):
    return (img-np.min(img))/(np.max(img)-np.min(img))

def main(dataset='KMNIST'):

    data = eval(f'Lit{dataset}(val_set=True)')
    data.setup()

    # The different magnitude layers
    layer_0 = MagnitudeLayer(p=0.)
    layer_1 = MagnitudeLayer(p=1.)
    layer_2 = MagnitudeLayer(p=2.)
    layer_prod_0 = MagnitudeLayerProduct(p=0.,hamming=True)
    layer_prod_1 = MagnitudeLayerProduct(p=1.)
    layer_prod_2 = MagnitudeLayerProduct(p=2.)

    data_loader = data.test_dataloader(bs=1)

    for i,(x,_) in enumerate(data_loader):
        plt.figure()
        plt.imshow(x.squeeze().view(28,28).numpy())
        plt.axis('off')
        plt.tight_layout()
        plt.savefig(os.path.join('output',f'{dataset}_img_{i}.pdf'))
        plt.close()

        plt.figure()
        plt.imshow(min_max(layer_0.forward(x).squeeze().view(28,28).numpy()))
        plt.axis('off')
        plt.tight_layout()
        plt.savefig(os.path.join('output',f'{dataset}_mag_{i}_p_0.pdf'))
        plt.close()

        plt.figure()
        plt.imshow(min_max(layer_1.forward(x).squeeze().view(28,28).numpy()))
        plt.axis('off')
        plt.tight_layout()
        plt.savefig(os.path.join('output',f'{dataset}_mag_{i}_p_1.pdf'))
        plt.close()

        plt.figure()
        plt.imshow(min_max(layer_2.forward(x).squeeze().view(28,28).numpy()))
        plt.axis('off')
        plt.tight_layout()
        plt.savefig(os.path.join('output',f'{dataset}_mag_{i}_p_2.pdf'))
        plt.close()

        plt.figure()
        plt.imshow(min_max(layer_prod_0.forward(x).squeeze().view(28,28).numpy()))
        plt.axis('off')
        plt.tight_layout()
        plt.savefig(os.path.join('output',f'{dataset}_mag_{i}_p_prod_0.pdf'))
        plt.close()

        plt.figure()
        plt.imshow(min_max(layer_prod_1.forward(x).squeeze().view(28,28).numpy()))
        plt.axis('off')
        plt.tight_layout()
        plt.savefig(os.path.join('output',f'{dataset}_mag_{i}_p_prod_1.pdf'))
        plt.close()

        plt.figure()
        plt.imshow(min_max(layer_prod_2.forward(x).squeeze().view(28,28).numpy()))
        plt.axis('off')
        plt.tight_layout()
        plt.savefig(os.path.join('output',f'{dataset}_mag_{i}_p_prod_2.pdf'))
        plt.close()
        if i>6:
            break

if __name__=='__main__':
    fire.Fire(main)
