import os
import torch
import random
import numpy as np
from PIL import Image
from huggingface_hub import snapshot_download
import sys
from LaVIT.models import build_model
# The local directory to save LaVIT checkpoint
model_path='YOUR_ROOT_PATH/model/LaVIT-7B-v2'
model_dtype='bf16'
LaVIT_path = 'YOUR_ROOT_PATH/MLLM/src/LaVIT'

# seed = 42
# torch.manual_seed(seed)
# np.random.seed(seed)
# random.seed(seed)

device_id = 0
torch.cuda.set_device(device_id)
device = torch.device('cuda')

# For Multi-modal Image Generation, must set `load_tokenizer=True` to load the tokenizer to tokenize input image.
# If you have already install xformers, set `use_xformers=True` to save the GPU memory (Xformers is not supported on V100 GPU)
# If you have already download the checkpoint, set `local_files_only=True`` to avoid auto-downloading from remote
model = build_model(model_path=model_path, model_dtype=model_dtype, check_safety=False,
            device_id=device_id, use_xformers=True, understanding=False, load_tokenizer=True, local_files_only=True)
model = model.to(device)    
print("Building Model Finsished")
torch_dtype = torch.bfloat16 if model_dtype=="bf16" else torch.float16


def image_grid(imgs, rows, cols):
    assert len(imgs) == rows*cols

    w, h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    grid_w, grid_h = grid.size
    
    for i, img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid

# Image + Text => Image
image_prompt = os.path.join(LaVIT_path, 'demo/dog.jpg')
text_prompt = 'It is swimming in the river'
input_prompts = [(image_prompt, 'image'), (text_prompt, 'text')]

# display(Image.open(image_prompt).resize((256, 256)))
print(text_prompt)

# LaVIT support 6 different image aspect ratios
ratio_dict = {
    '1:1' : (1024, 1024),
    '4:3' : (896, 1152),
    '3:2' : (832, 1216),
    '16:9' : (768, 1344),
    '2:3' : (1216, 832),
    '3:4' : (1152, 896),
}

# The image aspect ratio you want to generate
ratio = '1:1'
height, width = ratio_dict[ratio]

# with torch.cuda.amp.autocast(enabled=True, dtype=torch_dtype):
#     images = model.multimodal_synthesis(input_prompts, width=width, height=height,
#         guidance_scale_for_llm=4.0, num_return_images=1, num_inference_steps=25, top_k=50)

image_tokens = [[45764, 43410, 32718, 41980, 47852, 42645, 42310, 34848, 41505, 45950, 39446, 47654, 34009, 32626, 39671, 45624, 34772, 38860, 47654, 35827, 34488, 43676, 47504, 46137, 41455, 35866, 45191, 41585, 40349, 45628, 33058, 42177, 48091, 44840, 43533, 43729, 34062, 43538, 37710, 45939, 42177, 39708, 35773, 47198, 34656, 33115, 32026, 37004, 40565, 35965, 48199, 43777, 44393, 32954, 35596, 33344, 34030, 48301, 32460, 46374, 36145, 34589, 34030, 47340, 47087, 35139, 37064, 36259, 46104, 34772, 41339, 43124, 42933, 35773, 35097, 34283, 39289, 41157, 45663, 40496, 34003, 37588, 42425, 36259, 48180, 37409, 37024, 33013, 39293, 37698, 42933, 41920]]

image_tokens = torch.tensor(image_tokens, dtype=torch.long).to(device)

with torch.cuda.amp.autocast(enabled=True, dtype=torch_dtype):
    images = model.generate_image(input_prompts, width=width, height=height,
        guidance_scale_for_llm=4.0, num_return_images=1, num_inference_steps=25, top_k=50, image_tokens=image_tokens)
images[0].save(os.path.join(LaVIT_path, 'output/test.jpg'))
exit()

# display(images[0])
os.makedirs(os.path.join(LaVIT_path, 'output'), exist_ok=True)
images[0].save(os.path.join(LaVIT_path, 'output/it2i_output.jpg'))

# Image + Image => Image
image1 = os.path.join(LaVIT_path, 'demo/image_input1.jpg')
image2 = os.path.join(LaVIT_path, 'demo/image_input2.jpg')
input_prompts = [(image1, 'image'), (image2, 'image')]
# display(image_grid([Image.open(image1).resize((256, 256)), Image.open(image2).resize((256, 256))], 1, 2))

# LaVIT support 6 different image aspect ratios
ratio_dict = {
    '1:1' : (1024, 1024),
    '4:3' : (896, 1152),
    '3:2' : (832, 1216),
    '16:9' : (768, 1344),
    '2:3' : (1216, 832),
    '3:4' : (1152, 896),
}

# The image aspect ratio you want to generate
ratio = '1:1'
height, width = ratio_dict[ratio]

with torch.cuda.amp.autocast(enabled=True, dtype=torch_dtype):
    images = model.multimodal_synthesis(input_prompts, width=width, height=height,
        guidance_scale_for_llm=8.0, num_return_images=1, num_inference_steps=25, top_k=50)


# display(images[0])
images[0].save(os.path.join(LaVIT_path, 'output/ii2i_output.jpg'))
