import torch
import numpy as np
from PIL import Image
from pytorch_pretrained_biggan import (BigGAN, one_hot_from_names, truncated_noise_sample,
                                       display_in_terminal)

from custom_logging import get_logger

ut_logger = get_logger("unittest")


def convert_to_images(obj):
    """ Convert an output tensor from BigGAN in a list of images.
        Params:
            obj: tensor or numpy array of shape (batch_size, channels, height, width)
        Output:
            list of Pillow Images of size (height, width)
    """
    try:
        import PIL
    except ImportError:
        raise ImportError("Please install Pillow to use images: pip install Pillow")

    if not isinstance(obj, np.ndarray):
        obj = obj.detach().numpy()

    obj = obj.transpose((0, 2, 3, 1))
    obj = np.clip(((obj + 1) / 2.0) * 256, 0, 255)

    img = []
    for i, out in enumerate(obj):
        out_array = np.asarray(np.uint8(out), dtype=np.uint8)
        img.append(Image.fromarray(out_array))
    return img


def save_as_images(obj, file_name='output'):
    """ Convert and save an output tensor from BigGAN in a list of saved images.
        Params:
            obj: tensor or numpy array of shape (batch_size, channels, height, width)
            file_name: path and beggingin of filename to save.
                Images will be saved as `file_name_{image_number}.png`
    """
    img = convert_to_images(obj)

    for i, out in enumerate(img):
        current_file_name = file_name + '_%d.png' % i
        ut_logger.info("Saving image to {}".format(current_file_name))
        out.save(current_file_name, 'png')


if __name__ == '__main__':
    # Load pre-trained model tokenizer (vocabulary)
    model = BigGAN.from_pretrained('biggan-deep-256')

    # Prepare a input
    truncation = 0.4
    class_vector = one_hot_from_names(['soap bubble', 'coffee', 'mushroom'], batch_size=3)
    noise_vector = truncated_noise_sample(truncation=truncation, batch_size=3)

    # All in tensors
    noise_vector = torch.from_numpy(noise_vector)
    class_vector = torch.from_numpy(class_vector)

    # If you have a GPU, put everything on cuda
    noise_vector = noise_vector.to('cuda')
    class_vector = class_vector.to('cuda')
    model.to('cuda')

    # Generate an image
    with torch.no_grad():
        output = model(noise_vector, class_vector, truncation)

    # If you have a GPU put back on CPU
    output = output.to('cpu')

    # If you have a sixtel compatible terminal you can display the images in the terminal
    # (see https://github.com/saitoha/libsixel for details)
    # display_in_terminal(output)

    # Save results as png images
    save_as_images(output)

