import os
import numpy as np
from PIL import Image
import glob

from kornia.metrics import psnr, ssim
import lpips
from sifid import SIFID

import torch
import torch.nn.functional as F
from torchvision.utils import save_image

from watermark_anything.data.metrics import msg_predict_inference
from notebooks.inference_utils import (
    load_model_from_checkpoint, 
    default_transform, 
    create_random_mask, 
    unnormalize_img,
    plot_outputs,
    msg2str,
    calculate_iou_score
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
# to load images
def load_img(path):
    img = Image.open(path)
    img_rgb = img.convert("RGB")
    img_rgb = default_transform(img_rgb).unsqueeze(0).to(device)
    return img_rgb

def compute_lpips(x, y, net='alex'):
    """
    Compute LPIPS between two images.
    Args:
        x (torch.Tensor): Image tensor of shape (N, C, H, W) in range [-1, 1].
        y (torch.Tensor): Image tensor of shape (N, C, H, W) in range [-1, 1].
    Returns:
        (float): LPIPS.
    """
    lpips_fn = lpips.LPIPS(net=net, verbose=False).cuda() if isinstance(net, str) else net
    x, y = x.cuda(), y.cuda()
    return lpips_fn(x, y).detach().cpu().numpy().squeeze()

def compute_sifid(x, y, net=None):
    """
    Compute SIFID between two images.
    Args:
        x (torch.Tensor): Image tensor of shape (N, C, H, W) in range [-1, 1].
        y (torch.Tensor): Image tensor of shape (N, C, H, W) in range [-1, 1].
    Returns:
        (float): SIFID.
    """
    fn = SIFID() if net is None else net
    out = [fn(xi, yi) for xi, yi in zip(x, y)]
    return np.array(out)

# Load the model from the specified checkpoint
exp_dir = "checkpoints"
json_path = os.path.join(exp_dir, "params.json")
ckpt_path = os.path.join(exp_dir, 'wam_mit.pth') 
wam = load_model_from_checkpoint(json_path, ckpt_path).to(device).eval()

def main():
    # Parameters
    proportion_masked = 0.25  # Proportion of the image to be watermarked (0.5 means 50% of the image)
    loop_time = 5             # The # of the loop for loop strategy
    alpha = [0.5, 0.6, 0.8, 0.9, 1.0] #scale factor (experimentally predetermined)

    orig_dir = '../coco2017/val2017_512'  # 512x512 original images folder (input)
    wm_dir = './wm_imgs'     #watermaked images folder (output)
    res_dir = './res_imgs'   #restored images folder (output)
    os.makedirs(wm_dir, exist_ok=True)
    os.makedirs(res_dir, exist_ok=True)
    list_dir = glob.glob(f'{orig_dir}/*.*', recursive=True)

    # Iterate over each image in the directory
    for ii, img_path in enumerate(list_dir):
        # Load and preprocess the image
        img_pt = load_img(img_path)  # [1, 3, H, W]

        # define a 32-bit message to be embedded into the images
        torch.manual_seed(42 + ii)
        wm_msg = wam.get_random_msg(1)  # [1, 32]
        print(f"Original message to hide: {msg2str(wm_msg[0])}")

        # Embed the watermark message into the image
        outputs = wam.embed(img_pt, wm_msg)

        # Create a random mask to watermark only a part of the image
        mask_pt = create_random_mask(img_pt, num_masks=1, mask_percentage=proportion_masked)  # [1, 1, H, W]

        img_w = outputs['imgs_w'] * mask_pt + img_pt * (1 - mask_pt)  # [1, 3, H, W]

        # Save the watermark image
        wm_path = img_path.replace(orig_dir, wm_dir)
        save_image(unnormalize_img(img_w), wm_path)  
        
        # Load and preprocess the watermarked image
        img_w = load_img(wm_path)

        # Detect the watermark in the watermarked image
        preds = wam.detect(img_w)["preds"]  # [1, 33, 256, 256]
        mask_preds = F.sigmoid(preds[:, 0, :, :])  # [1, 256, 256], predicted mask
        bit_preds = preds[:, 1:, :, :]  # [1, 32, 256, 256], predicted bits
        
        # Predict the embedded message and calculate bit accuracy
        pred_message = msg_predict_inference(bit_preds, mask_preds).float()  # [1, 32]
        bit_acc = (pred_message.to(device) == wm_msg.to(device)).float().mean().item()
        print(f"Predicted message in the watermarked image: {msg2str(pred_message[0])}, bit accuracy: {bit_acc}")
            
        # Save the watermarked image and the detection mask
        mask_preds = F.interpolate(mask_preds.unsqueeze(1), size=(img_pt.shape[-2], img_pt.shape[-1]), mode="bilinear", align_corners=False)  # [1, 1, H, W]            
        mask_preds = (mask_preds > 0.5).float() 
        
        # Metrics
        psnr_value = psnr(unnormalize_img(img_w), unnormalize_img(img_pt), 1).item()
        ssim_value = torch.mean(ssim(unnormalize_img(img_w), unnormalize_img(img_pt), window_size=11)).item()
        lpips_value  = compute_lpips(img_w, img_pt)
        sifid_value  = compute_sifid(img_w, img_pt)
        print(f"PSNR: {psnr_value}, SSIM: {ssim_value}, LPIPS: {lpips_value}, SIFID: {sifid_value[0]}")

        # Restore the watermarked image to the original image by subtraction
        for jj in range(loop_time):
            if jj == 0:
                outputs = wam.embed(img_w, pred_message) # embed the same message into the watermarked image
                img_dw = outputs['imgs_w'] * mask_preds + img_w * (1 - mask_preds)  # [1, 3, H, W], double watermarked image
                img_res = img_w - (img_dw - img_w) * alpha[jj]  # [1, 3, H, W], restored image
            else:
                outputs = wam.embed(img_res, pred_message) # embed the same message into the last restored image (loop strategy)
                img_dw = outputs['imgs_w'] * mask_preds + img_res * (1 - mask_preds)  # [1, 3, H, W], double watermarked image
                img_res = img_w - (img_dw - img_res) * alpha[jj]  # [1, 3, H, W], restored image

        # Save the restored image
        res_path = img_path.replace(orig_dir, res_dir)
        save_image(unnormalize_img(img_res), res_path)  
        
        # Load and preprocess the restored image
        img_res = load_img(res_path)  # [1, 3, H, W]

        # Detect the watermark in the restored image
        preds_res = wam.detect(img_res)["preds"]  # [1, 33, 256, 256]
        mask_preds_res = F.sigmoid(preds_res[:, 0, :, :])  # [1, 256, 256], predicted mask
        bit_preds_res = preds_res[:, 1:, :, :]  # [1, 32, 256, 256], predicted bits

        # Predict the embedded message and calculate bit accuracy
        pred_message_res = msg_predict_inference(bit_preds_res, mask_preds_res).float()  # [1, 32]
        bit_acc = (pred_message_res.to(device) == wm_msg.to(device)).float().mean().item()
        print(f"Predicted message in the restored image: {msg2str(pred_message_res[0])}, bit accuracy: {bit_acc}")
        
        # Save the watermarked image and the detection mask
        mask_preds_res = F.interpolate(mask_preds_res.unsqueeze(1), size=(img_pt.shape[-2], img_pt.shape[-1]), mode="bilinear", align_corners=False)  # [1, 1, H, W]            
        mask_preds_res = (mask_preds_res > 0.5).float() 

        # Metrics
        psnr_value = psnr(unnormalize_img(img_res), unnormalize_img(img_pt), 1).item()
        ssim_value = torch.mean(ssim(unnormalize_img(img_res), unnormalize_img(img_pt), window_size=11)).item()
        lpips_value  = compute_lpips(img_res, img_pt)
        sifid_value  = compute_sifid(img_res, img_pt)
        print(f"PSNR: {psnr_value}, SSIM: {ssim_value}, LPIPS: {lpips_value}, SIFID: {sifid_value[0]}")



if __name__ == "__main__":    
    main()
