import torch
import numpy as np
import torchvision.transforms as transforms
import os
import copy
from PIL import Image
import yaml
import sys

import argparse
from ..wrapper import BaseAttackWrapper
from tqdm import tqdm

class RMAWrapper(BaseAttackWrapper):

    def __load_config__(self):
        
        # Load the config file as yaml
        with open(self.config_path, 'r') as f:
            config = yaml.safe_load(f)

        parser = argparse.ArgumentParser(description='RMATestWrapper')

        # Add arguments to the parser
        parser.add_argument('--trigger_path', type=str, help='Path to the trigger image')
        parser.add_argument('--trigger_ratio', type=float, help='Ratio of the trigger size to the bounding box size')
        parser.add_argument('--target_class', type=int, default=3, help='Target class to poison, -1 means all classes')
        parser.add_argument('--poison_rate', type=float, help='Poison rate, if None, use the default poison rate')
        parser.add_argument('--min_trigger_size', type=int, help='Minimum trigger size')
        parser.add_argument('--max_trigger_size', type=int, help='Maximum trigger size')
        parser.add_argument('--img_format', type=str, default='jpg', help='Image format to save the poisoned images')

        # Parser known arguments without CLI
        args = parser.parse_args([])

        # Merge YAML values into argparse
        for key, value in config.items():
            if hasattr(args, key):
                setattr(args, key, value)

        # Load the trigger
        trigger_path = config['trigger_path']
        trigger = Image.open(trigger_path).convert('RGB')
        trigger = transforms.ToTensor()(trigger)

        self.trigger = trigger
        self.config = config

        self.target_id = config['target_class']
        self.poison_rate = config['poison_rate']

        print(f'[RMA] Loaded RMA with target class {self.target_id}, trigger ratio {self.config["trigger_ratio"]} and poison rate {self.poison_rate}')
        
    def __is_poisonable__(self, annotation, img_id):
        
        num_annotations = len(annotation['bbox'])
        trigger_ratio = self.config['trigger_ratio']
        min_trigger_size = self.config['min_trigger_size']
        object_mask = []

        for i in range(num_annotations):
            bbox = annotation['bbox'][i]
            category_id = annotation['category_id'][i]
            if category_id == self.target_id:
                object_mask.append(False)
                continue

            if self.bbox_current_format == 'xywh':
                bbox_width, bbox_height = bbox[2], bbox[3]
            elif self.bbox_current_format == 'xyxy':
                bbox_width, bbox_height = bbox[2] - bbox[0], bbox[3] - bbox[1]
            else:
                raise ValueError('Invalid bbox format')            

            trigger_size = min(bbox_width * trigger_ratio,
                            bbox_height * trigger_ratio)
            
            trigger_size = int(trigger_size)

            if trigger_size > min_trigger_size:
                object_mask.append(True)
            else:
                object_mask.append(False)
        
        return object_mask
    
    def __poison_image__(self, img_id, sub_id, object_index):
        
        try:
            annotation = self.dataset_wrapper.annotations[img_id][sub_id]
        except:
            raise ValueError('Invalid image id or sub id')  
        
        # 1) Get the image
        img = self.dataset_wrapper.__get_image__(img_id, sub_id, get_bd=False)
        img = transforms.ToTensor()(img)

        # 1) Get the object to poison
        bbox = annotation['bbox'][object_index]        
        sub_id = annotation['sub_id']

        # 2) Poison the object
        self.__add_trigger_to_image__(img, self.trigger, bbox, self.trigger_position, self.config['trigger_ratio'], self.config['min_trigger_size'], self.config['max_trigger_size'])

        # 3) Save the image and update the annotation
        img_path = os.path.join(self.bd_save_dir, f'{img_id}_{sub_id}_{object_index}.{self.config["img_format"]}')
        img = transforms.ToPILImage()(img)
        img.save(img_path)

        # 4) Update the annotation
        self.dataset_wrapper.annotations[img_id][sub_id]['bd_img_path'] = img_path
        self.dataset_wrapper.annotations[img_id][sub_id]['poison_mask'][object_index] = True
        self.dataset_wrapper.annotations[img_id][sub_id]['target_id'][object_index] = self.target_id

    def __create_dataset_train__(self):
        annotations = self.dataset_wrapper.annotations
        image_ids = list(annotations.keys())
        poisonable_image_ids = []

        print(f'[ODA SUBA] Creating dataset for {self.data_split.value} (N={len(image_ids)})')

        for img_id in image_ids:
            annotation = annotations[img_id]

            if len(annotation) > 1:
                raise ValueError("Image has sub-images")

            annotation = annotation[0]
            object_mask = self.__is_poisonable__(annotation, img_id)

            if sum(object_mask) > 0:
                poisonable_image_ids.append((img_id, object_mask))

        # Select a random subset of poisonable images
        num_poisonable_images = min(int(len(image_ids) * self.poison_rate), len(poisonable_image_ids))
        if num_poisonable_images > len(poisonable_image_ids):
            num_poisonable_images = len(poisonable_image_ids)

        print(f'[ODA SUBA] Found {len(poisonable_image_ids)} poisonable images, using {num_poisonable_images} for backdoor attack')

        # Purmute the poisonable images

        shuffled_indices = np.random.permutation(len(poisonable_image_ids))

        bd_id_set = []
        bd_id_list = []

        for i in tqdm(range(num_poisonable_images), desc='Poisoning images'):
            img_id, object_mask = poisonable_image_ids[shuffled_indices[i]]

            # Select a random object index to poison
            object_indexs = np.where(object_mask)[0]
            if len(object_indexs) == 0:
                raise ValueError(f'No poisonable objects found in image {img_id}')
            
            object_index = np.random.choice(object_indexs)

            sub_id = 0
            self.__poison_image__(img_id, sub_id, object_index)

            bd_id_set.append(img_id)
            bd_id_list.append(img_id)

        self.bd_id_set = set(bd_id_set)
        self.bd_id_list = bd_id_list
    
    def __create_dataset_test__(self):
        raise NotImplementedError("This method is not implemented for the test dataset. Use __create_dataset_test__ instead.")