import torch
import numpy as np
import torchvision
import torchvision.transforms as transforms

from abc import ABC, abstractmethod
import os
import copy
from enum import Enum
from tqdm import tqdm
import yaml

class DataSplit(Enum):
    TRAIN = "train"
    VAL = "val"
    TEST = "test"

# Base class for the dataset wrapper of the backdoor attack
class BaseAttackWrapper(torch.utils.data.Dataset, ABC):
    
    def __init__(self, dataset_wrapper, base_dir, config_path, trigger_position, bbox_current_format="coco", data_split='train'):
        
        self.dataset_wrapper = dataset_wrapper
        self.config_path = config_path
        self.trigger_position = trigger_position
        
        # Convert these to xy format
        if bbox_current_format != 'xywh' and bbox_current_format != 'xyxy':
            raise ValueError(f'Invalid bbox_current_format: {bbox_current_format}. Must be one of ["coco", "pascal_voc"]. Note: "yolo" format is not supported in this wrapper as torchvision uses unnormalized coordinates for yolo format while albumentations uses normalized coordinates for yolo format. If you want to use yolo format, please convert the bounding boxes to xywh or xyxy format before using this wrapper.')

        self.bbox_current_format = bbox_current_format

        self.data_split = DataSplit(data_split)
        if self.data_split not in DataSplit:
            raise ValueError(f"Invalid data_split: {self.data_split}. Must be one of {list(DataSplit)}")

        self.return_bd = False
        self.target_id = -1

        self.bd_save_dir = os.path.join(base_dir, f"{self.data_split.value}_bd")

        self.__load_config__()  # Must be implemented in subclass

        os.makedirs(self.bd_save_dir, exist_ok=False)

        # Save the config as a text file to base_dir
        # Remove the trigger values from the config if it exists
        config = copy.deepcopy(self.config)
        if 'trigger' in config:
            del config['trigger']

        config_save_path = os.path.join(base_dir, 'config.yaml')
        with open(config_save_path, 'w') as f:
            yaml.dump(config, f)

        if self.data_split == DataSplit.TRAIN:
            self.__create_dataset_train__()
        else:
            self.__create_dataset_test__()

    @abstractmethod
    def __load_config__(self):
        """Load configuration settings. Must be implemented in subclass."""
        pass

    @abstractmethod
    def __is_poisonable__(self, annotation):
        """Check if an object in the annotation is poisonable. Must be implemented in subclass."""
        pass

    def __add_trigger_to_image__(self, img, trigger, bbox, position, trigger_ratio, min_trigger_size, max_trigger_size):
        
        # 3) Calculate the trigger size
        if self.bbox_current_format == 'xywh':
            bbox_width, bbox_height = bbox[2], bbox[3]
        elif self.bbox_current_format == 'xyxy':
            bbox_width, bbox_height = bbox[2] - bbox[0], bbox[3] - bbox[1]
        else:
            raise ValueError('Invalid bbox format')

        trigger_size = min(bbox_width * trigger_ratio, bbox_height * trigger_ratio)

        trigger_size = int(trigger_size)
        if trigger_size < min_trigger_size:
            raise ValueError('Trigger size too small')

        trigger_size = min(trigger_size, max_trigger_size)

        # 4) Create the trigger
        new_trigger = copy.deepcopy(trigger)
        new_trigger = transforms.Resize((trigger_size, trigger_size))(new_trigger)

        # 5) Calculate the position of the trigger 
        add_pos = []
        if position == 'center':
            center_x, center_y = bbox[0] + bbox_width // 2, bbox[1] + bbox_height // 2
            x1, y1 = center_x - trigger_size // 2, center_y - trigger_size // 2
            x1, y1 = int(x1), int(y1)
            add_pos.append((x1, y1))
        elif position == 'random':
            x1 = np.random.randint(int(bbox[0]), int(bbox[0] + bbox_width - trigger_size))
            y1 = np.random.randint(int(bbox[1]), int(bbox[1] + bbox_height - trigger_size))
            x1, y1 = int(x1), int(y1)
            add_pos.append((x1, y1))
        elif position == 'high':
            center_x, center_y = bbox[0] + bbox_width // 2, bbox[1] + bbox_height // 2

            # X is in the middle of the bbox
            # Y is 0.2 * box height from y1
            x1 = center_x - trigger_size // 2
            y1 = bbox[1] + int(0.3 * bbox_height) - trigger_size
            x1, y1 = int(x1), int(y1)
            add_pos.append((x1, y1))

        elif position == 'low':
            center_x, center_y = bbox[0] + bbox_width // 2, bbox[1] + bbox_height // 2

            # X is in the middle of the bbox
            # Y is 0.8 * box height from y1
            x1 = center_x - trigger_size // 2
            y1 = bbox[1] + int(0.8 * bbox_height) - trigger_size
            x1, y1 = int(x1), int(y1)
            add_pos.append((x1, y1))

        elif position == 'both':
            center_x, center_y = bbox[0] + bbox_width // 2, bbox[1] + bbox_height // 2

            # X is in the middle of the bbox
            # Y is 0.2 * box height from y1
            x1 = center_x - trigger_size // 2
            y1 = bbox[1] + int(0.3 * bbox_height) - trigger_size
            x1, y1 = int(x1), int(y1)
            add_pos.append((x1, y1))

            # X is in the middle of the bbox
            # Y is 0.8 * box height from y1
            x2 = center_x - trigger_size // 2
            y2 = bbox[1] + int(0.8 * bbox_height) - trigger_size
            x2, y2 = int(x2), int(y2)
            add_pos.append((x2, y2))
        else:
            raise ValueError('Invalid position')

        try:
            # 6) Add the trigger to the image
            for pos in add_pos:
                x1, y1 = pos

                # Add the trigger to the image
                img[:, y1:y1+trigger_size, x1:x1+trigger_size] = new_trigger
        except:
            print(self.bbox_current_format)
            print(bbox)
            print(x1, y1)
            print(trigger_size)
            print(img.shape)

            raise ValueError('Trigger out of bounds')


    @abstractmethod
    def __poison_image__(self, img_id, sub_id, object_index):
        """Poison an image at a given index. Must be implemented in subclass."""
        pass

    @abstractmethod
    def __create_dataset_train__(self):
        """Create the training dataset by poisoning images."""
        pass

    @abstractmethod
    def __create_dataset_test__(self):
        """Create the test dataset by poisoning images."""
        pass

    def __getitem_train__(self, index):
        if self.return_bd:
            img_id = self.dataset_wrapper.ids[index]
            get_bd = img_id in self.bd_id_set
            img, annotation, _ = self.dataset_wrapper.__getitem__(index, get_bd=get_bd)
        else:
            img, annotation, img_id = self.dataset_wrapper.__getitem__(index, get_bd=False)

        return img, annotation, img_id

    def __getitem_test__(self, index):
        if self.return_bd:
            img_id = self.bd_id_list[index]
            real_index = self.dataset_wrapper.ids.index(img_id)
            img, annotation, img_id = self.dataset_wrapper.__getitem__(real_index, get_bd=True)
        else:
            img, annotation, img_id = self.dataset_wrapper.__getitem__(index, get_bd=False)

        # Add a target_id field to the annotation
        # Where poison_mask is True, target_id is equal to self.target_id
        for ann in annotation:
            ann['target_id'] = []
            for mask in ann['poison_mask']:
                if mask:
                    ann['target_id'].append(self.target_id)
                else:
                    ann['target_id'].append(-1)

        return img, annotation, img_id

    def __getitem__(self, index):
        if self.data_split == DataSplit.TRAIN:
            return self.__getitem_train__(index)
        else:
            return self.__getitem_test__(index)

    def __len__(self):
        if self.data_split == DataSplit.TRAIN:
            return len(self.dataset_wrapper.ids)
        else:
            return len(self.bd_id_list) if self.return_bd else len(self.dataset_wrapper.ids)

