import argparse
import torch
from torchvision.utils import save_image, make_grid
import os

from data.dataloader import get_imagenet_loaders, get_svhn_loaders, get_cure_tsr_loaders, get_gtsrb_loaders
from models.load import MUNIT_Model

LOADER_DICT = {
    'svhn': get_svhn_loaders,
    'gtsrb': get_gtsrb_loaders,
    'cure_tsr': get_cure_tsr_loaders,
    'imagenet': get_imagenet_loaders
}

def main(args):
    dataset = 'imagenet'
    challenge = 'snow'
    path = 'training/models/snow-3-model.pt'
    cuda = True

    args.train_data_dir = '../../Model-Based-ImageNet/datasets/imagenet/tmp/train_first_50'
    args.val_data_dir = f'../../Model-Based-ImageNet/datasets/imagenet_c/full/tmp_{challenge}/first_fifty/5'
    args.half_prec = False

    save_path = f'../gallery/{dataset}/{challenge}'
    loader_fn = LOADER_DICT[dataset]
    os.makedirs(save_path, exist_ok=True)

    trn_loader, val_loader, _, _ = loader_fn(args)
    trn_imgs, _ = next(iter(trn_loader))
    save_image(trn_imgs, os.path.join(save_path, 'train.png'))

    val_imgs, _ = next(iter(val_loader))
    save_image(val_imgs, os.path.join(save_path, 'val.png'))

    G = MUNIT_Model(path, reverse=False)
    if cuda: G = G.cuda()

    delta = torch.randn(trn_imgs.size(0), 8, 1, 1)
    if cuda: delta = delta.cuda()

    if cuda: trn_imgs = trn_imgs.cuda()
    mb_images = G(trn_imgs, delta)

    mb_path = os.path.join(save_path, 'model_based.png')
    save_image(mb_images, mb_path)


    for j in range(10):
        img = trn_imgs[j]
        full = None
        if cuda: img = img.cuda()

        for row_idx in range(3):
            row_list = []
            for col_idx in range(3):
                delta = torch.randn(1, 8, 1, 1)
                if cuda: delta = delta.cuda()
                mb_img = G(img.unsqueeze(0), delta).squeeze()
                row_list.append(mb_img)

            row_samp = torch.cat(row_list, dim=-1)
            full = row_samp if full is None else torch.cat([row_samp, full], dim=-2)

        grid_path = os.path.join(save_path, f'grid-{j}.png')
        save_image(full, grid_path)

        orig_path = os.path.join(save_path, f'grid-orig-{j}.png')
        save_image(img, orig_path)


    for j in range(10):
        first_img = trn_imgs[j]
        if cuda: first_img = first_img.cuda()

        ones = torch.ones_like(first_img)
        mb_img_list = [first_img, ones]
        for _ in range(6):
            delta = torch.randn(1, 8, 1, 1)
            if cuda: delta = delta.cuda()

            mb_img = G(first_img.unsqueeze(0), delta).squeeze()
            mb_img_list.append(mb_img)

        multimodal_path = os.path.join(save_path, f'sampled-{j}.png')
        save_image(make_grid(mb_img_list), multimodal_path)



    
    

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Sample MUNIT models')
    parser.add_argument('--data-size', type=int, default=224, help="Size of each image")
    parser.add_argument('--batch-size', type=int, default=24, help='Training/validation batch size')
    parser.add_argument('-j', '--workers', default=8, type=int, metavar='N',
                        help='number of data loading workers (default: 8)')
    parser.add_argument('--distributed', action='store_true', help='Run distributed training. Default True')
    # parser.add_argument('--train-data-dir', metavar='DIR', required=True, help='Path to training dataset.')
    # parser.add_argument('--val-data-dir', metavar='DIR', required=True, help='Path to validation dataset.')
    args = parser.parse_args()
    main(args)