import json
from PIL import Image
import numpy as np
import torch
import os
import torchvision.transforms as transforms
from ..baseline import SingleBehaviorRedTeamingMethod
import multimodalmodels

class MultiModalPGD(SingleBehaviorRedTeamingMethod):
    """
    Parent class for multimodal attacks
    """

    def __init__(self, num_steps=100, step_size=0.01, epsilon=0.1, stop_loss=0.01, targets_path=None, **kwargs):
        self.num_steps = num_steps
        self.step_size = step_size
        self.epsilon = epsilon
        self.stop_loss = stop_loss
        self.test_cases_batch_size = kwargs['test_cases_batch_size']
        self.num_test_cases_per_behavior = kwargs['num_test_cases_per_behavior']
        self.model = multimodalmodels.__dict__[kwargs['experiment_name']]()

        with open(targets_path, 'r', encoding='utf-8') as file:
            self.behavior_id_to_target = json.load(file)

    def save_test_cases(self, save_dir, test_cases, logs=None, method_config=None, run_id=None):
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        def save_tensor(tensor, save_dir):
            adv_image_pil = transforms.ToPILImage()(tensor)
            filename = f"{hash(tensor)}.png"
            adv_image_pil.save(os.path.join(save_dir, filename))
            return filename

        def replace_tensors(obj):
            if isinstance(obj, dict):
                for key, value in obj.items():
                    obj[key] = replace_tensors(value)
            elif isinstance(obj, list):
                for i, item in enumerate(obj):
                    obj[i] = replace_tensors(item)
            elif isinstance(obj, torch.Tensor):
                return save_tensor(obj, save_dir)
            return obj

        test_cases = replace_tensors(test_cases)
        logs = replace_tensors(logs)
        super().save_test_cases(save_dir, test_cases, logs, method_config, run_id)

    def generate_test_cases_single_behavior(self, behavior, num_generate, verbose=False):
        """
        Generate a test cases for the behavior

        :param behavior: the behavior to generate test cases for
        :param verbose: whether to print progress
        """
        # ===== Behavior Meta data ======
        behavior_dict = behavior
        behavior = behavior_dict['Behavior']
        image_filename = behavior_dict['ImageFileName']
        source = behavior_dict['Source']
        note = behavior_dict['Notes']
        behavior_id = behavior_dict['BehaviorID']
        
        image_file = os.path.join('data/multimodal_behavior_images/', image_filename)
        target = self.behavior_id_to_target[behavior_id]
        behavior += ' '
        if verbose:
            print(f"Behavior: {behavior}")
            print(f"Image_file: {image_file}")
            print(f"Source: {source}")
            print(f"Note: {note}")
            print(f"Behavior ID: {behavior_id}")
            print(f"Target: {target}")


        image_pil = Image.open(image_file).convert('RGB')
        image_input = self.model.preprocess(image_pil)
        assert image_input.min() >= 0 and image_input.max() <= 1, "image_input values must be between 0 and 1"
        adv_image, log = self.pgd_attack(image_input, behavior, target, verbose)
        return [[adv_image, behavior]], [log]


    def pgd_attack(self, image_input, behavior, target, verbose=False):
        image_input_original = image_input
        loss_hist = []
        all_test_cases = []
        min_loss = float('inf')
        best_image_input = None
        for step in range(self.num_steps):
            image_input = image_input.clone().detach().requires_grad_(True)
            loss = self.model.compute_loss(behavior, target, image_input)
            loss_hist.append(loss.item())
            if step % 10 == 0:
                all_test_cases.append(image_input.clone().detach().cpu())
            if loss.item() < min_loss:
                min_loss = loss.item()
                best_image_input = image_input.clone().detach()
                if min_loss < self.stop_loss:
                    break
            grad = torch.autograd.grad(
                outputs=loss,
                inputs=[image_input],
                retain_graph=False,
                create_graph=False,
                only_inputs=True,
                allow_unused=True
            )[0]
            image_input = image_input - self.step_size * torch.sign(grad)
            image_input = torch.min(image_input, image_input_original + self.epsilon)
            image_input = torch.max(image_input, image_input_original - self.epsilon)
            image_input = torch.clamp(image_input, 0, 1)
            if verbose:
                print(f"Step {step+1}/{self.num_steps} - Loss: {loss.item()}, Min Loss: {min_loss}")
        log = {'final_loss': loss_hist[-1], 'all_losses': loss_hist, 'all_test_cases': all_test_cases}
        return best_image_input.cpu(), log