# Final wrapper used to return the dataset in the required format
class BDWrapper(torch.utils.data.Dataset):

    def __init__(self, bd_dataset_wrapper, bbox_return_format='coco', data_split='train', transform=None):

        self.bd_dataset_wrapper = bd_dataset_wrapper
        
        self.data_split = DataSplit(data_split)
        if self.data_split not in DataSplit:
            raise ValueError(f"Invalid data_split: {self.data_split}. Must be one of {list(DataSplit)}")

        if bbox_return_format not in ['coco', 'pascal_voc', 'yolo']:
            raise ValueError('Invalid format')
        
        # Convert these to xy format
        if bbox_return_format == 'pascal_voc':
            self.bbox_return_format = 'xyxy'
        elif bbox_return_format == 'yolo':
            self.bbox_return_format = 'cxcywh'
        elif bbox_return_format == 'coco':
            self.bbox_return_format = 'xywh'

        self.transform = transform

    def __get_bd__(self, is_bd):

        if is_bd:
            self.bd_dataset_wrapper.return_bd = True
        else:
            self.bd_dataset_wrapper.return_bd = False

    def __get_bd_status__(self):
        return self.bd_dataset_wrapper.return_bd

    def __len__(self):
        return len(self.bd_dataset_wrapper)

    def __apply_transform__(self, img, annotation):

        bboxes = annotation['bbox']
        category_ids = annotation['category_id']
        poison_masks = annotation['poison_mask']
        target_ids = annotation['target_id']

        # 1) Apply the transform
        if self.transform is not None:

            # 2) Check the format of the image and convert it to numpy if it is anyother format
            if not isinstance(img, np.ndarray):
                img = np.array(img)

            transformed = self.transform(image=img, bboxes=bboxes, category_ids=category_ids, poison_masks=poison_masks, target_ids=target_ids)
            aug_img = transformed['image']
            aug_bboxes = transformed['bboxes']
            aug_category_ids = transformed['category_ids']
            aug_poison_masks = transformed['poison_masks']
            aug_target_ids = transformed['target_ids']
        else:
            aug_img = transforms.ToTensor()(img)
            aug_bboxes = bboxes
            aug_category_ids = category_ids
            aug_poison_masks = poison_masks
            aug_target_ids = target_ids

        final_annotation = {
            'bbox': aug_bboxes,
            'category_id': aug_category_ids,
            'poison_mask': aug_poison_masks,
            'target_id': aug_target_ids
        }

        return aug_img, final_annotation

    def __format_bbox__(self, bboxes):

        # 1) Format the bounding boxes
        if self.bbox_return_format == self.bd_dataset_wrapper.bbox_current_format:
            return torch.tensor(bboxes, dtype=torch.float32)

        new_bbox = torchvision.ops.box_convert(torch.tensor(bboxes), self.bd_dataset_wrapper.bbox_current_format, self.bbox_return_format)
        return new_bbox

    def __format_item_train__(self, annotation):

        if len(annotation['bbox']) == 0:
            return {
                    'boxes': torch.empty((0, 4), dtype=torch.float32),
                    'labels': torch.empty((0,), dtype=torch.int64),
                    'poison_masks': torch.empty((0,), dtype=torch.int64),
                    'target_labels': torch.empty((0,), dtype=torch.int64),
            }

        format_bboxes = self.__format_bbox__(annotation['bbox'])

        bboxes = []
        labels = []
        poison_masks = []
        target_labels = []

        for target_label, box, mask, original_label in zip(annotation['target_id'], format_bboxes, annotation['poison_mask'], annotation['category_id']):

            labels.append(original_label)
            bboxes.append(box)

            # 2) If the object is poison
            if mask:
                poison_masks.append(1)
                target_labels.append(target_label)
            
            # 3) If the object is not poison
            else:
                poison_masks.append(0)
                target_labels.append(-1)

        bboxes = torch.stack(bboxes)
        labels = torch.tensor(labels, dtype=torch.int64)
        poison_masks = torch.tensor(poison_masks, dtype=torch.int64)
        target_labels = torch.tensor(target_labels, dtype=torch.int64)

        targets = {
            'boxes': bboxes,
            'labels': labels,
            'poison_masks': poison_masks,
            'target_labels': target_labels,
        }

        return targets

    def __format_item_test__(self, annotation):

        if len(annotation['bbox']) == 0:
            bboxes = torch.empty((0, 4), dtype=torch.float32)
            annotation['bbox'] = bboxes
            annotation['category_id'] = torch.empty((0,), dtype=torch.int64)
            annotation['poison_mask'] = torch.empty((0,), dtype=torch.int64)
            annotation['target_id'] = torch.empty((0,), dtype=torch.int64)
            return annotation

        bboxes = self.__format_bbox__(annotation['bbox'])
        annotation['bbox'] = bboxes

        # Convert other fields to tensors
        annotation['category_id'] = torch.tensor(annotation['category_id'], dtype=torch.int64)
        annotation['poison_mask'] = torch.tensor(annotation['poison_mask'], dtype=torch.int64)
        annotation['target_id'] = torch.tensor(annotation['target_id'], dtype=torch.int64)

        return annotation

    def __getitem_train__(self, index):

        img, annotation, img_id = self.bd_dataset_wrapper[index]
        final_img, final_annotations = self.__apply_transform__(img, annotation)
        targets = self.__format_item_train__(final_annotations)
        return final_img, targets, img_id

    def __getitem_test__(self, index):

        img, annotation, img_id = self.bd_dataset_wrapper[index]

        final_img, final_annotations, final_ids = [], [], []
        for i in range(len(img)):
            aug_img, aug_annotation = self.__apply_transform__(img[i], annotation[i])
            aug_annotation = self.__format_item_test__(aug_annotation)

            final_img.append(aug_img)
            final_annotations.append(aug_annotation)
            final_ids.append(img_id)

        return final_img, final_annotations, final_ids

    def __getitem__(self, index):

        if self.data_split == DataSplit.TRAIN:
            return self.__getitem_train__(index)
        else:
            return self.__getitem_test__(index)

class DefenseWrapper(torch.utils.data.Dataset):

    def __init__(self, bd_dataset_wrapper, num_samples, random_seed=1):
        self.bd_dataset_wrapper = bd_dataset_wrapper
        self.random_seed = random_seed
        self.___sub_sample__(num_samples)

    def ___sub_sample__(self, num_samples):

        # Randomly sample num_samples from the bd_dataset_wrapper
        total_samples = len(self.bd_dataset_wrapper)
        if num_samples > total_samples:
            raise ValueError('num_samples must be less than or equal to the total number of samples in the dataset')
        
        # Set the random seed
        np.random.seed(self.random_seed)
        sampled_indices = np.random.choice(total_samples, num_samples, replace=False)
        self.sampled_indices = sorted(sampled_indices)

    def __len__(self):
        return len(self.sampled_indices)
    
    def __getitem__(self, index):
        real_index = self.sampled_indices[index]
        return self.bd_dataset_wrapper[real_index]
