from PIL import Image
import json
import torch
import os
import torchvision.transforms as transforms
from ..baseline import SingleBehaviorRedTeamingMethod

class MultiModalDirectRequest(SingleBehaviorRedTeamingMethod):
    """
    Parent class for multimodal attacks
    """

    def __init__(self, image_width, image_height, targets_path, **kwargs):
        self.image_width = image_width
        self.image_height = image_height
        self.test_cases_batch_size = kwargs['test_cases_batch_size']
        self.num_test_cases_per_behavior = kwargs['num_test_cases_per_behavior']

        with open(targets_path, 'r', encoding='utf-8') as file:
            self.behavior_id_to_target = json.load(file)

    def save_test_cases(self, save_dir, test_cases, logs=None, method_config=None, run_id=None):
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        def save_tensor(tensor, save_dir):
            adv_image_pil = transforms.ToPILImage()(tensor)
            filename = f"{hash(tensor)}.png"
            adv_image_pil.save(os.path.join(save_dir, filename))
            return filename

        def replace_tensors(obj):
            if isinstance(obj, dict):
                for key, value in obj.items():
                    obj[key] = replace_tensors(value)
            elif isinstance(obj, list):
                for i, item in enumerate(obj):
                    obj[i] = replace_tensors(item)
            elif isinstance(obj, torch.Tensor):
                return save_tensor(obj, save_dir)
            return obj

        test_cases = replace_tensors(test_cases)
        super().save_test_cases(save_dir, test_cases, logs, method_config, run_id)

    def generate_test_cases_single_behavior(self, behavior, num_generate, verbose=False):
        """
        Generate a test cases for the behavior

        :param behavior: the behavior to generate test cases for
        :param verbose: whether to print progress
        """
        # ===== Behavior Meta data ======
        behavior_dict = behavior
        behavior = behavior_dict['Behavior']
        image_filename = behavior_dict['ImageFileName']
        source = behavior_dict['Source']
        note = behavior_dict['Notes']
        behavior_id = behavior_dict['BehaviorID']
        
        image_file = os.path.join('data/multimodal_behavior_images/', image_filename)
        target = self.behavior_id_to_target[behavior_id]
        behavior += ' '
        if verbose:
            print(f"Behavior: {behavior}")
            print(f"Image_file: {image_file}")
            print(f"Source: {source}")
            print(f"Note: {note}")
            print(f"Behavior ID: {behavior_id}")
            print(f"Target: {target}")
        image_pil = Image.open(image_file).convert('RGB')
        original_width, original_height = image_pil.size
        aspect_ratio = min(self.image_width / original_width, self.image_height / original_height)
        new_width = int(original_width * aspect_ratio)
        new_height = int(original_height * aspect_ratio)
        image_pil = image_pil.resize((new_width, new_height), Image.LANCZOS)
        left = (new_width - self.image_width) / 2
        top = (new_height - self.image_height) / 2
        right = (new_width + self.image_width) / 2
        bottom = (new_height + self.image_height) / 2
        image_pil = image_pil.crop((left, top, right, bottom))
        assert image_pil.size == (self.image_width, self.image_height), "Cropped image has incorrect dimensions."
        log = {'final_loss': 0, 'all_losses': [0], 'all_test_cases': []}
        return [[transforms.ToTensor()(image_pil), behavior]], [log]


