import PIL
from PIL import Image
import requests
import torch
from io import BytesIO
import torchvision.transforms as transforms
import os 
import argparse
import numpy as np

import cv2

from diffusers import StableDiffusionInpaintPipeline
from saicinpainting.evaluation.losses.base_loss import *
from transformers import AutoProcessor, CLIPVisionModelWithProjection

from saicinpainting.evaluation.losses.base_loss import *
def psnr(p0, p1, peak=1.):
    return 10 * torch.log10(peak ** 2 / torch.mean((p0-p1)**2))

def download_image(url):
    response = requests.get(url)
    return PIL.Image.open(BytesIO(response.content)).convert("RGB")

def get_CLIP_score(model_clip,processor_clip,path1,path2):
    image1=Image.open(path1)
    image2=Image.open(path2)
    inputs1= processor_clip(images=image1, return_tensors="pt")
    inputs1['pixel_values'] = inputs1['pixel_values'].to("cuda")
    outputs1= model_clip(**inputs1)
    image_embeds1 = outputs1.image_embeds
    inputs2= processor_clip(images=image2, return_tensors="pt")
    inputs2['pixel_values'] = inputs2['pixel_values'].to("cuda")
    outputs2= model_clip(**inputs2)
    image_embeds2 = outputs2.image_embeds
    cosine_similarity = torch.nn.CosineSimilarity(dim=-1)
    cos_distance=cosine_similarity(image_embeds1,image_embeds2)
    return cos_distance


parser=argparse.ArgumentParser(description="config")
parser.add_argument("--image_dir",type=str,help="the dir to the inpainted images to be tested")
parser.add_argument("--gt_dir",type=str,help="dir to the gt images")
parser.add_argument("--gpu",type=int,help="which gpu to be used")
parser.add_argument("--mask_dir",type=str,help="path to the mask dir")
parser.add_argument("--maskA_dir",type=str,help="path to the mask A dir that used to frist inpainting")
args=parser.parse_args()


image_dir=args.image_dir
gt_dir=args.gt_dir
mask_dir=args.mask_dir
mask_A_dir=args.maskA_dir
init_images=os.listdir(image_dir)
mask_images=os.listdir(mask_dir)
gt_images=os.listdir(gt_dir)
mask_A_images=os.listdir(mask_A_dir)
init_images= sorted(init_images, key=lambda d:int(d.split('/')[-1].split('.')[0]))
gt_images=sorted(gt_images,key=lambda d:int(d.split('/')[-1].split('.')[0]))
mask_A_images=sorted(mask_A_images,key=lambda d:int(d.split('/')[-1].split('.')[0]))


#pipe = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16)
pipe = StableDiffusionInpaintPipeline.from_pretrained(
    "stabilityai/stable-diffusion-2-inpainting",
    torch_dtype=torch.float16,
)
pipe = pipe.to(args.gpu)
prompt = ""



ssim_origin_values=[]
lpips_origin_values=[]
psnr_origin_values=[]
ssim_B_with_0=[]
lpips_B_with_0=[]
psnr_B_with_0=[]
ssim_B_with_A=[]
lpips_B_with_A=[]
psnr_B_with_A=[]
clip_scores=[]
length=len(init_images)
ssim=SSIMScore()
lpips=LPIPSScore()
trans=transforms.ToTensor()
trans_pil=transforms.ToPILImage()

model_clip = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-base-patch32").to("cuda")
processor_clip = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")

