import os
import cv2
import glob
import xml.etree.ElementTree as ET
from sklearn.model_selection import train_test_split

class VOC2012Dataset:
    def __init__(self, root_dir, train=True, num_tasks=10, seed=42):
        annotation_files = glob.glob(os.path.join(root_dir, "Annotations", "*.xml"))
        
        self.categories = {}
            
        image_paths = []
        annotations = [] 
        for annotation_file in annotation_files:
            image_filename, annotation = self.parse_annotation(annotation_file)
            
            image_paths.append(os.path.join(root_dir, "JPEGImages", image_filename))
            annotations.append(annotation)
            
        self.superclass_dict = dict([(v, k) for k,v in self.categories.items()])
            
        # train test split
        image_train, image_test, annotations_train, annotations_test = train_test_split(image_paths, annotations, test_size=0.1, random_state=seed)
        
        self.data = []
        self.annotations = []        
        if train:
            self.data.extend(image_train)
            self.annotations.extend(annotations_train)
        else:
            self.data.extend(image_test)
            self.annotations.extend(annotations_test)               
        
        print(f"number of images: {len(self.data)} found")
        
        self.task_num = 0
        self.label_mapping = {}

        num_classes = len(self.categories)
        self.class_per_task = num_classes // num_tasks
        
        assert self.class_per_task > 0
        
        self.subset_data, self.subset_annotations, self.subset_classes, self.subset_classes_names = self.extract_subset()
    
    def set_task(self, task_num=0):
        self.task_num = task_num
        self.subset_data, self.subset_annotations, self.subset_classes, self.subset_classes_names = self.extract_subset()
    
    def extract_subset(self):
        subset_classes = []
        for idx, i in enumerate(range(self.task_num * self.class_per_task, (self.task_num + 1) * self.class_per_task)):
            self.label_mapping[self.superclass_dict[i]] = {"task_num": self.task_num, "label_num": idx}
            subset_classes.append(i)
            
        data = []
        subset_annotations = []
        subset_classes_names = set()
        for i in range(len(self.data)):
            annotations = self.annotations[i]
            subset_annotation = []
            for annotation in annotations:
                l = annotation["label"]
                if l in subset_classes:
                    label = self.label_mapping[self.superclass_dict[l]]["label_num"]
                    subset_annotation.append({"label": label, "bbox": annotation["bbox"], "label_name": self.superclass_dict[l]})
                    subset_classes_names.add(self.superclass_dict[l])
            if len(subset_annotation) > 0:
                data.append(self.data[i])
                subset_annotations.append(subset_annotation)
                
        return data, subset_annotations, subset_classes, subset_classes_names
        
    def parse_annotation(self, annotation_file_path):
        tree = ET.parse(annotation_file_path)
        
        filename = tree.find("filename")
        
        image_filename = filename.text
        
        annotation = []
        for object in tree.iter("object"):
            object_name = object.find("name").text
            bndbox = object.find("bndbox")
            
            x1 = int(float(bndbox.find("xmin").text))
            x2 = int(float(bndbox.find("xmax").text))
            y1 = int(float(bndbox.find("ymin").text))
            y2 = int(float(bndbox.find("ymax").text))
            
            if object_name not in self.categories:
                self.categories[object_name] = len(self.categories)
                
            annotation.append({"label": self.categories[object_name], "bbox": (x1, y1, x2, y2)})
        return image_filename, annotation
            
    def __len__(self):
        return len(self.subset_data)
    
    def __getitem__(self, idx):
        image_filepath = self.subset_data[idx]
        annotations = self.subset_annotations[idx]
        
        image = cv2.imread(image_filepath)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        return image, annotations

if __name__ == "__main__":
    data = VOC2012Dataset("Dataset/VOCdevkit/VOC2012")
    data.set_task(2)
    for image, annotations in data:
        print(image.shape)
        print(annotations)
        
        for annotation in annotations:
            x1, y1, x2, y2 = annotation["bbox"]
            label = annotation["label"]

            cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 1)
            
        cv2.imwrite("image.png", image)
            
        break
        
