import json
import numpy as np
from PIL import Image
from scipy.ndimage import gaussian_filter
from pathlib import Path
import logging
from torchvision import transforms
import argparse
import os
import pickle

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')


def load_image(path):
    """ Load an image from the given path. """
    return np.array(Image.open(path))


def get_patch(image, center, radius):
    """ Extract a patch from the image centered at 'center' with given 'radius', with boundary check. """
    h, w = image.shape[:2]
    x, y = center

    x1 = max(0, x - radius)
    y1 = max(0, y - radius)
    x2 = min(w, x + radius + 1)
    y2 = min(h, y + radius + 1)

    patch = image[y1:y2, x1:x2]

    # 如果 patch 大小不符合 (2r+1, 2r+1)，补0（可选）
    expected_size = (2 * radius + 1, 2 * radius + 1)
    pad_h = expected_size[0] - patch.shape[0]
    pad_w = expected_size[1] - patch.shape[1]

    if pad_h > 0 or pad_w > 0:
        patch = np.pad(patch, ((0, pad_h), (0, pad_w), (0, 0)), mode='constant')

    return patch


def calculate_difference(patch1, patch2):
    """ Calculate the L2 norm (Euclidean distance) between two patches. """
    difference = patch1 - patch2
    squared_difference = np.square(difference)
    l2_distance = np.sum(squared_difference)

    return l2_distance


def compute_dai(original_image, result_image, points, radius):
    """ Compute the Drag Accuracy Index (DAI) for the given images and points. """
    dai = 0
    for start, target in points:
        original_patch = get_patch(original_image, start, radius)
        result_patch = get_patch(result_image, target, radius)
        dai += calculate_difference(original_patch, result_patch)
    dai /= len(points)
    dai /= cal_patch_size(radius)
    return dai / len(points)


def get_points(points_dir):
    with open(points_dir, 'r') as file:
        points_data = json.load(file)
        points = points_data['points']

    # Assuming pairs of points: [start, target, start, target, ...]
    point_pairs = [(points[i], points[i + 1]) for i in range(0, len(points), 2)]
    return point_pairs


def cal_patch_size(radius: int):
    return (1 + 2 * radius) ** 2

def compute_average_dai(radius, dataset_path, eval_root):
    all_category = [
        'art_work',
        'land_scape',
        'building_city_view',
        'building_countryside_view',
        'animals',
        'human_head',
        'human_upper_body',
        'human_full_body',
        'interior_design',
        'other_objects',
    ]
    total_dai, num_folders = 0, 0
    transform = transforms.Compose([
        transforms.ToTensor(),
    ])
    for cat in all_category:
        for file_name in os.listdir(os.path.join(dataset_path, cat)):
            if file_name == '.DS_Store':
                continue
            with open(os.path.join(dataset_path, cat, file_name, 'meta_data.pkl'), 'rb') as f:
                meta_data = pickle.load(f)
            points = meta_data['points']
            point_pairs = [(points[i], points[i + 1]) for i in range(0, len(points), 2)]
            
            source_image_path = os.path.join(dataset_path, cat, file_name, 'original_image.png')
            dragged_image_path = os.path.join(eval_root, cat, file_name, 'dragged_image.png')
            original_image = load_image(source_image_path)
            result_image = load_image(dragged_image_path)
                
            original_image = transform(original_image).permute(1, 2, 0).numpy()
            result_image = transform(result_image).permute(1, 2, 0).numpy()
            
            dai = compute_dai(original_image, result_image, point_pairs, radius)
            
            total_dai += dai
            num_folders += 1
    
    if num_folders > 0:
        average_dai = total_dai / num_folders
        logging.info(f'Average DAI for {eval_root} with r3 {radius} is {average_dai:.4f}. Total {num_folders} images.')
    else:
        logging.warning("No valid folders found for DAI calculation.")


def main(eval_root):
    gamma = [1, 5, 10, 20]
    data_folder = 'data/DragBench'
    for r in gamma:
        compute_average_dai(r, data_folder, eval_root)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="setting arguments")
    parser.add_argument('--eval_root',
        type=str,
        help='root of dragging results for evaluation',
        required=True)
    args = parser.parse_args()
    main(args.eval_root)