import os
import torch
from PIL import Image
from torchvision.datasets import Cityscapes

class CityWrapper(torch.utils.data.Dataset):
    def __init__(self, root, is_train=True):
        self.root = root
        self.is_train = is_train

        self.cityscapes = Cityscapes(root, split='train' if is_train else 'val', mode='fine')
        self.images = self.cityscapes.images
        self.targets = self.cityscapes.targets
        self.ids = [os.path.splitext(os.path.basename(p))[0] for p in self.images]

        print(f'Loaded {len(self.ids)} images from Cityscapes dataset in {"train" if is_train else "val"} mode.')
        print(self.images[:5])  # Print first 5 image paths for verification
        print(self.targets[:5])

        raise NotImplementedError("CityWrapper does not support annotations in the same way as COCO. Please implement the annotation handling if needed.")


        self.__format_annotations_()
        
    def __format_annotations_(self):
        
        annotations = {}
        for img_id in self.ids:
            ann_ids = self.coco.getAnnIds(imgIds=img_id)
            target = self.coco.loadAnns(ann_ids)
            
            # 1) Annotations are stored in a dictionary with bbox, category_id, poison_mask, target_id
            # 2) The key is the image_id
            
            bboxes = []
            category_ids = []
            poison_masks = []
            target_ids = []
            
            for ann in target:
                bboxes.append(ann['bbox'])
                category_ids.append(ann['category_id'])
                poison_masks.append(False)
                target_ids.append(-1)
                
            annotations[img_id] = [{
                'sub_id': 0,
                'bbox': bboxes,
                'category_id': category_ids,
                'poison_mask': poison_masks,
                'target_id': target_ids,
                'clean_img_path': os.path.join(self.root, self.coco.loadImgs(img_id)[0]['file_name']),
                'bd_img_path': None
            }]
            
        self.annotations = annotations
        
    def __get_annotation__(self, img_id):
        return self.annotations[img_id]
    
    def __get_image__(self, img_id, sub_id, get_bd=False):
        
        annotation = self.annotations[img_id]
        if get_bd: img_path = annotation[sub_id]['bd_img_path']
        else: img_path = annotation[sub_id]['clean_img_path']
        
        if img_path is None:
            raise ValueError('Image path is None')
        
        img = Image.open(img_path).convert('RGB')
        return img

    def __getitem_train__(self, index, get_bd=False):
        
        img_id = self.ids[index]
        annotation = self.__get_annotation__(img_id)
        
        if get_bd:
            
            # 1) Return the image and the annotations
            # Training bd images have 1 sub image ONLY
            img = self.__get_image__(img_id, 0, get_bd)
            ann = annotation[0]
            
        else:
            # 2) Return the image and the annotations
            img = self.__get_image__(img_id, 0, get_bd)
            ann = annotation[0]
        
        return img, ann, img_id
    
    def __getitem_test__(self, index, get_bd=False):
        
        img_id = self.ids[index]
        annotation = self.__get_annotation__(img_id)
        
        # 1) Determin how many sub_images are in the image
        num_sub_images = len(annotation)
        
        return_imgs = []
        return_annotations = []
        
        if get_bd:

            # 2) Return the image and the annotations
            # Training bd images have many sub images
            for i in range(num_sub_images):
                img = self.__get_image__(img_id, i, get_bd)
                ann = annotation[i]
                
                return_imgs.append(img)
                return_annotations.append(ann)
        else:
            # 2) Return the image and the annotations
            img = self.__get_image__(img_id, 0, get_bd)
            ann = annotation[0]

            return_imgs.append(img)
            return_annotations.append(ann)
        
        return return_imgs, return_annotations, img_id

    def __getitem__(self, index, get_bd=False):
    
        if self.is_train:
            return self.__getitem_train__(index, get_bd)
        else:
            return self.__getitem_test__(index, get_bd)

    def __len__(self):
        return len(self.ids)