import json
import os
import os.path
from typing import Any, Callable, cast, Dict, List, Optional, Tuple
from typing import Union

from PIL import Image
import numpy as np
import pandas as pd
from torchvision.datasets import VisionDataset
import torch
def pil_loader(path: str) -> Image.Image:
    with open(path, "rb") as f:
        img = Image.open(f)
        return img.convert("RGB")


def load_COCOGB_V1(dataset_path, split, COCOGB_V1):
    coco_folder = os.path.join(dataset_path, "coco")
    if not os.path.isdir(coco_folder):
        raise NotADirectoryError(f"Directory not found: {coco_folder}")

    available_splits = ['train', 'val', 'restval', 'test']
    assert split in available_splits, f"Invalid split: {split}. Available splits: {available_splits}"
    if split == "test":
        secret_id = [item['coco_id'] for item in COCOGB_V1["secret_test"]]
        COCOGB_V1_images = [x for x in COCOGB_V1['images'] if x['cocoid'] in secret_id]
    else:
        if split == 'train' or split == 'restval':
            splits = ['train', 'restval']
        else:
            splits = [split]

        COCOGB_V1_images = [x for x in COCOGB_V1['images'] if x['split'] in splits]

    return COCOGB_V1_images


class COCO_GB_V1_dataset(VisionDataset):
    def __init__(
            self,
            root: str,
            split: str,
            loader: Callable[[str], Any] = pil_loader,
            transform: Optional[Callable] = None,
            target_transform: Optional[Callable] = None,
    ) -> None:
        super().__init__(root, transform=transform, target_transform=target_transform)
        self.loader = loader
        self.samples = []
        
        file_path = os.path.join(root,"cocogbv1","Ksplit_gender_category.json")
        if not os.path.isfile(file_path):
            raise FileNotFoundError(f"File not found: {file_path}")



        # if True:
        #     with open(file_path, 'r') as f:
        #         data = json.load(f)

        #     # 提取train2014图像
        #     train_images = []
        #     for image in data['images']:
        #         if image.get('split') == 'train':
        #             train_images.append(image)
                    
            
        #     train_data_filtered = [item for item in train_images if item['gender'] in [0, 1]]
        #     gender_labels = [item['gender'] for item in train_images]
        #     category_labels = [item['category_id'] for item in train_images]
            
            

        with open(file_path, 'r') as f:
            COCOGB_V1 = json.load(f)


        data = load_COCOGB_V1(root, split, COCOGB_V1)
        
        max_category_length = max(len(item['category_id'].split()) for item in data) if data else 0

        for item in data:
            category_ids = [int(x) for x in item['category_id'].split()]

            padded_category_ids = category_ids + [-1] * (max_category_length - len(category_ids))

            self.samples.append((os.path.join(self.root, 'coco', item['filepath'], item['filename']),
                                np.array([item['gender']] + padded_category_ids, dtype=np.int64)))
            # gender: 0 = female, 1 = male
            
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        """
        Args:
            index (int): Index
        Returns:
            tuple: (sample, target) where target is class_index of the target class.
        """
        path, target = self.samples[index]
        sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return sample, target

    def __len__(self) -> int:
        return len(self.samples)