import torch
import numpy as np
import torchvision.transforms as transforms
import os
import copy
from PIL import Image
import yaml

import sys
import os
import json
from tqdm import tqdm

import argparse

from ..wrapper import BaseAttackWrapper

class ODAUBAWrapper(BaseAttackWrapper):

    def __load_config__(self):
        
        # Load the config file as yaml
        with open(self.config_path, 'r') as f:
            config = yaml.safe_load(f)

        parser = argparse.ArgumentParser(description='ODAUBArapper')

        # Add arguments to the parser
        parser.add_argument('--trigger_path', type=str, help='Path to the trigger image')
        parser.add_argument('--trigger_ratio', type=float, help='Ratio of the trigger size to the bounding box size')
        parser.add_argument('--bbox_poison_rate', type=float, help='Box poison rate, if None, use the default poison rate')
        parser.add_argument('--min_trigger_size', type=int, help='Minimum trigger size')
        parser.add_argument('--max_trigger_size', type=int, help='Maximum trigger size')
        parser.add_argument('--img_format', type=str, default='jpg', help='Image format to save the poisoned images')

        # Parser known arguments without CLI
        args = parser.parse_args([])

        # Merge YAML values into argparse
        for key, value in config.items():
            if hasattr(args, key):
                setattr(args, key, value)

        # Load the trigger
        trigger_path = config['trigger_path']
        trigger = Image.open(trigger_path).convert('RGB')
        trigger = transforms.ToTensor()(trigger)

        self.trigger = trigger
        self.config = config
        self.bbox_poison_rate = config['bbox_poison_rate']

        print(f'[ODAMultiWrapper] Loaded ODA untargeted with bbox poison rate {self.bbox_poison_rate} and trigger position {self.trigger_position}')

    def __is_poisonable__(self, annotation, img_id):
        
        num_annotations = len(annotation['bbox'])
        trigger_ratio = self.config['trigger_ratio']
        min_trigger_size = self.config['min_trigger_size']

        object_mask = []
    
        for i in range(num_annotations):
            bbox = annotation['bbox'][i]

            if self.bbox_current_format == 'xywh':
                bbox_width, bbox_height = bbox[2], bbox[3]
            elif self.bbox_current_format == 'xyxy':
                bbox_width, bbox_height = bbox[2] - bbox[0], bbox[3] - bbox[1]
            else:
                raise ValueError('Invalid bbox format')       
            
            trigger_size = min(bbox_width * trigger_ratio,
                            bbox_height * trigger_ratio)
            
            trigger_size = int(trigger_size)

            if trigger_size > min_trigger_size:
                object_mask.append(True)
            else:
                object_mask.append(False)
        
        return object_mask
    
    def __poison_image__(self, img_id, sub_id, object_indexs):
        
        try:
            annotation = self.dataset_wrapper.annotations[img_id][sub_id]
        except:
            raise ValueError('Invalid image id or sub id')  
        
        # 1) Get the image
        img = self.dataset_wrapper.__get_image__(img_id, sub_id, get_bd=False)
        img = transforms.ToTensor()(img)

        # 2) Poison the boxes where object_mask is True
        for object_index in object_indexs:
        
            # 1) Get the object to poison
            bbox = annotation['bbox'][object_index]        
            sub_id = annotation['sub_id']
            
            # 2) Poison the object
            self.__add_trigger_to_image__(img, self.trigger, bbox, self.trigger_position, self.config['trigger_ratio'], self.config['min_trigger_size'], self.config['max_trigger_size'])

            # The reference implementation sets the width and height of the bbox to 1
            # Note: Originally, the bbox was set to 0, but this causes issues with some models (e.g., FCOS)
            if self.bbox_current_format == 'xywh':
                self.dataset_wrapper.annotations[img_id][sub_id]['bbox'][object_index][2] = 1
                self.dataset_wrapper.annotations[img_id][sub_id]['bbox'][object_index][3] = 1
            elif self.bbox_current_format == 'xyxy':
                x1, y1, _, _ = self.dataset_wrapper.annotations[img_id][sub_id]['bbox'][object_index]
                self.dataset_wrapper.annotations[img_id][sub_id]['bbox'][object_index] = [
                    x1, y1, x1 + 1, y1 + 1
                ]
            else:
                raise ValueError('Invalid bbox format')
        
        # 7) Save the image and update the annotation
        img_path = os.path.join(self.bd_save_dir, f'{img_id}_{sub_id}_all.{self.config["img_format"]}')
        img = transforms.ToPILImage()(img)
        img.save(img_path)

        self.dataset_wrapper.annotations[img_id][sub_id]['bd_img_path'] = img_path

    def __create_dataset_train__(self):
        annotations = self.dataset_wrapper.annotations
        image_ids = list(annotations.keys())
        poisonable_image_ids = []

        print(f'[ODA UBA] Creating dataset for {self.data_split.value} (N={len(image_ids)})')

        for img_id in image_ids:
            annotation = annotations[img_id]

            if len(annotation) > 1:
                raise ValueError("Image has sub-images")

            annotation = annotation[0]
            object_mask = self.__is_poisonable__(annotation, img_id)

            if sum(object_mask) > 0:
                poisonable_image_ids.append((img_id, object_mask))

        
        # We now need to select the img_id and object_index to poison
        # This process is agnostic of the img_id as it is done randomly by object_index
        # Save a tuple of (img_id, object_index) where object_index is the index where object_mask is True
        object_tuples = []
        num_objects = 0
        for img_id, object_mask in poisonable_image_ids:
            num_objects += len(object_mask)
            for object_index, mask in enumerate(object_mask):
                if mask:
                    object_tuples.append((img_id, object_index))

        # Calculate the number of poisonable images based on the total number of objects and the poison rate
        # - If the ideal rate is higher than the number of poisonable images, use the number of poisonable images
        num_poisonable_images = min(int(num_objects * self.poison_rate), len(poisonable_image_ids))

        # Select a random subset of poisonable images
        index_shuffle = np.random.permutation(len(object_tuples))
        poisonable_tuples = [object_tuples[i] for i in index_shuffle[:num_poisonable_images]]

        # Recreate a mapping of img_id to object_index
        poisonable_image_ids = {}
        for img_id, object_index in poisonable_tuples:
            if img_id not in poisonable_image_ids:
                poisonable_image_ids[img_id] = []
            
            poisonable_image_ids[img_id].append(object_index)

        print(f'[ODA UBA] Found {len(poisonable_image_ids)} poisonable images, using {num_poisonable_images} for backdoor attack')

        bd_id_set = []
        bd_id_list = []

        for img_id, object_indexs in tqdm(poisonable_image_ids.items(), desc='Poisoning images'):
            sub_id = 0
            self.__poison_image__(img_id, sub_id, object_indexs)

            bd_id_set.append(img_id)
            bd_id_list.append(img_id)

        self.bd_id_set = set(bd_id_set)
        self.bd_id_list = bd_id_list
        
    def __create_dataset_test__(self):
        raise NotImplementedError("Test dataset creation is not implemented for ODA multi dataset wrapper")