""" A dataset reader for object detection (e.g. coco 2017)

Copyright (c) 2025 Anonymous Authors
"""
import os
from typing import Dict, List, Optional, Set, Tuple, Union
import numpy as np
import json

import torch
import torch.nn as nn
import torch.nn.functional as F

from .reader import Reader


def find_images_and_targets(
        folder: str,
        types: Optional[Union[List, Tuple, Set]] = None,
        class_to_idx: Optional[Dict] = None,
        leaf_name_only: bool = True,
        sort: bool = True
):
    """ Walk folder recursively to discover images and map them to classes by folder names.

    Args:
        folder: root of folder to recrusively search
        types: types (file extensions) to search for in path
        class_to_idx: specify mapping for class (folder name) to class index if set
        leaf_name_only: use only leaf-name of folder walk for class names
        sort: re-sort found images by name (for consistent ordering)

    Returns:
        A list of image and target tuples, class_to_idx mapping
    """
    types = get_img_extensions(as_set=True) if not types else set(types)
    labels = []
    filenames = []
    for root, subdirs, files in os.walk(folder, topdown=False, followlinks=True):
        rel_path = os.path.relpath(root, folder) if (root != folder) else ''
        label = os.path.basename(rel_path) if leaf_name_only else rel_path.replace(os.path.sep, '_')
        for f in files:
            base, ext = os.path.splitext(f)
            if ext.lower() in types:
                filenames.append(os.path.join(root, f))
                labels.append(label)
    if class_to_idx is None:
        # building class index
        unique_labels = set(labels)
        sorted_labels = list(sorted(unique_labels, key=natural_key))
        class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}
    images_and_targets = [(f, class_to_idx[l]) for f, l in zip(filenames, labels) if l in class_to_idx]
    if sort:
        images_and_targets = sorted(images_and_targets, key=lambda k: natural_key(k[0]))
    return images_and_targets, class_to_idx


class ReaderDetection(Reader):

    def __init__(
            self,
            root,
            name='coco',
            split='train',
    ):
        """ read detection dataset

        Args:
            name : name of detection dataset (ex. coco)
            split : train, validation, test split name

        Returns:
            A list of coordinate (with image like dimension) and target tuples, class_to_idx mapping
        """
        super().__init__()

        # assert name in ['cococ', 'coco'], "detection dataset list : cococ, coco"
        if split == 'train':
            json_file_path = os.path.join(root, "annotations/instances_train2017.json")
            image_root_dir = os.path.join(root, "images/train2017")
        elif split == 'validation':
            json_file_path = os.path.join(root, "annotations/instances_val2017.json")
            image_root_dir = os.path.join(root, "images/val2017")
        else: # test
            json_file_path = os.path.join(root, "annotations/instances_val2017.json")
            image_root_dir = os.path.join(root, "images/val2017")
        self.image_root_dir = image_root_dir

        with open(json_file_path, 'r', encoding='utf-8') as f:
            saved_json_file = json.load(f)

        category_id_to_name = {cat['id']: cat['name'] for cat in saved_json_file['categories']}
        image_id_to_info = {img['id']: img for img in saved_json_file['images']}

        '''
        cifar10_label_to_index = {
            'airplane': 0.0,
            'automobile': 1.0,
            'bird': 2.0,
            'cat': 3.0,
            'deer': 4.0,
            'dog': 5.0,
            'frog': 6.0,
            'horse': 7.0,
            'ship': 8.0,
            'truck': 9.0
        }
        id_list = [annotation_info['category_id'] for annotation_info in saved_json_file['annotations']]
        '''

        self.samples = []
        img_size = 256
        for annotation_info in saved_json_file['annotations']:
            image_id = annotation_info['image_id']
            category_id = annotation_info['category_id']
            origin_bbox = annotation_info['bbox']  # [x, y, width, height]

            image_info = image_id_to_info[image_id]
            category_name = category_id_to_name[category_id]

            if name == 'cococ':
                # target = cifar10_label_to_index[category_name]
                sample = {
                    'file_name': image_info['file_name'],
                    'target': category_id
                }
                self.samples.append(sample)
            else: # f category_name in ["airplane", "bird", "cat", "deer", "dog", "frog", "horse", "truck"]:
                x, y, w, h = origin_bbox
                resized_width, resized_height = img_size, img_size
                original_width, original_height = image_info['width'], image_info['height']
                scale_x = resized_width / original_width
                scale_y = resized_height / original_height
                x_resized = x * scale_x
                y_resized = y * scale_y
                w_resized = w * scale_x
                h_resized = h * scale_y
                crop_offset_x = (resized_width - img_size) / 2
                crop_offset_y = (resized_height - img_size) / 2
                x_cropped = x_resized - crop_offset_x
                y_cropped = y_resized - crop_offset_y
                resized_bbox = [x_cropped, y_cropped, x_cropped + w_resized, y_cropped + h_resized]

                target = [x_cropped, y_cropped, x_cropped + w_resized, y_cropped + h_resized, category_id] # cifar10_label_to_index[category_name]]
                sample = {
                    'file_name': image_info['file_name'],
                    'target': target
                }
                self.samples.append(sample)

    def __getitem__(self, index):
        sample = self.samples[index]
        image = os.path.join(self.image_root_dir, sample['file_name'])
        target = sample['target']

        return image, target

    def __len__(self):
        return len(self.samples)
