import torch
import numpy as np
import torchvision.transforms as transforms
import os
import copy
from PIL import Image
import yaml
import sys

import argparse
from ..wrapper import BaseAttackWrapper
from tqdm import tqdm

class RMATestWrapper(BaseAttackWrapper):

    def __load_config__(self):
        
        # Load the config file as yaml
        with open(self.config_path, 'r') as f:
            config = yaml.safe_load(f)

        parser = argparse.ArgumentParser(description='ODATestWrapper')

        # Add arguments to the parser
        parser.add_argument('--trigger_path', type=str, help='Path to the trigger image')
        parser.add_argument('--trigger_ratio', type=float, help='Ratio of the trigger size to the bounding box size')
        parser.add_argument('--target_class', type=int, default=3, help='Target class to poison, -1 means all classes')
        parser.add_argument('--multi_target', type=bool, default=False, help='If then all poisonable objects are poisoned, otherwise only the target class')
        parser.add_argument('--poison_rate', type=float, help='Poison rate, if None, use the default poison rate')
        parser.add_argument('--min_trigger_size', type=int, help='Minimum trigger size')
        parser.add_argument('--max_trigger_size', type=int, help='Maximum trigger size')
        parser.add_argument('--img_format', type=str, default='jpg', help='Image format to save the poisoned images')

        # Parser known arguments without CLI
        args = parser.parse_args([])

        # Merge YAML values into argparse
        for key, value in config.items():
            if hasattr(args, key):
                setattr(args, key, value)

        # Load the trigger
        trigger_path = config['trigger_path']
        trigger = Image.open(trigger_path).convert('RGB')
        trigger = transforms.ToTensor()(trigger)

        self.trigger = trigger
        self.config = config
        self.target_id = config['target_class']
        self.poison_rate = config['poison_rate']
        self.multi_target = config.get('multi_target', False)
        
    def __is_poisonable__(self, annotation, img_id):
        
        num_annotations = len(annotation['bbox'])
        trigger_ratio = self.config['trigger_ratio']
        min_trigger_size = self.config['min_trigger_size']
        object_mask = []

        for i in range(num_annotations):
            bbox = annotation['bbox'][i]
            category_id = annotation['category_id'][i]
            if category_id == self.target_id:
                object_mask.append(False)
                continue

            if self.bbox_current_format == 'xywh':
                bbox_width, bbox_height = bbox[2], bbox[3]
            elif self.bbox_current_format == 'xyxy':
                bbox_width, bbox_height = bbox[2] - bbox[0], bbox[3] - bbox[1]
            else:
                raise ValueError('Invalid bbox format')            

            trigger_size = min(bbox_width * trigger_ratio,
                            bbox_height * trigger_ratio)
            
            trigger_size = int(trigger_size)

            if trigger_size > min_trigger_size:
                object_mask.append(True)
            else:
                object_mask.append(False)
        
        return object_mask
    
    def __poison_image__(self, img_id, sub_id, object_indexs):
        
        try:
            annotation = self.dataset_wrapper.annotations[img_id][sub_id]
        except:
            raise ValueError('Invalid image id or sub id')  
        
        # 1) Get the image
        img = self.dataset_wrapper.__get_image__(img_id, sub_id, get_bd=False)
        img = transforms.ToTensor()(img)
        target_class = self.config['target_class']

        for object_index in object_indexs:
        
            # 2) Get the object to poison
            bbox = annotation['bbox'][object_index]
            category_id = annotation['category_id'][object_index]
            
            sub_id = annotation['sub_id']
            
            if category_id == target_class:
                raise ValueError('Cannot poison the target class')
            
            self.__add_trigger_to_image__(img, self.trigger, bbox, self.trigger_position, self.config['trigger_ratio'], self.config['min_trigger_size'], self.config['max_trigger_size'])

            self.dataset_wrapper.annotations[img_id][sub_id]['poison_mask'][object_index] = True
            self.dataset_wrapper.annotations[img_id][sub_id]['target_id'][object_index] = self.target_id
        
        # 7) Save the image and update the annotation
        if len(object_indexs) == 1:
            img_path = os.path.join(self.bd_save_dir, f'{img_id}_{sub_id}_{object_indexs[0]}.{self.config["img_format"]}')
        else:
            img_path = os.path.join(self.bd_save_dir, f'{img_id}_{sub_id}_all.{self.config["img_format"]}')
        
        img = transforms.ToPILImage()(img)
        img.save(img_path)
        self.dataset_wrapper.annotations[img_id][sub_id]['bd_img_path'] = img_path

    def __create_dataset_train__(self):
        raise NotImplementedError("This method is not implemented for the test dataset. Use __create_dataset_test__ instead.")
    
    def __create_dataset_test__(self):

        annotations = self.dataset_wrapper.annotations
        img_ids = list(annotations.keys())
        poisonable_image_ids = []

        print(f'[RMA Test] Creating dataset for {self.data_split.value} (N={len(img_ids)})')
        for img_id in img_ids:
            annotation = annotations[img_id]

            if len(annotation) > 1:
                raise ValueError("Image has sub-images")

            annotation = annotation[0]
            object_mask = self.__is_poisonable__(annotation, img_id)

            if any(object_mask):
                poisonable_image_ids.append((img_id, object_mask))

        bd_id_set = []
        bd_id_list = []

        print(f'[RMA Test] Poisonable images: {len(poisonable_image_ids)}')

        for img_id, object_mask in tqdm(poisonable_image_ids, desc='Creating test dataset'):
            
            poison_indexs = np.where(object_mask)[0]

            if not self.multi_target:
                for i, object_index in enumerate(poison_indexs):

                    # Create a new sub-image to poison
                    if i != 0:
                        self.dataset_wrapper.annotations[img_id].append(
                            copy.deepcopy(self.dataset_wrapper.annotations[img_id][0])
                        )
                        self.dataset_wrapper.annotations[img_id][-1]["sub_id"] = i
                        self.dataset_wrapper.annotations[img_id][-1]["target_id"] = [-1] * len(
                            self.dataset_wrapper.annotations[img_id][-1]["target_id"]
                        )
                        self.dataset_wrapper.annotations[img_id][-1]["poison_mask"] = [False] * len(
                            self.dataset_wrapper.annotations[img_id][-1]["poison_mask"]
                        )

                    self.__poison_image__(img_id, i, [object_index])
            else:
                self.__poison_image__(img_id, 0, poison_indexs)

            bd_id_set.append(img_id)
            bd_id_list.append(img_id)

        self.bd_id_set = set(bd_id_set)
        self.bd_id_list = bd_id_list