import os
import torch
import numpy as np
from PIL import Image, ImageOps
import torchvision.transforms as T
from typing import Tuple, Optional, List
import argparse

size = 512
transform = T.Compose(
    [
        T.Resize(size, interpolation=T.InterpolationMode.BILINEAR),
        T.CenterCrop(size),
    ]
)

def tensor_to_pil_image(tensor: torch.Tensor) -> Image.Image:
    numpy_array = (tensor + 1.0) * 127.5
    numpy_array = numpy_array.clamp(0, 255).byte().numpy()
    
    if numpy_array.ndim == 4:
        numpy_array = numpy_array.squeeze(0)

    if numpy_array.shape[0] == 1:
        numpy_array = numpy_array.squeeze(0)
        pil_image = Image.fromarray(numpy_array, mode='L')
    elif numpy_array.shape[0] == 3:
        numpy_array = numpy_array.transpose(1, 2, 0)
        pil_image = Image.fromarray(numpy_array, mode='RGB')
    else:
        raise ValueError("Unsupported number of channels: {}".format(numpy_array.shape[0]))
    
    return pil_image

def pil_image_to_tensor(image: Image.Image) -> torch.Tensor:
    image = np.array(image.convert("RGB"))
    image = image[None].transpose(0, 3, 1, 2)
    image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
    return image

def add_perturbation(image, perturbation):
    # Resize image if dimensions do not match
    # print(perturbation.shape, image.size)
    if image.size != (perturbation.shape[2], perturbation.shape[1]):
        image = transform(image)
    # print(perturbation.shape, image.size
    image_array = pil_image_to_tensor(image)
    perturbed_image_array = image_array + perturbation
    return tensor_to_pil_image(perturbed_image_array)

def extract_perturbation(plain_image, adv_image):
    if plain_image.size != adv_image.size:
        plain_image = transform(plain_image)
    plain_tensor, adv_tensor = pil_image_to_tensor(plain_image), pil_image_to_tensor(adv_image)
    perturbation = adv_tensor - plain_tensor
    return perturbation

# subjects = ["14", "67", "112", "213", "228", "n000050", "n000068", "n000164", "n000190", "n000243"]
# subjects = ["228", "n000243"]

def main(original_image_dir: str, adv_image_dir: str, output_dir: str):
    perturbations = []
    
    if os.path.exists(os.path.join(adv_image_dir, "universal_perturbation.pt")):
        perturbation = torch.load(os.path.join(adv_image_dir, "universal_perturbation.pt"), map_location='cpu')
    else:
        for filename in os.listdir(adv_image_dir):
            if filename.endswith((".png", ".jpg")):
                plain_image_path = os.path.join(original_image_dir, filename)
                adv_image_path = os.path.join(adv_image_dir, filename)

                plain_image = Image.open(plain_image_path)
                adv_image = Image.open(adv_image_path)

                perturbation = extract_perturbation(plain_image, adv_image)
                perturbations.append(perturbation)

    # Define the input directory
    os.makedirs(output_dir, exist_ok=True)

    parent_dir = os.path.dirname(original_image_dir)  # 获取父目录
    input_dir = os.path.join(parent_dir, "set_C") # 将父目录和新子目录组合
    print(parent_dir, input_dir)
    # Process each image in the input directory
    for i, filename in enumerate(os.listdir(input_dir)):
        if filename.endswith(('.png', '.jpg', '.jpeg')):
            image_path = os.path.join(input_dir, filename)
            image = Image.open(image_path).convert('RGB')
            perturbed_image = add_perturbation(image, perturbations[i]) if perturbations else add_perturbation(image, perturbation)
            perturbed_image.save(os.path.join(output_dir, filename))

    print("Perturbation added to all images and saved to the output directory.")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Add perturbation to images.")
    parser.add_argument("--adv_image_path", type=str, required=True, help="Path to the perturbation file.")
    parser.add_argument("--output_dir", type=str, required=True, help="Directory to save the perturbed images.")
    parser.add_argument("--original_image_path", type=str, required=True, help="Directory to save the perturbed images.")
    args = parser.parse_args()
    
    main(args.original_image_path, args.adv_image_path, args.output_dir)