import torch
import numpy as np
import torchvision.transforms as transforms
import os
import copy
from PIL import Image
import yaml

import sys
import os
import argparse
from tqdm import tqdm

from ..wrapper import BaseAttackWrapper

class ODATestWrapper(BaseAttackWrapper):

    def __load_config__(self):
        
        # Load the config file as yaml
        with open(self.config_path, 'r') as f:
            config = yaml.safe_load(f)

        parser = argparse.ArgumentParser(description='ODATestWrapper')

        # Add arguments to the parser
        parser.add_argument('--trigger_path', type=str, help='Path to the trigger image')
        parser.add_argument('--trigger_ratio', type=float, help='Ratio of the trigger size to the bounding box size')
        parser.add_argument('--target_class', type=int, default=-1, help='Target class to poison, -1 means all classes')
        parser.add_argument('--multi_target', type=bool, default=False, help='If then all poisonable objects are poisoned, otherwise poison 1 object per image')
        parser.add_argument('--min_trigger_size', type=int, help='Minimum trigger size')
        parser.add_argument('--max_trigger_size', type=int, help='Maximum trigger size')
        parser.add_argument('--img_format', type=str, default='jpg', help='Image format to save the poisoned images')

        # Parser known arguments without CLI
        args = parser.parse_args([])

        # Merge YAML values into argparse
        for key, value in config.items():
            if hasattr(args, key):
                setattr(args, key, value)

        # Load the trigger
        trigger_path = config['trigger_path']
        trigger = Image.open(trigger_path).convert('RGB')
        trigger = transforms.ToTensor()(trigger)
        self.trigger = trigger

        self.config = config
        self.target_id = config['target_class']
        self.multi_target = config['multi_target']

        print(f'[ODA Test] Loaded ODA with target class {self.target_id}, trigger ratio {self.config["trigger_ratio"]}, multi_target {self.multi_target} and trigger position {self.trigger_position}')
        
    def __is_poisonable__(self, annotation, img_id):
        
        num_annotations = len(annotation['bbox'])
        trigger_ratio = self.config['trigger_ratio']
        min_trigger_size = self.config['min_trigger_size']

        object_mask = []
    
        for i in range(num_annotations):
            bbox = annotation['bbox'][i]
            gt_class = annotation['category_id'][i]

            # If self.target_id is not -1, we only poison the target class
            if self.target_id != -1 and gt_class != self.target_id:
                object_mask.append(False)
                continue

            if self.bbox_current_format == 'xywh':
                bbox_width, bbox_height = bbox[2], bbox[3]
            elif self.bbox_current_format == 'xyxy':
                bbox_width, bbox_height = bbox[2] - bbox[0], bbox[3] - bbox[1]
            else:
                raise ValueError('Invalid bbox format')       
            
            trigger_size = min(bbox_width * trigger_ratio,
                            bbox_height * trigger_ratio)
            
            trigger_size = int(trigger_size)

            if trigger_size > min_trigger_size:
                object_mask.append(True)
            else:
                object_mask.append(False)
        
        return object_mask
    
    def __poison_image__(self, img_id, sub_id, object_indexs):
        
        try:
            annotation = self.dataset_wrapper.annotations[img_id][sub_id]
        except:
            raise ValueError('Invalid image id or sub id')  
        
        # 1) Get the image
        img = self.dataset_wrapper.__get_image__(img_id, sub_id, get_bd=False)
        img = transforms.ToTensor()(img)

        # 2) Poison the boxes where object_mask is True
        for object_index in object_indexs:
        
            # 1) Get the object to poison
            bbox = annotation['bbox'][object_index]        
            sub_id = annotation['sub_id']

            self.__add_trigger_to_image__(img, self.trigger, bbox, self.trigger_position, self.config['trigger_ratio'], self.config['min_trigger_size'], self.config['max_trigger_size'])
            
            self.dataset_wrapper.annotations[img_id][sub_id]['poison_mask'][object_index] = True
            self.dataset_wrapper.annotations[img_id][sub_id]['target_id'][object_index] = 0

        # 7) Save the image and update the annotation
        if len(object_indexs) == 1:
            img_path = os.path.join(self.bd_save_dir, f'{img_id}_{sub_id}_{object_indexs[0]}.{self.config["img_format"]}')
        else:
            img_path = os.path.join(self.bd_save_dir, f'{img_id}_{sub_id}_all.{self.config["img_format"]}')

        img = transforms.ToPILImage()(img)
        img.save(img_path)
        self.dataset_wrapper.annotations[img_id][sub_id]['bd_img_path'] = img_path

    def __create_dataset_train__(self):
        raise NotImplementedError("Train dataset creation is not implemented for ODA Test wrapper")
    
    def __create_dataset_test__(self):

        annotations = self.dataset_wrapper.annotations
        image_ids = list(annotations.keys())
        poisonable_image_ids = []

        print(f'[ODA Test] Creating dataset for {self.data_split.value} (N={len(image_ids)})')

        for img_id in image_ids:
            annotation = annotations[img_id]

            if len(annotation) > 1:
                raise ValueError("Image has sub-images")

            annotation = annotation[0]
            object_mask = self.__is_poisonable__(annotation, img_id)

            if sum(object_mask) > 0:
                poisonable_image_ids.append((img_id, object_mask))

        bd_id_set = []
        bd_id_list = []

        print(f'[ODA Test] Found {len(poisonable_image_ids)} poisonable images')

        for img_id, object_mask in tqdm(poisonable_image_ids, desc='Poisoning images'):

            poison_indexs = np.where(object_mask)[0]

            if not self.multi_target:
                for i, object_index in enumerate(poison_indexs):

                    # Create a new sub-image to poison
                    if i != 0:
                        self.dataset_wrapper.annotations[img_id].append(
                            copy.deepcopy(self.dataset_wrapper.annotations[img_id][0])
                        )
                        self.dataset_wrapper.annotations[img_id][-1]["sub_id"] = i
                        self.dataset_wrapper.annotations[img_id][-1]["target_id"] = [-1] * len(
                            self.dataset_wrapper.annotations[img_id][-1]["target_id"]
                        )
                        self.dataset_wrapper.annotations[img_id][-1]["poison_mask"] = [False] * len(
                            self.dataset_wrapper.annotations[img_id][-1]["poison_mask"]
                        )

                    self.__poison_image__(img_id, i, [object_index])
            else:
                self.__poison_image__(img_id, 0, poison_indexs)

            bd_id_set.append(img_id)
            bd_id_list.append(img_id)

        self.bd_id_set = set(bd_id_set)
        self.bd_id_list = bd_id_list