
import argparse
import os
from einops import rearrange
import numpy as np
import PIL
from PIL import Image
import torch
import torch.nn.functional as F
import lpips
import clip # pip install openai-clip
import time
import pickle

def preprocess_image(image,
                     device):
    image = torch.from_numpy(image).float() / 127.5 - 1 # [-1, 1]
    image = rearrange(image, "h w c -> 1 c h w")
    image = image.to(device)
    return image

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="setting arguments")
    parser.add_argument('--eval_root',
        action='append',
        help='root of dragging results for evaluation',
        required=True)
    parser.add_argument('--mask', action='store_true', help='use mask to evaluate')
    args = parser.parse_args()

    device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

    # lpip metric
    loss_fn_alex = lpips.LPIPS(net='alex').to(device)

    # load clip model
    clip_model, clip_preprocess = clip.load("ViT-B/32", device=device, jit=False)

    all_category = [
        'art_work',
        'land_scape',
        'building_city_view',
        'building_countryside_view',
        'animals',
        'human_head',
        'human_upper_body',
        'human_full_body',
        'interior_design',
        'other_objects',
    ]

    original_img_root = 'data/DragBench'

    for target_root in args.eval_root:
        all_lpips = []
        all_clip_sim = []
        
        cat_log = ''
        for cat in all_category:
            cat_lpips = []
            cat_clip_sim = []
            for file_name in os.listdir(os.path.join(original_img_root, cat)):
                if file_name == '.DS_Store':
                    continue
                source_image_path = os.path.join(original_img_root, cat, file_name, 'original_image.png')
                dragged_image_path = os.path.join(target_root, cat, file_name, 'dragged_image.png')
                

                source_image_PIL = Image.open(source_image_path)
                dragged_image_PIL = Image.open(dragged_image_path)
                dragged_image_PIL = dragged_image_PIL.resize(source_image_PIL.size,PIL.Image.BILINEAR)

                source_image = preprocess_image(np.array(source_image_PIL), device)
                dragged_image = preprocess_image(np.array(dragged_image_PIL), device)

                ## add mask
                if args.mask:
                    with open(os.path.join(original_img_root, cat, file_name, 'meta_data.pkl'), 'rb') as f:
                        meta_data = pickle.load(f)
                    mask = meta_data['mask'].astype(np.float32)
                    non_edit_mask = (mask == 0).astype(np.float32)  # 1=non-edited
                    non_edit_mask = torch.from_numpy(non_edit_mask).unsqueeze(0).unsqueeze(0).to(device)  # (1,1,H,W)
                    non_edit_mask = F.interpolate(non_edit_mask, size=source_image.shape[2:], mode='nearest')  # to image size
                ##
                
                # compute LPIP
                with torch.no_grad():
                    source_image_224x224 = F.interpolate(source_image, (224,224), mode='bilinear')
                    dragged_image_224x224 = F.interpolate(dragged_image, (224,224), mode='bilinear')
                    if args.mask:
                        # use mask to compute lpips
                        mask_224 = F.interpolate(non_edit_mask, size=(224,224), mode='nearest')
                        source_image_224x224 = source_image_224x224 * mask_224
                        dragged_image_224x224 = dragged_image_224x224 * mask_224
                    
                    cur_lpips = loss_fn_alex(source_image_224x224, dragged_image_224x224)
                    
                    
                    all_lpips.append(cur_lpips.item())
                    cat_lpips.append(cur_lpips.item())

                # compute CLIP similarity
                if args.mask:
                    non_edit_mask_np = non_edit_mask.squeeze().cpu().numpy()  # shape (H, W)
                    non_edit_mask_rgb = np.stack([non_edit_mask_np]*3, axis=-1)  # (H, W, 3)

                    source_np = np.array(source_image_PIL).astype(np.float32)
                    dragged_np = np.array(dragged_image_PIL).astype(np.float32)

                    source_masked = (source_np * non_edit_mask_rgb).astype(np.uint8)
                    dragged_masked = (dragged_np * non_edit_mask_rgb).astype(np.uint8)

                    # PIL + clip preprocess
                    source_image_clip = clip_preprocess(Image.fromarray(source_masked)).unsqueeze(0).to(device)
                    dragged_image_clip = clip_preprocess(Image.fromarray(dragged_masked)).unsqueeze(0).to(device)
                else:
                    source_image_clip = clip_preprocess(source_image_PIL).unsqueeze(0).to(device)
                    dragged_image_clip = clip_preprocess(dragged_image_PIL).unsqueeze(0).to(device)

                with torch.no_grad():
                    source_feature = clip_model.encode_image(source_image_clip)
                    dragged_feature = clip_model.encode_image(dragged_image_clip)
                    source_feature /= source_feature.norm(dim=-1, keepdim=True)
                    dragged_feature /= dragged_feature.norm(dim=-1, keepdim=True)
                    cur_clip_sim = (source_feature * dragged_feature).sum()
                    all_clip_sim.append(cur_clip_sim.cpu().numpy())
                    cat_clip_sim.append(cur_clip_sim.cpu().numpy())
            cat_log += f"{cat}:\n" + \
                        f"avg lpips: {np.mean(cat_lpips)}\n" + \
                        f"avg 1-lpips: {1-np.mean(cat_lpips)}\n" + \
                        f"avg clip sim: {np.mean(cat_clip_sim)}\n\n"
        print(target_root)
        print('avg lpips: ', np.mean(all_lpips))
        print('avg clip sim', np.mean(all_clip_sim))
        logg = f"***************\n"*2 + \
                f"{time.strftime('%Y-%m-%d %H:%M:%S',time.localtime())}\n" +\
                f"{target_root}:  \n" + \
                    f"avg lpips: {np.mean(all_lpips)}\n" +\
                    f"avg 1-lpips: {1-np.mean(all_lpips)}\n" +\
                    f"avg clip sim: {np.mean(all_clip_sim)}\n" + \
                    f"{cat_log}" + \
                    f"\n\n\n"
        with open("./run_eval_similarity_result.txt", 'a') as f:
            f.write(logg)

