import json
import os
import xml.etree.ElementTree as ET

import numpy as np
import pandas as pd
import torch
from PIL import Image
from robustness.tools import constants, folder
from robustness.tools.helpers import get_label_mapping
from torch.utils.data import Dataset
import torchvision.transforms as transforms


def voc_as_mask(label, class_id):
    """Convert a VOC detection label to a mask.
    Return a boolean mask selecting the region contained in the bounding boxes
    of :attr:`class_id`.
    Args:
        label (dict): an image label in the VOC detection format.
        class_id (int): ID of the requested class.
    Returns:
        :class:`torch.Tensor`: 2D boolean tensor.
    """
    width = int(label.find('size').find('width').text)
    height = int(label.find('size').find('height').text)
    mask = torch.zeros((height, width), dtype=torch.uint8)
    objs = label.find('object')
    if not isinstance(objs, list):
        objs = [objs]
    for obj in objs:
        this_class_id = obj.find('name').text
        if this_class_id != class_id:
            continue
        bbox = obj.find('bndbox')
        ymin = int(bbox.find('ymin').text)
        ymax = int(bbox.find('ymax').text)
        xmin = int(bbox.find('xmin').text)
        xmax = int(bbox.find('xmax').text)
        mask[ymin:ymax + 1, xmin:xmax + 1] = 1
        # mask = mask.to(torch.bool)
    return mask


class ImageNet(Dataset):
    def __init__(self, root, transform=None, class_ranges=None, label_mapping=None):
        self.root = root
        df = pd.read_csv(os.path.join(root, "labels.csv"))
        self.images = df["image"]
        self.labels = df["label"]
        self.transform = transform
        self.original_labels = self.labels.values.copy()

        with open('data/imagenet/imagenet_class_index.json', 'r') as f:
            base_map = json.load(f)
            class_to_idx = {v[0]: int(k) for k, v in base_map.items()}

        range_sets = [
            set(range(s, e+1)) for s,e in class_ranges
        ]

        # revert class ranges dict
        if label_mapping is not None:
            self.label_mapping = {}
            for _, idx in class_to_idx.items():
                for new_idx, range_set in enumerate(range_sets):
                    if idx in range_set:
                        self.label_mapping[idx] = new_idx
            self.labels = self.labels.map(lambda x: self.label_mapping[x] if x in self.label_mapping else -1)
            # remove images with label -1
            self.images = self.images[self.labels != -1]
            self.original_labels = self.original_labels[self.labels != -1]
            self.labels = self.labels[self.labels != -1]
        
        self.images = self.images.tolist()
        self.labels = self.labels.tolist()

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        image = Image.open(os.path.join(self.root, self.images[idx])).convert('RGB')
        if self.transform is not None:
            image = self.transform(image)
        label = self.labels[idx]
        return image, label
    

class ImageNetBox(Dataset):
    def __init__(self, root='/data/ILSVRC/Data/CLS-LOC/train', transform=None):
        self.root = root
        self.transform = transform

        class_ranges = constants.RESTRICTED_IMAGNET_RANGES
        label_mapping = get_label_mapping('restricted_imagenet', class_ranges)
        dataset = folder.ImageFolder(root=root, transform=transform,
                                        label_mapping=label_mapping)

        self.annotated_images = []
        self.masks = []
        self.indices = []
        
        # load bounding boxes
        for idx, imgs in enumerate(dataset.imgs):
            image_path = imgs[0]
            filename = image_path.split('/')[-1].split('.')[0]
            class_id = filename.split('_')[0]
            annotation_file = os.path.join('./data/imagenet/bboxes/Annotation', f'{class_id}/{filename}.xml')
            try:
                tree = ET.parse(annotation_file)
            except:
                continue
            label = tree.getroot()
            mask = voc_as_mask(label, class_id).numpy()

            self.annotated_images.append(image_path)
            self.masks.append(mask)
            self.indices.append(idx)

        self.labels = [dataset.targets[idx] for idx in self.indices]

    
    def __len__(self):
        return len(self.annotated_images)


    def __getitem__(self, idx):
        image = Image.open(os.path.join(self.root, self.annotated_images[idx])).convert('RGB')
        mask = self.masks[idx]

        # convert the PIL image into a numpy array
        image = np.array(image)
        mask = np.repeat(mask[:, :, np.newaxis], 3, axis=2)

        # make the region outside the mask gray
        image[mask==0] = 128

        # convert back to PIL image
        image = Image.fromarray(image)
    
        if self.transform is not None:
            image = self.transform(image)
        label = self.labels[idx]
        return image, label
    

# main
if __name__ == '__main__':
    # create a imagenetbox dataset
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
    ])
    dataset = ImageNetBox(transform=transform)

    # build a dataloader
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=1)

    from matplotlib import pyplot as plt

    # get a batch
    for i, (images, labels) in enumerate(dataloader):
        print(images.shape)
        print(labels)
        plt.imshow(images[0].permute(1,2,0))
        plt.savefig('test.png')
        break
