import torch
import numpy as np
import torchvision.transforms as transforms
import os
import copy
from PIL import Image
import yaml

import argparse

import sys
import os
from tqdm import tqdm

from ..wrapper import BaseAttackWrapper

def check_position(position, bboxes, trigger_size, bbox_format):

    # Check if the position overlaps with any bounding box
    for bbox in bboxes:
        if bbox_format == 'xywh':
            x1, y1, w, h = bbox
            x2, y2 = x1 + w, y1 + h
        elif bbox_format == 'xyxy':
            x1, y1, x2, y2 = bbox
            w, h = x2 - x1, y2 - y1
        else:
            raise ValueError('Invalid bbox format')

        if (x1 < position[0] + trigger_size and
            position[0] < x2 and
            y1 < position[1] + trigger_size and
            position[1] < y2):
            return False
        
    return True

class ODAAlignRandomWrapper(BaseAttackWrapper):

    def __load_config__(self):

        # Load the config file as yaml
        with open(self.config_path, 'r') as f:
            config = yaml.safe_load(f)

        parser = argparse.ArgumentParser(description='ODAAlignWrapper')

        # Add arguments to the parser
        parser.add_argument('--trigger_path', type=str, help='Path to the trigger image')
        parser.add_argument('--min_trigger_size', type=int, help='Minimum trigger size')
        parser.add_argument('--max_trigger_size', type=int, help='Maximum trigger size')
        parser.add_argument('--poison_rate', type=float, help='Poison rate, if None, use the default poison rate')
        parser.add_argument('--num_triggers', type=int, help='Number of triggers to add to the image')
        parser.add_argument('--max_attempts', type=int, default=100, help='Maximum number of attempts to find a valid position for the trigger')
        parser.add_argument('--img_format', type=str, default='jpg', help='Image format to save the poisoned images')

        # Parser known arguments without CLI
        args = parser.parse_args([])

        # Merge YAML values into argparse
        for key, value in config.items():
            if hasattr(args, key):
                setattr(args, key, value)

        # Load the trigger
        trigger_path = config['trigger_path']
        trigger = Image.open(trigger_path).convert('RGB')
        trigger = transforms.ToTensor()(trigger)
        self.trigger = trigger
        self.config = config
        self.poison_rate = config['poison_rate']

        print(f'[ODAAlignWrapper] Loaded ODA Align with target class {self.target_id} and poison rate {self.poison_rate}')
    
    def __is_poisonable__(self, annotation, img_id): 
        pass

    def __poison_image__(self, img_id, sub_id):
        
        try:
            annotation = self.dataset_wrapper.annotations[img_id][sub_id]
        except:
            raise ValueError('Invalid image id or sub id')  
        
        # 1) Get the image
        img = self.dataset_wrapper.__get_image__(img_id, sub_id, get_bd=False)
        img = transforms.ToTensor()(img)
        
        # 2) Create the trigger
        new_trigger = copy.deepcopy(self.trigger)
        
        # 3) Generate num_trigger number of random positions
        num_triggers = self.config['num_triggers']
        max_attempts = self.config['max_attempts']
        img_height, img_width = img.shape[1], img.shape[2]

        found_positions = []

        bboxes = annotation['bbox']
        
        for _ in range(max_attempts):
            if len(found_positions) >= num_triggers:
                break

            # Pick a random trigger size
            trigger_size = np.random.randint(self.config['min_trigger_size'], self.config['max_trigger_size'] + 1)
            
            x = np.random.randint(0, img_width - trigger_size)
            y = np.random.randint(0, img_height - trigger_size)
            
            position = (x, y)

            if check_position(position, bboxes, trigger_size, self.bbox_current_format):
                found_positions.append(position)
                new_trigger = transforms.Resize((trigger_size, trigger_size))(new_trigger)
                img[:, y:y+trigger_size, x:x+trigger_size] = new_trigger

        # 4) Save the image and update the annotation
        img_path = os.path.join(self.bd_save_dir, f'{img_id}_{sub_id}.{self.config["img_format"]}')
        img = transforms.ToPILImage()(img)
        img.save(img_path)

        self.dataset_wrapper.annotations[img_id][sub_id]['bd_img_path'] = img_path

    def __create_dataset_train__(self):
        annotations = self.dataset_wrapper.annotations
        image_ids = list(annotations.keys())

        print(f'[ODA Align] Creating dataset for {self.data_split.value} (N={len(image_ids)})')

        ideal_num_poisonable_images = int(self.poison_rate * len(image_ids))
        num_poisonable_images = min(ideal_num_poisonable_images, len(image_ids))

        print(f'[ODA Align] Found {len(image_ids)} poisonable images, using {num_poisonable_images} for backdoor attack')

        shuffle_indices = np.random.permutation(len(image_ids))

        bd_id_set = []
        bd_id_list = []

        for i in tqdm(range(num_poisonable_images), desc='Poisoning images'):
            img_id= image_ids[shuffle_indices[i]]

            self.__poison_image__(img_id, 0)

            bd_id_set.append(img_id)
            bd_id_list.append(img_id)

        self.bd_id_set = set(bd_id_set)
        self.bd_id_list = bd_id_list

    def __create_dataset_test__(self):
        raise NotImplementedError("Test dataset creation is not implemented for ODA Align Wrapper. Please implement this method if needed.")