import torch
import matplotlib.pyplot as plt
import torchvision


def log_datasets(data_module):
    if data_module.logger is None:
        return
    
    if data_module.is_ssl_run:
        log_dataset_imgs(data_module.logger, 
                visualize_multiview_dataset(data_module.train_dataloader()), 
                visualize_singleview_dataset(data_module.train_dataloader_knn()),
                visualize_singleview_dataset(data_module.val_dataloader()),
                None,
                label="ssl datasets"
        )
    else:
        log_dataset_imgs(data_module.logger, 
            visualize_singleview_dataset(data_module.train_dataloader()),
            None,
            visualize_singleview_dataset(data_module.val_dataloader()),
            visualize_singleview_dataset(data_module.test_dataloader()) if  data_module.has_test_set else None,
            label="probe_datasets"
        )

def log_dataset_imgs(wandb_logger, trainset_grid, knn_trainset_grid, valset_grid, testset_grid, label="datasets"):
    # check if trainset is multiview
    if len(trainset_grid) > 1:
        wandb_logger.log_image(key=label, 
                               images=[trainset_grid[0], trainset_grid[1], knn_trainset_grid[0], valset_grid[0]],
                               caption=["trainset view 1", "trainset view 2", "knn trainset", "valset"])
    else:
        if testset_grid is not None:
            wandb_logger.log_image(key=label, 
                                images=[trainset_grid[0], valset_grid[0], testset_grid[0]],
                                caption=["trainset", "valset", "testset"])
        else:
            wandb_logger.log_image(key=label, 
                                images=[trainset_grid[0], valset_grid[0]],
                                caption=["trainset", "valset"])
    
def visualize_multiview_dataset(data_loader, sidesize=3, name="multiview_dataset", save_img=False):
    """
    Plot small batch of multiview dataset. image grid, Side-by-side of view 1 and view 2
    """
    (x0, x1), y, _ = next(iter(data_loader))

    img_size = x0.shape[1:]

    assert x0.shape[0] >= sidesize**2, "sidesize**2 must be at most the batch size"

    view1_imgs = torch.zeros((sidesize**2, *img_size))
    view2_imgs = torch.zeros((sidesize**2, *img_size))
    

    for i in range(sidesize**2):
        view1_imgs[i] = x0[i].reshape(img_size)
        view2_imgs[i] = x1[i].reshape(img_size)



    view1_grid = torchvision.utils.make_grid(view1_imgs, nrow=sidesize, scale_each=True)
    view2_grid = torchvision.utils.make_grid(view2_imgs, nrow=sidesize, scale_each=True)

    if save_img:
        fig, ax = plt.subplots(1,2, figsize=[10,10])

        ax[0].imshow(view1_grid.permute((1, 2, 0)))
        ax[0].set_title("View 1")
        ax[0].tick_params(bottom = False)
        ax[0].tick_params(left = False)
        ax[0].set(xticklabels=[])
        ax[0].set(yticklabels=[])

        ax[1].imshow(view2_grid.permute((1, 2, 0)))
        ax[1].set_title("View 2")
        ax[1].tick_params(bottom = False)
        ax[1].tick_params(left = False)
        ax[1].set(xticklabels=[])
        ax[1].set(yticklabels=[])

        fig.savefig(f'{name}.png')

    return [view1_grid, view2_grid]


def visualize_singleview_dataset(data_loader, sidesize=3, name="multiview_dataset", save_img=False):
    """
    Plot small batch of multiview dataset. 25 image grid, Side-by-side of view 1 and view 2
    """
    x, y, _  = next(iter(data_loader))

    img_size = x.shape[1:]

    assert x.shape[0] >= sidesize**2, "sidesize**2 must be at most the batch size"
    
    view1_imgs = torch.zeros((sidesize**2, *img_size))
    for i in range(sidesize**2):
        view1_imgs[i] = x[i].reshape(img_size)

    view1_grid = torchvision.utils.make_grid(view1_imgs, nrow=sidesize, scale_each=True)

    if save_img:
        fig, ax = plt.subplots(1, 1, figsize=[5,5])
        ax.imshow(view1_grid.permute((1, 2, 0)))
        ax.set_title("View 1")

        ax.tick_params(bottom = False)
        ax.tick_params(left = False)
        ax.set(xticklabels=[])
        ax.set(yticklabels=[])
        fig.savefig(f'{name}.png')
    
    return [view1_grid]