with torch.no_grad():
    for i in range(100):
        init_image = cv2.imread(image_dir+"/"+init_images[i])
        init_image=cv2.cvtColor(init_image,cv2.COLOR_BGR2RGB)
        gt_image=cv2.imread(gt_dir+"/"+gt_images[i])
        gt_image=cv2.cvtColor(gt_image,cv2.COLOR_BGR2RGB)
        init_image=trans(init_image)
        gt_image=trans(gt_image)
        init_image=torch.unsqueeze(init_image,dim=0)
        gt_image=torch.unsqueeze(gt_image,dim=0)
        init_image=init_image.to("cuda")
        gt_image=gt_image.to("cuda")
        ssim_origin_value=ssim(init_image,gt_image)
        lpips_origin_value=lpips(init_image,gt_image)
        psnr_origin_value=psnr(init_image,gt_image)
        ssim_origin_values.append(ssim_origin_value)
        lpips_origin_values.append(lpips_origin_value)
        psnr_origin_values.append(psnr_origin_value)
        ssim_value=0
        lpips_value=0
        psnr_value=0
        ssim_another_value=0
        lpips_another_value=0
        psnr_another_value=0
        clip_score=0
        for j in range(10):
            init_image = cv2.imread(image_dir+"/"+init_images[i])
            init_image=cv2.cvtColor(init_image,cv2.COLOR_BGR2RGB)
            mask_image=cv2.imread(mask_dir+"/"+mask_images[10*i+j])
            mask_A_image=cv2.imread(mask_A_dir+"/"+mask_A_images[i])
            init_image=cv2.resize(init_image,(512,512))
            mask_image=cv2.resize(mask_image,(512,512))
            mask_A_image=cv2.resize(mask_A_image,(512,512))
            mask_A_image=mask_A_image[:,:,0]
            mask_image = mask_image[:, :, 0]
            mask_image = (np.ones_like(mask_image)-mask_A_image/255)*mask_image
            image = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0]
            image.save("inpainting_image/"+"inpainting"+str(i)+".png")
            image=trans(image)
            init_image=trans(init_image)
            #init_image=(torch.ones_like(trans(mask_A_image))-trans(mask_A_image))*init_image
            #image=(torch.ones_like(trans(mask_A_image))-trans(mask_A_image))*image
            #input_tensor = input_tensor.mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).type(torch.uint8).numpy()
            #save_image=(image*255).detach().cpu().type(torch.uint8).numpy()
            #save_image=save_image.transpose(1,2,0)
            #save_image=cv2.cvtColor(save_image,cv2.COLOR_RGB2BGR)
            #cv2.imwrite("inpainting_image/"+"inpainting"+str(i)+".png",save_image)
            image=torch.unsqueeze(image,dim=0)
            init_image=torch.unsqueeze(init_image,dim=0)
            image=image.to("cuda")
            init_image=init_image.to("cuda")
            ssim_value+=ssim(image,gt_image)
            lpips_value+=lpips(image,gt_image)
            psnr_value+=psnr(image,gt_image)
            ssim_another_value+=ssim(image,init_image)
            lpips_another_value+=lpips(image,init_image)
            psnr_another_value+=psnr(image,init_image)
            clip_score+=get_CLIP_score(model_clip,processor_clip,image_dir+"/"+init_images[i],"inpainting_image/"+"inpainting"+str(i)+".png")
        ssim_B_with_0.append(ssim_value/10)
        lpips_B_with_0.append(lpips_value/10)
        psnr_B_with_0.append(psnr_value/10)
        ssim_B_with_A.append(ssim_another_value/10)
        lpips_B_with_A.append(lpips_another_value/10)
        psnr_B_with_A.append(psnr_another_value/10)
        clip_scores.append(clip_score/10)

    
print("origin_SSIM:",ssim_origin_values)
print("origin_LPIPS:",lpips_origin_values)
print("origin_PSNR",psnr_origin_values)
print("SSIM_B_with_0:",ssim_B_with_0)
print("LPIPS_B_with_0:",lpips_B_with_0)
print("PSNR_B_with_0:",psnr_B_with_0)
print("SSIM_B_with_A:",ssim_B_with_A)
print("LPIPS_B_with_A:",lpips_B_with_A)
print("PSNR_B_with_A:",psnr_B_with_A)
print("CLIP_score:",clip_scores)
print("origin_SSIM:",sum(ssim_origin_values)/100)
print("origin_LPIPS:",sum(lpips_origin_values)/100)
print("origin_PSNR",sum(psnr_origin_values)/100)
print("SSIM_B_with_0:",sum(ssim_B_with_0)/100)
print("LPIPS_B_with_0:",sum(lpips_B_with_0)/100)
print("PSNR_B_with_0:",sum(psnr_B_with_0)/100)
print("SSIM_B_with_A:",sum(ssim_B_with_A)/100)
print("LPIPS_B_with_A:",sum(lpips_B_with_A)/100)
print("PSNR_B_with_A:",sum(psnr_B_with_A)/100)
print("CLIP_score:",sum(clip_scores)/100)
