import torch
import numpy as np
import torchvision.transforms as transforms
import os
import copy
from PIL import Image
import yaml

import argparse

import sys
import os
from tqdm import tqdm

from ..wrapper import BaseAttackWrapper

def check_position(position, bboxes, trigger_size, bbox_format):

    # Check if the position overlaps with any bounding box
    for bbox in bboxes:
        if bbox_format == 'xywh':
            x1, y1, w, h = bbox
            x2, y2 = x1 + w, y1 + h
        elif bbox_format == 'xyxy':
            x1, y1, x2, y2 = bbox
            w, h = x2 - x1, y2 - y1
        else:
            raise ValueError('Invalid bbox format')

        if (x1 < position[0] + trigger_size and
            position[0] < x2 and
            y1 < position[1] + trigger_size and
            position[1] < y2):
            return False
        
    return True

class ODAAlignFixedWrapper(BaseAttackWrapper):

    def __load_config__(self):

        # Load the config file as yaml
        with open(self.config_path, 'r') as f:
            config = yaml.safe_load(f)

        parser = argparse.ArgumentParser(description='ODAAlignWrapper')

        # Add arguments to the parser
        parser.add_argument('--trigger_path', type=str, help='Path to the trigger image')
        parser.add_argument('--trigger_size', type=int, help='Minimum trigger size')
        parser.add_argument('--trigger_factor', type=float, help='Factor used to determine if box is to small (Section 5.2 (I))')
        parser.add_argument('--poison_rate', type=float, help='Poison rate, if None, use the default poison rate')
        parser.add_argument('--num_triggers', type=int, help='Number of triggers to add to the image')
        parser.add_argument('--max_attempts', type=int, default=100, help='Maximum number of attempts to find a valid position for the trigger')
        parser.add_argument('--img_format', type=str, default='jpg', help='Image format to save the poisoned images')
        parser.add_argument('--multi_target', type=bool, default=False, help='If then all poisonable objects are poisoned, otherwise poison 1 object per image')

        # Parser known arguments without CLI
        args = parser.parse_args([])

        # Merge YAML values into argparse
        for key, value in config.items():
            if hasattr(args, key):
                setattr(args, key, value)

        # Load the trigger
        trigger_path = config['trigger_path']
        trigger = Image.open(trigger_path).convert('RGB')
        trigger = transforms.ToTensor()(trigger)
        self.trigger = trigger
        self.trigger_size = config['trigger_size']
        self.config = config
        self.poison_rate = config['poison_rate']
        self.multi_target = config['multi_target'] if 'multi_target' in config else False

        print(f'[ODAAlignWrapper] Loaded ODA Align with target class {self.target_id} and poison rate {self.poison_rate}')
    
    def __is_poisonable__(self, annotation, img_id): 
        
        num_annotations = len(annotation['bbox'])
        object_mask = []

        for i in range(num_annotations):

            bbox = annotation['bbox'][i]

            if self.bbox_current_format == 'xywh':
                bbox_width, bbox_height = bbox[2], bbox[3]
            elif self.bbox_current_format == 'xyxy':
                bbox_width, bbox_height = bbox[2] - bbox[0], bbox[3] - bbox[1]
            else:
                raise ValueError('Invalid bbox format')       

            height_ratio = bbox_height / self.trigger_size
            width_ratio = bbox_width / self.trigger_size

            if height_ratio >= self.config['trigger_factor'] and width_ratio >= self.config['trigger_factor']:
                object_mask.append(True)
            else:
                object_mask.append(False)

        return object_mask

    def __poison_image__(self, img_id, sub_id, object_index):
        raise NotImplementedError("This method should never be called directly. Use __poison_train_image__ or __poison_test_image__ instead.")

    def __poison_train_image__(self, img_id, sub_id):
        
        try:
            annotation = self.dataset_wrapper.annotations[img_id][sub_id]
        except:
            raise ValueError('Invalid image id or sub id')  
        
        # 1) Get the image
        img = self.dataset_wrapper.__get_image__(img_id, sub_id, get_bd=False)
        img = transforms.ToTensor()(img)
        
        # 2) Create the trigger
        new_trigger = copy.deepcopy(self.trigger)
        trigger_size = self.trigger_size
        
        # 3) Generate num_trigger number of random positions
        num_triggers = self.config['num_triggers']
        max_attempts = self.config['max_attempts']
        img_height, img_width = img.shape[1], img.shape[2]

        found_positions = []

        bboxes = annotation['bbox']
        
        for _ in range(max_attempts):
            if len(found_positions) >= num_triggers:
                break

            # Generate a random position for the trigger
            x = np.random.randint(0, img_width - trigger_size)
            y = np.random.randint(0, img_height - trigger_size)
            
            position = (x, y)

            if check_position(position, bboxes, trigger_size, self.bbox_current_format):
                found_positions.append(position)
                new_trigger = transforms.Resize((trigger_size, trigger_size))(new_trigger)
                img[:, y:y+trigger_size, x:x+trigger_size] = new_trigger

        # 4) Save the image and update the annotation
        img_path = os.path.join(self.bd_save_dir, f'{img_id}_{sub_id}.{self.config["img_format"]}')
        img = transforms.ToPILImage()(img)
        img.save(img_path)

        self.dataset_wrapper.annotations[img_id][sub_id]['bd_img_path'] = img_path

    def __create_dataset_train__(self):
        annotations = self.dataset_wrapper.annotations
        image_ids = list(annotations.keys())

        print(f'[ODA Align] Creating dataset for {self.data_split.value} (N={len(image_ids)})')

        ideal_num_poisonable_images = int(self.poison_rate * len(image_ids))
        num_poisonable_images = min(ideal_num_poisonable_images, len(image_ids))

        print(f'[ODA Align] Found {len(image_ids)} poisonable images, using {num_poisonable_images} for backdoor attack')

        shuffle_indices = np.random.permutation(len(image_ids))

        bd_id_set = []
        bd_id_list = []

        for i in tqdm(range(num_poisonable_images), desc='Poisoning images'):
            img_id= image_ids[shuffle_indices[i]]

            self.__poison_train_image__(img_id, 0)

            bd_id_set.append(img_id)
            bd_id_list.append(img_id)

        self.bd_id_set = set(bd_id_set)
        self.bd_id_list = bd_id_list

    def __add_trigger_to_image__(self, img, trigger, bbox, position):
        
        # 3) Calculate the trigger size
        if self.bbox_current_format == 'xywh':
            bbox_width, bbox_height = bbox[2], bbox[3]
        elif self.bbox_current_format == 'xyxy':
            bbox_width, bbox_height = bbox[2] - bbox[0], bbox[3] - bbox[1]
        else:
            raise ValueError('Invalid bbox format')

        trigger_size = self.config['trigger_size']

        if trigger_size > min(bbox_width, bbox_height):
            raise ValueError(f'Trigger size {trigger_size} is larger than the bounding box size {bbox_width}x{bbox_height}')

        # 4) Create the trigger
        new_trigger = copy.deepcopy(trigger)
        new_trigger = transforms.Resize((trigger_size, trigger_size))(new_trigger)

        # 5) Calculate the position of the trigger 
        add_pos = []
        if position == 'center':
            center_x, center_y = bbox[0] + bbox_width // 2, bbox[1] + bbox_height // 2
            x1, y1 = center_x - trigger_size // 2, center_y - trigger_size // 2
            x1, y1 = int(x1), int(y1)
            add_pos.append((x1, y1))
        elif position == 'random':
            x1 = np.random.randint(int(bbox[0]), int(bbox[0] + bbox_width - trigger_size))
            y1 = np.random.randint(int(bbox[1]), int(bbox[1] + bbox_height - trigger_size))
            x1, y1 = int(x1), int(y1)
            add_pos.append((x1, y1))
        elif position == 'high':
            center_x, center_y = bbox[0] + bbox_width // 2, bbox[1] + bbox_height // 2

            # X is in the middle of the bbox
            # Y is 0.2 * box height from y1
            x1 = center_x - trigger_size // 2
            y1 = bbox[1] + int(0.3 * bbox_height) - trigger_size
            x1, y1 = int(x1), int(y1)
            add_pos.append((x1, y1))

        elif position == 'low':
            center_x, center_y = bbox[0] + bbox_width // 2, bbox[1] + bbox_height // 2

            # X is in the middle of the bbox
            # Y is 0.8 * box height from y1
            x1 = center_x - trigger_size // 2
            y1 = bbox[1] + int(0.8 * bbox_height) - trigger_size
            x1, y1 = int(x1), int(y1)
            add_pos.append((x1, y1))

        elif position == 'both':
            center_x, center_y = bbox[0] + bbox_width // 2, bbox[1] + bbox_height // 2

            # X is in the middle of the bbox
            # Y is 0.2 * box height from y1
            x1 = center_x - trigger_size // 2
            y1 = bbox[1] + int(0.3 * bbox_height) - trigger_size
            x1, y1 = int(x1), int(y1)
            add_pos.append((x1, y1))

            # X is in the middle of the bbox
            # Y is 0.8 * box height from y1
            x2 = center_x - trigger_size // 2
            y2 = bbox[1] + int(0.8 * bbox_height) - trigger_size
            x2, y2 = int(x2), int(y2)
            add_pos.append((x2, y2))
        else:
            raise ValueError('Invalid position')

        try:
            # 6) Add the trigger to the image
            for pos in add_pos:
                x1, y1 = pos

                # Add the trigger to the image
                img[:, y1:y1+trigger_size, x1:x1+trigger_size] = new_trigger
        except:
            print(bbox)
            print(x1, y1)
            print(trigger_size)
            print(img.shape)


            raise ValueError('Trigger out of bounds')

    def __poison_test_image__(self, img_id, sub_id, object_indexs):
        
        try:
            annotation = self.dataset_wrapper.annotations[img_id][sub_id]
        except:
            raise ValueError('Invalid image id or sub id')  
        
        # 1) Get the image
        img = self.dataset_wrapper.__get_image__(img_id, sub_id, get_bd=False)
        img = transforms.ToTensor()(img)

        # 2) Poison the boxes where object_mask is True
        for object_index in object_indexs:
        
            # 1) Get the object to poison
            bbox = annotation['bbox'][object_index]        
            sub_id = annotation['sub_id']

            # 2) Poison the object
            self.__add_trigger_to_image__(img, self.trigger, bbox, self.trigger_position)
            
            self.dataset_wrapper.annotations[img_id][sub_id]['poison_mask'][object_index] = True
            self.dataset_wrapper.annotations[img_id][sub_id]['target_id'][object_index] = 0

        # 7) Save the image and update the annotation
        if len(object_indexs) == 1:
            img_path = os.path.join(self.bd_save_dir, f'{img_id}_{sub_id}_{object_indexs[0]}.{self.config["img_format"]}')
        else:
            img_path = os.path.join(self.bd_save_dir, f'{img_id}_{sub_id}_all.{self.config["img_format"]}')

        img = transforms.ToPILImage()(img)
        img.save(img_path)
        self.dataset_wrapper.annotations[img_id][sub_id]['bd_img_path'] = img_path

    def __create_dataset_test__(self):
        
        annotations = self.dataset_wrapper.annotations
        image_ids = list(annotations.keys())
        poisonable_image_ids = []

        print(f'[ODA Test] Creating dataset for {self.data_split.value} (N={len(image_ids)})')

        for img_id in image_ids:
            annotation = annotations[img_id]

            if len(annotation) > 1:
                raise ValueError("Image has sub-images")

            annotation = annotation[0]
            object_mask = self.__is_poisonable__(annotation, img_id)

            if sum(object_mask) > 0:
                poisonable_image_ids.append((img_id, object_mask))

        bd_id_set = []
        bd_id_list = []

        print(f'[ODA Test] Found {len(poisonable_image_ids)} poisonable images')

        for img_id, object_mask in tqdm(poisonable_image_ids, desc='Poisoning images'):

            poison_indexs = np.where(object_mask)[0]

            if not self.multi_target:
                for i, object_index in enumerate(poison_indexs):

                    # Create a new sub-image to poison
                    if i != 0:
                        self.dataset_wrapper.annotations[img_id].append(
                            copy.deepcopy(self.dataset_wrapper.annotations[img_id][0])
                        )
                        self.dataset_wrapper.annotations[img_id][-1]["sub_id"] = i
                        self.dataset_wrapper.annotations[img_id][-1]["target_id"] = [-1] * len(
                            self.dataset_wrapper.annotations[img_id][-1]["target_id"]
                        )
                        self.dataset_wrapper.annotations[img_id][-1]["poison_mask"] = [False] * len(
                            self.dataset_wrapper.annotations[img_id][-1]["poison_mask"]
                        )

                    self.__poison_test_image__(img_id, i, [object_index])
            else:
                self.__poison_test_image__(img_id, 0, poison_indexs)

            bd_id_set.append(img_id)
            bd_id_list.append(img_id)

        self.bd_id_set = set(bd_id_set)
        self.bd_id_list = bd_id_list