from PIL import Image
import numpy as np
import torch
import torchvision.transforms as T
import random
import json
from transformers import set_seed

totensor = T.ToTensor()
topil = T.ToPILImage()

def recover_image(image, init_image, mask, background=False):
    image = totensor(image)
    mask = totensor(mask)
    init_image = totensor(init_image)
    if background:
        result = mask * init_image + (1 - mask) * image
    else:
        result = mask * image + (1 - mask) * init_image
    return topil(result)

def prepare_mask_and_masked_image(image, mask):
    image = np.array(image.convert("RGB"))
    image = image[None].transpose(0, 3, 1, 2)
    image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0

    mask = np.array(mask.convert("L"))
    mask = mask.astype(np.float32) / 255.0
    mask = mask[None, None]
    mask[mask < 0.5] = 0
    mask[mask >= 0.5] = 1
    mask = torch.from_numpy(mask)

    masked_image = image * (mask < 0.5)

    return mask, masked_image, image

def prepare_image(image):
    image = np.array(image.convert("RGB"))
    image = image[None].transpose(0, 3, 1, 2)
    image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0

    return image[0]
 
def set_seed_lib(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    random.seed(seed)
    set_seed(seed)

def load_image(image_name, is_mask = False):
    if is_mask:
        image = Image.open(f'../dataset/masks/mask_{image_name}.png').convert('RGB').resize((512,512))
    else:
        image = Image.open(f'../dataset/images/{image_name}.png').convert('RGB').resize((512,512))
    return image

def save_image(img, img_path):
    img.save(img_path, "PNG")

def get_train_val_image_prompt_list():
    file_loc = "../dataset/image_prompt_pairs.json"
    with open(file_loc, "r") as json_file:
        image_prompt_val_pairs = json.load(json_file)
    
    train_image_prompt_pairs = []
    val_image_prompt_pairs = []

    for image_prompt_val in image_prompt_val_pairs:
        image_prompt_pair = {"image": image_prompt_val["image"],
                             "prompts": image_prompt_val["prompts"]}
        if image_prompt_val["is_validation"]:
            val_image_prompt_pairs.append(image_prompt_pair)
        else:
            train_image_prompt_pairs.append(image_prompt_pair)
    return train_image_prompt_pairs, val_image_prompt_pairs

def get_train_lists(image_prompt_list):
    image_name_list = [image_prompt["image"][:-4] for image_prompt in image_prompt_list]
    prompt_list = [image_prompt["prompts"] for image_prompt in image_prompt_list]
    image_torch_list = []
    mask_torch_list = []
    prompt_train_list = []
    for image_ind, image_name in enumerate(image_name_list):
        image = load_image(image_name)
        image_mask = load_image(image_name, is_mask = True)
        mask_torch, image_torch, non_masked_image_torch = prepare_mask_and_masked_image(image, image_mask)
        image_torch = image_torch.half().cuda()
        non_masked_image_torch = non_masked_image_torch.half().cuda()
        mask_torch = mask_torch.half().cuda()
        
        cur_prompt_list = prompt_list[image_ind]
        for prompt in cur_prompt_list:
            image_torch_list.append(image_torch.squeeze(0))
            mask_torch_list.append(mask_torch.squeeze(0))
            prompt_train_list.append(prompt)
    return image_torch_list, mask_torch_list, prompt_train_list