import numpy as np
import os
import torch
from torchvision.utils import save_image

from models.load import load_model
from data.dataloader import get_svhn_loaders, get_imagenet_loaders


def main():
    model_path = './training/models/snow-2-model-90k.pt'
    save_model_grid(path)

def save_model_grid(args, reverse=False, lower=-2, upper=2, num_pts=5):

    model = load_model(args, reverse=reverse)
    # trn_loader, _, _, _ = get_svhn_loaders(args)
    trn_loader, _, _, _ = get_imagenet_loaders(args)

    img, _, = next(iter(trn_loader))
    img = img[0].unsqueeze(0).cuda()
    img_samples = None

    for y in list(np.linspace(lower, upper, num=num_pts)):
        row_images = []
        for x in list(np.linspace(lower, upper, num=num_pts)):
            grid_style = torch.tensor([x, y]).reshape(1, 2).cuda()
            x_A_to_B = model(img, grid_style)
            row_images.append(x_A_to_B)

        row_sample = torch.cat(row_images, dim=-1)

        if img_samples is None:
            img_samples = row_sample
        else:
            img_samples = torch.cat([row_sample, img_samples], dim=-2)

    fname_orig = os.path.join('original.png')
    save_image(img, fname_orig)

    fname_grid = os.path.join('grid.png')
    save_image(img_samples, fname_grid)

def save_img_grid(img, model, lower=-1, upper=1, num_pts=4):
    img = img[0].unsqueeze(0).cuda()
    img_samples = None

    for y in list(np.linspace(lower, upper, num=num_pts)):
        row_images = []
        for x in list(np.linspace(lower, upper, num=num_pts)):
            grid_style = torch.tensor([x, y]).reshape(1, 2).cuda()
            x_A_to_B = model(img, grid_style)
            row_images.append(x_A_to_B)

        row_sample = torch.cat(row_images, dim=-1)

        if img_samples is None:
            img_samples = row_sample
        else:
            img_samples = torch.cat([row_sample, img_samples], dim=-2)

    fname_orig = os.path.join('original.png')
    save_image(img, fname_orig)

    fname_grid = os.path.join('grid.png')
    save_image(img_samples, fname_grid)


if __name__ == '__main__':
    main()