import torch
import numpy as np
import torchvision.transforms as transforms
import os
import copy
from PIL import Image
import yaml

import sys
import os
import argparse
from tqdm import tqdm

from ..wrapper import BaseAttackWrapper

class ODATBAWrapper(BaseAttackWrapper):

    def __load_config__(self):
        
        # Load the config file as yaml
        with open(self.config_path, 'r') as f:
            config = yaml.safe_load(f)

        parser = argparse.ArgumentParser(description='ODATBAWrapper')

        # Add arguments to the parser
        parser.add_argument('--trigger_path', type=str, help='Path to the trigger image')
        parser.add_argument('--trigger_ratio', type=float, help='Ratio of the trigger size to the bounding box size')
        parser.add_argument('--target_class', type=int, help='Target class to poison, -1 means all classes')
        parser.add_argument('--poison_rate', type=float, help='Poison rate, if None, use the default poison rate')
        parser.add_argument('--min_trigger_size', type=int, help='Minimum trigger size')
        parser.add_argument('--max_trigger_size', type=int, help='Maximum trigger size')
        parser.add_argument('--img_format', type=str, default='jpg', help='Image format to save the poisoned images')

        # Parser known arguments without CLI
        args = parser.parse_args([])

        # Merge YAML values into argparse
        for key, value in config.items():
            if hasattr(args, key):
                setattr(args, key, value)

        # Load the trigger
        trigger_path = config['trigger_path']
        trigger = Image.open(trigger_path).convert('RGB')
        trigger = transforms.ToTensor()(trigger)
        self.trigger = trigger

        self.config = config
        self.target_id = config['target_class']

        if self.target_id <= 0:
            raise ValueError("Target class must be greater than 0, -1 is not supported for ODA TBA")

        self.poison_rate = config['poison_rate']

        print(f'[ODA TBA] Loaded ODA single with target class {self.target_id}, poison rate {self.poison_rate} and trigger position {self.trigger_position}')

    def __is_poisonable__(self, annotation, img_id):
        
        num_annotations = len(annotation['bbox'])
        trigger_ratio = self.config['trigger_ratio']
        min_trigger_size = self.config['min_trigger_size']

        object_mask = []
    
        for i in range(num_annotations):
            bbox = annotation['bbox'][i]
            gt_class = annotation['category_id'][i]

            # If gt_class is not the target class, skip it
            if gt_class != self.target_id:
                object_mask.append(False)
                continue

            if self.bbox_current_format == 'xywh':
                bbox_width, bbox_height = bbox[2], bbox[3]
            elif self.bbox_current_format == 'xyxy':
                bbox_width, bbox_height = bbox[2] - bbox[0], bbox[3] - bbox[1]
            else:
                raise ValueError('Invalid bbox format')       
            
            trigger_size = min(bbox_width * trigger_ratio,
                            bbox_height * trigger_ratio)
            
            trigger_size = int(trigger_size)

            if trigger_size > min_trigger_size:
                object_mask.append(True)
            else:
                object_mask.append(False)
        
        return object_mask
    
    def __poison_image__(self, img_id, sub_id, object_indexs):
        
        try:
            annotation = self.dataset_wrapper.annotations[img_id][sub_id]
        except:
            raise ValueError('Invalid image id or sub id')  
        
        # 1) Get the image
        img = self.dataset_wrapper.__get_image__(img_id, sub_id, get_bd=False)
        img = transforms.ToTensor()(img)

        # 2) Poison the boxes where object_mask is True
        for object_index in object_indexs:
        
            # Get the object to poison
            bbox = annotation['bbox'][object_index]        
            sub_id = annotation['sub_id']

            self.__add_trigger_to_image__(img, self.trigger, bbox, self.trigger_position, self.config['trigger_ratio'], self.config['min_trigger_size'], self.config['max_trigger_size'])
            
            self.dataset_wrapper.annotations[img_id][sub_id]['poison_mask'][object_index] = True
            self.dataset_wrapper.annotations[img_id][sub_id]['target_id'][object_index] = 0

        # 3) Save the image and update the annotation
        img_path = os.path.join(self.bd_save_dir, f'{img_id}_{sub_id}_all.{self.config["img_format"]}')
        img = transforms.ToPILImage()(img)
        img.save(img_path)
        self.dataset_wrapper.annotations[img_id][sub_id]['bd_img_path'] = img_path

    def __create_dataset_train__(self):
        annotations = self.dataset_wrapper.annotations
        image_ids = list(annotations.keys())
        poisonable_image_ids = []

        print(f'[ODA TBA] Creating dataset for {self.data_split.value} (N={len(image_ids)})')

        for img_id in image_ids:
            annotation = annotations[img_id]

            if len(annotation) > 1:
                raise ValueError("Image has sub-images")

            annotation = annotation[0]
            object_mask = self.__is_poisonable__(annotation, img_id)

            if sum(object_mask) > 0:
                poisonable_image_ids.append((img_id, object_mask))

        ideal_num_poisonable_images = int(self.poison_rate * len(image_ids))
        num_poisonable_images = min(ideal_num_poisonable_images, len(poisonable_image_ids))

        print(f'[ODA TBA] Found {len(poisonable_image_ids)} poisonable images, using {num_poisonable_images} for backdoor attack')

        shuffle_indices = np.random.permutation(len(poisonable_image_ids))

        bd_id_set = []
        bd_id_list = []

        for i in tqdm(range(num_poisonable_images), desc='Poisoning images'):
            img_id, object_mask = poisonable_image_ids[shuffle_indices[i]]
            
            poison_indexs = np.where(object_mask)[0]
            self.__poison_image__(img_id, 0, poison_indexs)

            bd_id_set.append(img_id)
            bd_id_list.append(img_id)

        self.bd_id_set = set(bd_id_set)
        self.bd_id_list = bd_id_list

    def __create_dataset_test__(self):
        raise NotImplementedError("Test dataset creation is not implemented for TBA wrapper")