class MultiModalPGDBlankImage(MultiModalPGD):
    def __init__(self, num_steps=100, step_size=0.01, epsilon=1.0, **kwargs):
        assert epsilon == 1.0
        super().__init__(num_steps, step_size, epsilon, **kwargs)

    def generate_test_cases_single_behavior(self, behavior, num_generate, verbose=False):

        behavior_dict = behavior
        behavior = behavior_dict['Behavior']
        behavior_id = behavior_dict['BehaviorID']
        context_str = behavior_dict['ContextString']

        target = self.behavior_id_to_target[behavior_id]
        behavior += ' '
        if context_str:
            behavior = f"{context_str}\n\n---\n\n{behavior}"
        if verbose:
            print(f"Behavior: {behavior}")
            print(f"Target: {target}")
        # Create a blank RGB image
        image_pil = Image.fromarray(np.random.randint(0, 256, (512, 512, 3), dtype=np.uint8))
        image_input = self.model.preprocess(image_pil)
        adv_image, loss_hist = self.pgd_attack(image_input, behavior, target, verbose)
        return [[adv_image, behavior]], [loss_hist]


class MultiModalPGDPatch(MultiModalPGD):
    def __init__(self, num_steps=100, step_size=0.01, stop_loss=0.005, **kwargs):
        self.num_steps = num_steps
        self.step_size = step_size
        self.stop_loss = stop_loss
        self.test_cases_batch_size = kwargs['test_cases_batch_size']
        self.num_test_cases_per_behavior = kwargs['num_test_cases_per_behavior']
        self.model = multimodalmodels.__dict__[kwargs['experiment_name']]()

    def pgd_attack(self, image_input, behavior, target, verbose=False):
        """
        Perform the Projected Gradient Descent (PGD) attack on the patch.
        This method overrides the base class method to apply the attack specifically to the patch region.
        """

        loss_hist = []
        all_test_cases = []
        min_loss = float('inf')

        # Create a rectangle area on the mask with 1s inside and 0s outside
        # The location of the rectangle should be random within the mask boundaries
        mask = torch.zeros_like(image_input, dtype=torch.bool)
        height, width = image_input.size(-2), image_input.size(-1)
        patch_width = patch_height = int(0.2*(min(width, height)))
        max_x = width - patch_width
        max_y = height - patch_height
        start_x = torch.randint(0, max_x, (1,)).item()
        start_y = torch.randint(0, max_y, (1,)).item()
        mask[:, start_y:start_y + patch_height, start_x:start_x + patch_width] = 1
        mask = mask.type(image_input.dtype).to(image_input.device)
        patch_image = torch.rand_like(image_input).to(image_input.device)
        for step in range(self.num_steps):
            patch_image = patch_image.clone().detach().requires_grad_(True)
            perturbed_image = torch.mul(mask, patch_image) + torch.mul((1-mask), image_input)
            loss = self.model.compute_loss(behavior, target, perturbed_image)
            loss_hist.append(loss.item())
            if step % 10 == 0:
                all_test_cases.append(image_input.clone().detach().cpu())
            if loss.item() < min_loss:
                min_loss = loss.item()
                best = perturbed_image.clone().detach()
                if min_loss < self.stop_loss:
                    break
            grad = torch.autograd.grad(
                outputs=loss,
                inputs=[patch_image],
                retain_graph=False,
                create_graph=False,
                only_inputs=True,
                allow_unused=True
            )[0]
            patch_image = patch_image - self.step_size * torch.sign(grad)
            patch_image = torch.clamp(patch_image, 0, 1)
            if verbose:
                print(f"Step {step+1}/{self.num_steps} - Loss: {loss.item()}, Min Loss: {min_loss}")
        log = {'final_loss': loss_hist[-1], 'all_losses': loss_hist, 'all_test_cases': all_test_cases}
        return best.cpu(), log
