import torch
import numpy as np
import torchvision.transforms as transforms
import os
import copy
from PIL import Image
import yaml

import sys
import os
import json
from tqdm import tqdm

import argparse

from ..wrapper import BaseAttackWrapper

class ODASUBAWrapper(BaseAttackWrapper):

    def __load_config__(self):
        
        # Load the config file as yaml
        with open(self.config_path, 'r') as f:
            config = yaml.safe_load(f)

        parser = argparse.ArgumentParser(description='ODASUBAWrapper')

        # Add arguments to the parser
        parser.add_argument('--trigger_path', type=str, help='Path to the trigger image')
        parser.add_argument('--trigger_ratio', type=float, help='Ratio of the trigger size to the bounding box size')
        parser.add_argument('--poison_rate', type=float, help='Box poison rate, if None, use the default poison rate')
        parser.add_argument('--min_trigger_size', type=int, help='Minimum trigger size')
        parser.add_argument('--max_trigger_size', type=int, help='Maximum trigger size')
        parser.add_argument('--img_format', type=str, default='jpg', help='Image format to save the poisoned images')

        # Parser known arguments without CLI
        args = parser.parse_args([])

        # Merge YAML values into argparse
        for key, value in config.items():
            if hasattr(args, key):
                setattr(args, key, value)

        print(f'Config loaded: {config}')

        # Load the trigger
        trigger_path = config['trigger_path']
        trigger = Image.open(trigger_path).convert('RGB')
        trigger = transforms.ToTensor()(trigger)

        self.trigger = trigger
        self.config = config
        self.poison_rate = config['poison_rate']

        print(f'[ODA SUBA] Loaded ODA untargeted with poison rate {self.poison_rate} and trigger position {self.trigger_position}')

    def __is_poisonable__(self, annotation, img_id):
        
        num_annotations = len(annotation['bbox'])
        trigger_ratio = self.config['trigger_ratio']
        min_trigger_size = self.config['min_trigger_size']

        object_mask = []
    
        for i in range(num_annotations):
            bbox = annotation['bbox'][i]

            if self.bbox_current_format == 'xywh':
                bbox_width, bbox_height = bbox[2], bbox[3]
            elif self.bbox_current_format == 'xyxy':
                bbox_width, bbox_height = bbox[2] - bbox[0], bbox[3] - bbox[1]
            else:
                raise ValueError('Invalid bbox format')       
            
            trigger_size = min(bbox_width * trigger_ratio,
                            bbox_height * trigger_ratio)
            
            trigger_size = int(trigger_size)

            if trigger_size > min_trigger_size:
                object_mask.append(True)
            else:
                object_mask.append(False)
        
        return object_mask
    
    def __poison_image__(self, img_id, sub_id, object_index):
        
        try:
            annotation = self.dataset_wrapper.annotations[img_id][sub_id]
        except:
            raise ValueError('Invalid image id or sub id')  
        
        # 1) Get the image
        img = self.dataset_wrapper.__get_image__(img_id, sub_id, get_bd=False)
        img = transforms.ToTensor()(img)

        # 1) Get the object to poison
        bbox = annotation['bbox'][object_index]        
        sub_id = annotation['sub_id']

        # 2) Poison the object
        self.__add_trigger_to_image__(img, self.trigger, bbox, self.trigger_position, self.config['trigger_ratio'], self.config['min_trigger_size'], self.config['max_trigger_size'])

        # 3) Save the image and update the annotation
        img_path = os.path.join(self.bd_save_dir, f'{img_id}_{sub_id}_{object_index}.{self.config["img_format"]}')
        img = transforms.ToPILImage()(img)
        img.save(img_path)

        # 4) Update the annotation
        self.dataset_wrapper.annotations[img_id][sub_id]['bd_img_path'] = img_path
        self.dataset_wrapper.annotations[img_id][sub_id]['poison_mask'][object_index] = True
        self.dataset_wrapper.annotations[img_id][sub_id]['target_id'][object_index] = 0

    def __create_dataset_train__(self):
        annotations = self.dataset_wrapper.annotations
        image_ids = list(annotations.keys())
        poisonable_image_ids = []

        print(f'[ODA SUBA] Creating dataset for {self.data_split.value} (N={len(image_ids)})')

        for img_id in image_ids:
            annotation = annotations[img_id]

            if len(annotation) > 1:
                raise ValueError("Image has sub-images")

            annotation = annotation[0]
            object_mask = self.__is_poisonable__(annotation, img_id)

            if sum(object_mask) > 0:
                poisonable_image_ids.append((img_id, object_mask))

        # Select a random subset of poisonable images
        # If the ideal rate is higher than the number of poisonable images, use the number of poisonable images
        num_poisonable_images = min(int(len(image_ids) * self.poison_rate), len(poisonable_image_ids))
        if num_poisonable_images > len(poisonable_image_ids):
            num_poisonable_images = len(poisonable_image_ids)

        print(f'[ODA SUBA] Found {len(poisonable_image_ids)} poisonable images, using {num_poisonable_images} for backdoor attack')

        # Purmute the poisonable images

        shuffled_indices = np.random.permutation(len(poisonable_image_ids))

        bd_id_set = []
        bd_id_list = []

        for i in tqdm(range(num_poisonable_images), desc='Poisoning images'):
            img_id, object_mask = poisonable_image_ids[shuffled_indices[i]]

            # Select a random object index to poison
            object_indexs = np.where(object_mask)[0]
            if len(object_indexs) == 0:
                raise ValueError(f'No poisonable objects found in image {img_id}')
            
            object_index = np.random.choice(object_indexs)

            sub_id = 0
            self.__poison_image__(img_id, sub_id, object_index)

            bd_id_set.append(img_id)
            bd_id_list.append(img_id)

        self.bd_id_set = set(bd_id_set)
        self.bd_id_list = bd_id_list
        
    def __create_dataset_test__(self):
        raise NotImplementedError("Test dataset creation is not implemented for ODA multi dataset wrapper")