
import argparse
import os
import pickle
import numpy as np
import PIL
from PIL import Image
from torchvision.transforms import PILToTensor
import torch
import torch.nn as nn
import torch.nn.functional as F
from dift_sd import SDFeaturizer
from pytorch_lightning import seed_everything
import time
import cv2

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="setting arguments")
    parser.add_argument('--eval_root',
        action='append',
        help='root of dragging results for evaluation',
        required=True)
    parser.add_argument('--device',
        type=str,
        default='cuda',
        help='device to use for evaluation')
    parser.add_argument('--output_dir',
        type=str,
        default='./eval_point_matching',
        help='output directory to save the results')
    args = parser.parse_args()
    device = args.device

    # using SD-2.1
    dift = SDFeaturizer('stabilityai/stable-diffusion-2-1',device=device)

    all_category = [
        'art_work',
        'land_scape',
        'building_city_view',
        'building_countryside_view',
        'animals',
        'human_head',
        'human_upper_body',
        'human_full_body',
        'interior_design',
        'other_objects',
    ]
    original_img_root = 'data/DragBench'

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    for target_root in args.eval_root:
        # fixing the seed for semantic correspondence
        seed_everything(42)
        test_dict = {}
        all_dist = []
        cat_logg = ''
        for cat in all_category:
            cat_dist = []
            test_dict[cat] = {}
            for file_name in os.listdir(os.path.join(original_img_root, cat)):
                if file_name == '.DS_Store':
                    continue
                with open(os.path.join(original_img_root, cat, file_name, 'meta_data.pkl'), 'rb') as f:
                    meta_data = pickle.load(f)
                prompt = meta_data['prompt']
                points = meta_data['points']

                torch.cuda.empty_cache()
                # here, the point is in x,y coordinate
                handle_points = []
                target_points = []
                for idx, point in enumerate(points):
                    # from now on, the point is in row,col coordinate
                    cur_point = torch.tensor([point[1], point[0]])
                    if idx % 2 == 0:
                        handle_points.append(cur_point)
                    else:
                        target_points.append(cur_point)

                source_image_path = os.path.join(original_img_root, cat, file_name, 'original_image.png')
                dragged_image_path = os.path.join(target_root, cat, file_name, 'dragged_image.png')

                source_image_PIL = Image.open(source_image_path)
                dragged_image_PIL = Image.open(dragged_image_path)
                dragged_image_PIL = dragged_image_PIL.resize(source_image_PIL.size,PIL.Image.BILINEAR)

                source_image_tensor = (PILToTensor()(source_image_PIL) / 255.0 - 0.5) * 2       # to (C,H,W) and range(-1,1)
                dragged_image_tensor = (PILToTensor()(dragged_image_PIL) / 255.0 - 0.5) * 2

                _, H, W = source_image_tensor.shape

                # dift = dift.to(device)
                torch.cuda.empty_cache()
                # dift = SDFeaturizer('stabilityai/stable-diffusion-2-1',device=device)
                ft_source = dift.forward(source_image_tensor,
                      prompt=prompt,
                      t=261,
                      up_ft_index=1,    
                      ensemble_size=6,
                      device=device)  # return size: [1, c, h, w]
                ft_source = F.interpolate(ft_source, (H, W), mode='bilinear')

                torch.cuda.empty_cache()
                ft_dragged = dift.forward(dragged_image_tensor,
                      prompt=prompt,
                      t=261,
                      up_ft_index=1,
                      ensemble_size=6,
                      device=device)  # return size: [1, c, h, w]
                ft_dragged = F.interpolate(ft_dragged, (H, W), mode='bilinear') 

                # del dift
                torch.cuda.empty_cache()
                cos = nn.CosineSimilarity(dim=1)
                temp = 0
                img = np.array(dragged_image_PIL) 
                img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 
                for pt_idx in range(len(handle_points)):
                    hp = handle_points[pt_idx]
                    tp = target_points[pt_idx]
                    with torch.no_grad():
                        num_channel = ft_source.size(1)
                        src_vec = ft_source[0, :, hp[0], hp[1]].view(1, num_channel, 1, 1)  
                        cos_map = cos(src_vec, ft_dragged).cpu().numpy()[0]  # H, W     
                        max_rc = np.unravel_index(cos_map.argmax(), cos_map.shape) # the matched row,col   
                    del src_vec
                    del cos_map
                    torch.cuda.empty_cache()
                    
                    # draw figure   
                    curr_target_p = tuple(torch.tensor(max_rc).long().cpu().numpy())  
                    cv2.circle(img, tuple((curr_target_p[1], curr_target_p[0])), 10, (255, 0, 0), -1) 
                    cv2.circle(img, tuple((tp[1], tp[0])), 10, (0, 0, 255), -1) 
                    
                    # breakpoint()
                    # calculate distance
                    dist = (tp - torch.tensor(max_rc)).cpu().float().norm()   
                    
                    temp += dist
                    
                    all_dist.append(dist)
                    cat_dist.append(dist)
                if not os.path.exists(os.path.join(args.output_dir, cat, file_name)):
                    os.makedirs(os.path.join(args.output_dir, cat, file_name), exist_ok=True)
                cv2.imwrite(os.path.join(args.output_dir, cat, file_name, "dragged_image_with_points.png"), img)
                test_dict[cat][file_name] = (temp / len(handle_points)).item()
                # time.sleep(1)
                torch.cuda.empty_cache() 
            print(cat + ' mean distance: ', torch.tensor(cat_dist).mean().item())
            cat_logg += f"{cat} mean distance: {torch.tensor(cat_dist).mean().item()}\n"
        print(target_root + ' mean distance: ', torch.tensor(all_dist).mean().item())   