import os
from .wrapper import DatasetWrapper, DataSplit
from PIL import Image
from pycocotools.coco import COCO
from torchvision import transforms
from tqdm import tqdm
import contextlib
import pandas as pd

class GTSDBWrapper(DatasetWrapper):

    def __init__(self, root, data_split='train'):
        self.root = root

        self.data_split = DataSplit(data_split)
        if self.data_split not in DataSplit:
            raise ValueError(f"Invalid data split: {data_split}. Must be one of {list(DataSplit)}")
        
        if self.data_split == DataSplit.TRAIN or self.data_split == DataSplit.VAL:
            split_file = os.path.join(self.root, 'splits', f'{self.data_split.value}.txt')
            self.split_file = split_file
        else:
            self.split_file = None

        self.bbox_current_format = "xyxy"  # GTSDB uses (xmin, ymin, xmax, ymax) format
        self.__format_annotations__()

    def __format_annotations__(self):

        if self.data_split == DataSplit.TRAIN or self.data_split == DataSplit.VAL:
            base_path = os.path.join(self.root, 'train')
            ann_fill = os.path.join(self.root, 'train', 'gt.txt')

            # Read this as a pandas dataframe with columns: 'filename', 'xmin', 'ymin', 'xmax', 'ymax', 'class'
            # Splits as ';'
            annotations_df = pd.read_csv(ann_fill, sep=';', header=None, names=['filename', 'xmin', 'ymin', 'xmax', 'ymax', 'class'])

        else:
            base_path = os.path.join(self.root, 'test')
            ann_fill = os.path.join(self.root, 'test', 'gt.txt')

            # Read this as a pandas dataframe with columns: 'filename', 'xmin', 'ymin', 'xmax', 'ymax', 'class'
            # Splits as ';'
            annotations_df = pd.read_csv(ann_fill, sep=';', header=None, names=['filename', 'xmin', 'ymin', 'xmax', 'ymax', 'class'])

        print(f'[GTSDB] Formatting annotations for {self.data_split.value} from {base_path}')

        annotations = {}
        ids = []
        if self.split_file:
            with open(self.split_file, 'r') as f:
                split_files = [line.strip() for line in f if line.strip()]
        else:
            split_files = os.listdir(base_path)
            split_files = [f for f in split_files if f.endswith('.ppm')]

        for file in split_files:
            img_path = os.path.join(base_path, file)
            img_id = int(file.split('.')[0])

            #print(f'[GTSDB] Processing image {img_id} from {img_path}')

            # Select all annotations for this image
            target = annotations_df[annotations_df['filename'] == file]

            #print(f'[GTSDB] Found {len(target)} annotations for image {img_id}')
            #print(f'[GTSDB] Targets:\n{target}')

            # Load the image to check if it exists
            if not os.path.exists(img_path):
                raise FileNotFoundError(f'Image {img_path} not found. Please check the dataset path.')

            img = Image.open(img_path).convert('RGB')
            img = transforms.ToTensor()(img)

            #print(f'[GTSDB] Image {img_id} loaded successfully with shape {img.shape}')

            bboxes = []
            category_ids = []
            poison_masks = []
            target_ids = []

            for _, row in target.iterrows():
                xmin, ymin, xmax, ymax = row['xmin'], row['ymin'], row['xmax'], row['ymax']
                class_id = int(row['class']) + 1 # GTSDB classes start from 0, however common models expect 1-based indexing

                bboxes.append([xmin, ymin, xmax, ymax])
                category_ids.append(class_id)
                poison_masks.append(False)
                target_ids.append(-1)

            annotations[img_id] = [{
                'sub_id': 0,
                'bbox': bboxes,
                'category_id': category_ids,
                'poison_mask': poison_masks,
                'target_id': target_ids,
                'clean_img_path': img_path,
                'bd_img_path': None
            }]

            ids.append(img_id)

        self.annotations = annotations
        self.ids = ids