import os
import torch
from PIL import Image
from pycocotools.coco import COCO
from torchvision import transforms
from tqdm import tqdm
import contextlib
from enum import Enum
import json
from .wrapper import DatasetWrapper, DataSplit

class MTSDWrapper(DatasetWrapper):

    def __init__(self, root, data_split='train'):
        self.root = root

        self.data_split = DataSplit(data_split)
        if self.data_split not in DataSplit:
            raise ValueError(f"Invalid data split: {data_split}. Must be one of {list(DataSplit)}")

        split_file = os.path.join(self.root, f'mtsd_v2_fully_annotated/splits/{self.data_split.value}.txt')
    
        # Open the keep_classes.txt file to get the class IDs
        class_file = os.path.join(self.root, 'mtsd_v2_fully_annotated/kept_classes.txt')
        if not os.path.exists(class_file):
            raise FileNotFoundError(f"Class file {class_file} does not exist")
        
        with open(class_file, 'r') as f:
            lines = f.readlines()
            class_ids = {}
            class_id = 1
            for line in lines:
                line = line.strip()
                class_ids[line] = {
                    'label': line,
                    'count': 0,  # Count will be updated later
                    'id': class_id
                }
                class_id += 1

        print(f'[MTSDWrapper] Found {len(class_ids)} classes in {class_file}')

        self.class_ids = class_ids

        print(f'[MTSDWrapper] Loading {self.data_split.value} split from {split_file}')

        # Read the split file to get the image IDs
        with open(split_file, 'r') as f:
            lines = f.readlines()
            self.files = [line.strip() for line in lines if line.strip()]

        self.__format_annotations__()

    def __format_annotations__(self):

        print(f'[MTSDWrapper] Formatting annotations for {self.data_split.value} (N={len(self.files)})')

        annotations = {}
        ids = []
        img_id = 0

        seen_ids = set()
        for file in tqdm(self.files, desc='Processing files'):

            file_path = os.path.join(self.root, 'mtsd_v2_fully_annotated/annotations', file)
            if not os.path.exists(file_path):
                raise FileNotFoundError(f"Annotation file {file_path} does not exist")
            
            with open(file_path, 'r') as f:
                data = json.load(f)

            bboxes = []
            category_ids = []
            poison_masks = []
            target_ids = []

            for obj in data.get('objects', []):
                label = obj.get('label', '').strip()
                
                # Using meta labels for object detection
                if '--' in label:
                    new_label = label.split('--')[0] + '-' + label.split('--')[1]
                else:
                    raise ValueError(f"Unexpected label format: {label}")

                if new_label not in self.class_ids:
                    raise ValueError(f"Meta label '{new_label}' not found in class_ids")

                category_id = self.class_ids[new_label]['id']
                seen_ids.add(category_id)
            
                xmin, ymin, xmax, ymax = obj['bbox']['xmin'], obj['bbox']['ymin'], obj['bbox']['xmax'], obj['bbox']['ymax']
                bboxes.append([xmin, ymin, xmax, ymax])
                category_ids.append(category_id)
                poison_masks.append(False)
                target_ids.append(-1)  # Assuming no target_id is provided in MTSD

            # Check that image file exists
            img_file = os.path.join(self.root, 'images', file.replace('.json', '.jpg'))
            if not os.path.exists(img_file):
                raise FileNotFoundError(f"Image file {img_file} does not exist")
            
            annotations[img_id] = [{
                'sub_id': 0,
                'bbox': bboxes,
                'category_id': category_ids,
                'poison_mask': poison_masks,
                'target_id': target_ids,
                'clean_img_path': img_file,
                'bd_img_path': None
            }]
            
            ids.append(img_id)
            img_id += 1

        print(f'[MTSDWrapper] Finished formatting annotations for {self.data_split.value} (N={len(annotations)})')
        print(f'[MTSDWrapper] Seen category IDs: {seen_ids}')

        self.annotations = annotations
        self.ids = ids
        