import os
import torch
import numpy as np
from PIL import Image, ImageOps
import torchvision.transforms as T
from typing import Tuple, Optional, List
import argparse

def tensor_to_pil_image(tensor: torch.Tensor) -> Image.Image:
    numpy_array = (tensor + 1.0) * 127.5
    numpy_array = numpy_array.clamp(0, 255).byte().numpy()
    
    if numpy_array.ndim == 4:
        numpy_array = numpy_array.squeeze(0)

    if numpy_array.shape[0] == 1:
        numpy_array = numpy_array.squeeze(0)
        pil_image = Image.fromarray(numpy_array, mode='L')
    elif numpy_array.shape[0] == 3:
        numpy_array = numpy_array.transpose(1, 2, 0)
        pil_image = Image.fromarray(numpy_array, mode='RGB')
    else:
        raise ValueError("Unsupported number of channels: {}".format(numpy_array.shape[0]))
    
    return pil_image

def pil_image_to_tensor(image: Image.Image) -> torch.Tensor:
    image = np.array(image.convert("RGB"))
    image = image[None].transpose(0, 3, 1, 2)
    image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
    return image

# Function to add perturbation to an image
def add_perturbation(image, perturbation):
    image_array = pil_image_to_tensor(image)
    perturbed_image_array = image_array + perturbation
    return tensor_to_pil_image(perturbed_image_array)

# subjects = ["14", "67", "112", "213", "228", "n000050", "n000068", "n000164", "n000190", "n000243"]
# subjects = ["228", "n000243"]

def main(subject_id: str, perturbation_path: str, output_dir: str):
    # Load the perturbation
    perturbation = torch.load(perturbation_path, map_location='cpu')
    
    # Define the input directory
    input_dir = f''
    os.makedirs(output_dir, exist_ok=True)

    # Process each image in the input directory
    for filename in os.listdir(input_dir):
        if filename.endswith(('.png', '.jpg', '.jpeg')):
            image_path = os.path.join(input_dir, filename)
            image = Image.open(image_path).convert('RGB')
            perturbed_image = add_perturbation(image, perturbation)
            perturbed_image.save(os.path.join(output_dir, filename))

    print("Perturbation added to all images and saved to the output directory.")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Add perturbation to images.")
    parser.add_argument("--perturbation_path", type=str, required=True, help="Path to the perturbation file.")
    parser.add_argument("--output_dir", type=str, required=True, help="Directory to save the perturbed images.")
    parser.add_argument("--subject_id", type=str, required=True, help="Directory to save the perturbed images.")
    args = parser.parse_args()
    
    main(args.subject_id, args.perturbation_path, args.output_dir)