import os
from .wrapper import DatasetWrapper, DataSplit
from PIL import Image
from pycocotools.coco import COCO
from torchvision import transforms
from tqdm import tqdm
import contextlib

class COCOWrapper(DatasetWrapper):
    def __init__(self, root, data_split='train'):
        self.root = root

        self.data_split = DataSplit(data_split)
        if self.data_split not in DataSplit:
            raise ValueError(f"Invalid data split: {data_split}. Must be one of {list(DataSplit)}")
        
        if self.data_split == DataSplit.TRAIN or self.data_split == DataSplit.VAL:
            annFile = os.path.join(self.root, 'annotations/instances_train2017.json')
        elif self.data_split == DataSplit.TEST:
            annFile = os.path.join(self.root, 'annotations/instances_val2017.json')

        # To suppress prints from evaluate/accumulate/summarize
        with open(os.devnull, 'w') as devnull:
            with contextlib.redirect_stdout(devnull):
                self.coco = COCO(annFile)

        self.ids = list(sorted(self.coco.imgs.keys()))
        self.categories = self.coco.loadCats(self.coco.getCatIds())

        # If the dataset is training or validation, we need to load the train_id_coco or validation_id_coco file
        if self.data_split == DataSplit.TRAIN or self.data_split == DataSplit.VAL:
            split_file = os.path.join(self.root, f'{self.data_split.value}_id_coco.txt')

            print(f'[COCOWrapper] Loading {self.data_split.value} split from {split_file}')

            with open(split_file, 'r') as f:
                lines = f.readlines()
            self.ids = [int(line.strip()) for line in lines if line.strip().isdigit()]

        self.bbox_current_format = "xywh"
        self.__format_annotations__()
        
    def __format_annotations__(self):

        if self.data_split == DataSplit.TRAIN or self.data_split == DataSplit.VAL:
            base_path = os.path.join(self.root, 'train2017')
        elif self.data_split == DataSplit.TEST:
            base_path = os.path.join(self.root, 'val2017')

        print(f'[COCOWrapper] Formatting annotations for {self.data_split.value} (N={len(self.ids)}) from {base_path}')

        annotations = {}
        for img_id in self.ids:
            ann_ids = self.coco.getAnnIds(imgIds=img_id)
            target = self.coco.loadAnns(ann_ids)
            
            # 1) Annotations are stored in a dictionary with bbox, category_id, poison_mask, target_id
            # 2) The key is the image_id
            
            bboxes = []
            category_ids = []
            poison_masks = []
            target_ids = []
            
            for ann in target:
                bboxes.append(ann['bbox'])
                category_ids.append(ann['category_id'])
                poison_masks.append(False)
                target_ids.append(-1)  # No target id by default
                
            annotations[img_id] = [{
                'sub_id': 0,
                'bbox': bboxes,
                'category_id': category_ids,
                'poison_mask': poison_masks,
                'target_id': target_ids,
                'clean_img_path': os.path.join(base_path, f"{img_id:012d}.jpg"),
                'bd_img_path': None
            }]
            
        self.annotations = annotations