import os

import numpy as np
import torch

from omegaconf import OmegaConf
from utils import instantiate_from_config
from tqdm import tqdm
from model.ImagePreprocesser import ImagePreprocesser
from PIL import Image

def build_phase(device='cuda:0'):
    mapping_conf = OmegaConf.load(f'./options/ours_test.yaml')
    mapping_method_opt = mapping_conf.MappingModule
    mapping_method_opt.opts.device = device
    mapping_method_opt.opts.latent_shape = (4, 64, 64)
    mapping_method_opt.opts.batch_size = 1
    instantiate_from_config(mapping_method_opt)
    return

def load_config(device='cuda:0', model_config='./options/sd2.1.yaml', mapping_config='./options/ours_test.yaml'):
    model_conf = OmegaConf.load(model_config)
    model_opt = model_conf.SDModel
    model_opt.opts.device = device
    model = instantiate_from_config(model_opt)
    model.set_gen_scheduler()

    mapping_conf = OmegaConf.load(mapping_config)
    mapping_method_opt = mapping_conf.MappingModule
    mapping_method_opt.opts.device = device
    mapping_method_opt.opts.latent_shape = (4, 64, 64)
    mapping_method_opt.opts.batch_size = 1
    mapping_method = instantiate_from_config(mapping_method_opt)
    return model, mapping_method

def generate(prompt, model, mapping_method, save_path, device='cuda:0'):
    if hasattr(mapping_method, 'lm'):
        lm = mapping_method.lm
    else:
        lm = None
    with torch.no_grad():
        message = torch.randint(0, 2, (1, lm,)).to(device) if lm else None
        print(message.flatten().cpu().numpy().tolist())
        latents = mapping_method.embed_watermark(message).to(torch.float32)
        image = model.generate(latents, prompt)
        image.save(save_path)


def extract(model, preprocesser, mapping_method, image_path, device='cuda:0'):
    image_pt = preprocesser.load_image(image_path, image_size=512, device=device)
    model.set_inv_scheduler()
    inv_latents = model.invert(image_pt)
    rec_message = mapping_method.extract_watermark(inv_latents).cpu().numpy()
    print(rec_message.astype(np.uint8).flatten().tolist())


if __name__ == '__main__':
    # build phase, will log signature file name
    # build_phase()

    model, mapping_method = load_config()
    # generate
    prompt = 'a photo of an astronaut riding a horse on mars'
    save_path = './test.png'
    generate(prompt, model, mapping_method, save_path)
    print(f'Watermarked image has been save in {save_path}')
    # extract
    preprocesser = ImagePreprocesser()
    extract(model, preprocesser, mapping_method, save_path)

