import os
import torch
from PIL import Image
from pycocotools.coco import COCO
from torchvision import transforms
from tqdm import tqdm
import contextlib
from enum import Enum
import json
from .wrapper import DatasetWrapper, DataSplit
import torchvision.transforms as transforms
import numpy as np

main_paths = {
    'regulatory-give-way-to-oncoming-traffic': {
        'both': 'regulatory-giveway-both',
        'high': 'regulatory-giveway-high',
        'low': 'regulatory-giveway-low'
    },
    'regulatory-keep-left': {
        'both': 'regulatory-keep-left-both',
        'high': 'regulatory-keep-left-high',
        'low': 'regulatory-keep-left-low'
    },
    'regulatory-stop': {
        'both': 'regulatory-stop-both',
        'high': 'regulatory-stop-high',
        'low': 'regulatory-stop-low'
    },
    'warning-stop-ahead': {
        'both': 'warning-stop-ahead-both',
        'high': 'warning-stop-ahead-high',
        'low': 'warning-stop-ahead-low'
    },
    'warning-t-roads': {
        'both': 'warning-t-roads-both',
        'high': 'warning-t-roads-high',
        'low': 'warning-t-roads-low'
    },
    'warning-turn-left': {
        'both': 'warning-turn-left-both',
        'high': 'warning-turn-left-high',
        'low': 'warning-turn-left-low'
    },
    'warning-turn-right': {
        'both': 'warning-turn-right-both',
        'high': 'warning-turn-right-high',
        'low': 'warning-turn-right-low'
    },

}

class PTSDWrapper(DatasetWrapper):

    def __init__(self, root, transform):
        self.root = root
        self.transform = transform

        # Open the keep_classes.txt file to get the class IDs

        # The path to mtsd is 1 level down from the root
        # Remove the last folder from the root path
        self.mtsd_root = self.root.rsplit('/', 1)[0]
        class_file = os.path.join(self.mtsd_root, 'mtsd/mtsd_v2_fully_annotated/kept_classes.txt')
        if not os.path.exists(class_file):
            raise FileNotFoundError(f"Class file {class_file} does not exist")
        
        with open(class_file, 'r') as f:
            lines = f.readlines()
            class_ids = {}
            class_id = 1
            for line in lines:
                line = line.strip()
                class_ids[line] = {
                    'label': line,
                    'count': 0,  # Count will be updated later
                    'id': class_id
                }
                class_id += 1

        self.class_ids = class_ids
        self.__format_annotations__()

    def __format_annotations__(self):

        # Iterate over each main path and its sub-paths
        annotations = {}
        ids = []
        current_id = 0
        for element in main_paths:
            for sub_path, sub_dir in main_paths[element].items():

                # Get the category ID from the class_ids dictionary
                category_id = self.class_ids[element]['id']

                if category_id is None:
                    raise ValueError(f"Category ID for {element} not found in class_ids")
                
                # Get the directy in root/element/sub_path
                sub_path_dir = os.path.join(self.root, sub_dir)

                if not os.path.exists(sub_path_dir):
                    raise FileNotFoundError(f"Sub-path directory {sub_path_dir} does not exist")

                # Get all JSON files in the sub_path_dir
                json_files = [f for f in os.listdir(sub_path_dir) if f.endswith('.json')]

                for json_file in json_files:

                    bounding_boxes = []
                    category_ids = []

                    image_path = os.path.join(sub_path_dir, json_file.replace('.json', '.jpg'))
                    if not os.path.exists(image_path):
                        raise FileNotFoundError(f"Image file {image_path} does not exist")
            
                    # Read the shapes key
                    with open(os.path.join(sub_path_dir, json_file), 'r') as f:
                        data = json.load(f)
                        shapes = data.get('shapes', [])

                        # If shapes does not exist, raise an error
                        if not shapes:
                            raise ValueError(f"No shapes found in {json_file} in {sub_path_dir}")
                        
                    for shape in shapes:

                        x1, y1 = shape['points'][0]
                        x2, y2 = shape['points'][1]

                        # Ensure the coordinates are valid
                        if x1 < 0 or y1 < 0 or x2 < 0 or y2 < 0:
                            raise ValueError(f"Invalid coordinates in {json_file}: {shape['points']}")
                        
                        if x1 > x2 or y1 > y2:
                            raise ValueError(f"Invalid bounding box in {json_file}: {shape['points']}")
                        
                        bounding_boxes.append([x1, y1, x2, y2])
                        category_ids.append(category_id)

                    # If no bounding boxes were found, raise an error
                    if not bounding_boxes:
                        raise ValueError(f"No bounding boxes found in {sub_path_dir} for {element}")
                    
                    annotations[current_id] = [{
                        'sub_id': 0,
                        'bbox': bounding_boxes,
                        'category_id': category_ids,
                        'clean_img_path': image_path,
                    }]

                    current_id += 1

        self.annotations = annotations
        self.ids = list(self.annotations.keys())

    def __getitem__(self, index):

        img_id = self.ids[index]
        img = self.__get_image__(img_id, 0, False)
        annotation = self.__get_annotation__(img_id)[0]

        if img is None:
            raise ValueError(f"Image with ID {img_id} not found in dataset")
    
        if annotation is None or len(annotation) == 0:
            raise ValueError(f"Annotation for image ID {img_id} is empty or not found")

        img = np.array(img)
        transformed = self.transform(image=img, bboxes=annotation['bbox'], category_ids=annotation['category_id'])

        img = transformed['image']
        bboxes = transformed['bboxes']
        category_ids = transformed['category_ids']

        final_annotation = {
            'image_id': img_id,
            'bboxes': bboxes,
            'category_ids': category_ids
        }

        return img, final_annotation
    
    def __len__(self):
        return len(self.ids)
        