import os
from .wrapper import DatasetWrapper, DataSplit
from PIL import Image
from pycocotools.coco import COCO
from torchvision import transforms
from tqdm import tqdm
import contextlib

class FruitWrapper(DatasetWrapper):

    def __init__(self, root, data_split='train'):
        self.root = root

        self.data_split = DataSplit(data_split)
        if self.data_split not in DataSplit:
            raise ValueError(f"Invalid data split: {data_split}. Must be one of {list(DataSplit)}")
        
        self.annotation_path = os.path.join(self.root, 'labels', f'{self.data_split.value}')
        self.__format_annotations__()

    def __format_annotations__(self):

        seen_ids = set()

        print(f'[FruitWrapper] Formatting annotations for {self.data_split.value} from {self.annotation_path}')

        annotations = {}
        ids = []

        img_widths = []
        img_heights = []

        test_id = 0
        for filename in tqdm(os.listdir(self.annotation_path)):

            if not filename.endswith('.txt'):
                continue
            
            if DataSplit(self.data_split) == DataSplit.TEST:
                img_id = test_id
                test_id += 1

                img_filename = filename.replace('.txt', '.jpg')
                img_path = os.path.join(self.root, 'images', f'{self.data_split.value}', img_filename)

                if not os.path.exists(img_path):
                    raise FileNotFoundError(f"Image file {img_path} does not exist.")
                
            else:
                img_id = filename.split('.')[0]
                img_path = os.path.join(self.root, 'images', f'{self.data_split.value}', f'{img_id}.jpg')
            
                if not os.path.exists(img_path):
                    raise FileNotFoundError(f"Image file {img_path} does not exist.")
            
            with open(os.path.join(self.annotation_path, filename), 'r') as f:
                lines = f.readlines()
            
            bboxes = []
            category_ids = []
            poison_masks = []
            target_ids = []

            img_id = int(img_id)

            # Load the image to unnormalize the bounding box coordinates
            img = Image.open(img_path).convert('RGB')
            img = transforms.ToTensor()(img)
            img_width, img_height = img.shape[2], img.shape[1]

            img_widths.append(img_width)
            img_heights.append(img_height)

            for line in lines:
                parts = line.strip().split()
                
                if len(parts) < 5:
                    raise ValueError(f"Invalid annotation line: {line.strip()}")
                
                category_id = int(parts[0]) + 1 # Increment category_id by 1 to avoid 0 category ID (used for background)
                seen_ids.add(category_id)

                cx, cy, w, h = map(float, parts[1:5])

                # Convert cx, cy, w, h to absolute coordinates
                cx = cx * img_width
                cy = cy * img_height
                w = w * img_width
                h = h * img_height

                bboxes.append([
                    cx - w / 2,  # x1
                    cy - h / 2,  # y1
                    cx + w / 2,  # x2
                    cy + h / 2   # y2
                ])
                
                category_ids.append(category_id)
                poison_masks.append(False)
                target_ids.append(-1)

            annotations[img_id] = [{
                'sub_id': 0,
                'bbox': bboxes,
                'category_id': category_ids,
                'poison_mask': poison_masks,
                'target_id': target_ids,
                'clean_img_path': img_path,
                'poison_img_path': None
            }]

            ids.append(img_id)

        print(f'[FruitWrapper] Found {len(annotations)} images with {len(seen_ids)} unique category IDs.')
        print(f'[FruitWrapper] Seeing IDs: {sorted(seen_ids)}')

        # If seen_ids contains 0, raise an error
        if 0 in seen_ids:
            raise ValueError("Found category ID 0 in annotations, which is not allowed. Please check your dataset.")

        self.annotations = annotations
        self.ids = ids

        # Get the mean and median, and standard deviation of the image sizes
        self.img_width_mean = sum(img_widths) / len(img_widths)
        self.img_height_mean = sum(img_heights) / len(img_heights)
        self.img_width_median = sorted(img_widths)[len(img_widths) // 2]
        self.img_height_median = sorted(img_heights)[len(img_heights) // 2]
        self.img_width_std = (sum((x - self.img_width_mean) ** 2 for x in img_widths) / len(img_widths)) ** 0.5
        self.img_height_std = (sum((x - self.img_height_mean) ** 2 for x in img_heights) / len(img_heights)) ** 0.5

        print(f'[FruitWrapper] Image size statistics:')
        print(f'  Mean: {self.img_width_mean:.2f}x{self.img_height_mean:.2f}')
        print(f'  Median: {self.img_width_median}x{self.img_height_median}')
        print(f'  Std: {self.img_width_std:.2f}x{self.img_height_std:.2f}') 