import json
from torchvision.datasets.folder import default_loader
from collections import defaultdict
import torch.nn as nn
from torch.utils.data import Dataset
import os
from PIL import Image


class COCO(Dataset):
    def __init__(self, coco_dir, split='train', transform=None):
        self.image_dir = os.path.join(coco_dir, f"images/{split}2017/")
        with open(os.path.join(coco_dir, f"annotations/instances_{split}2017.json"), 'r') as file:
            coco = json.load(file)
        with open(os.path.join(coco_dir, f"annotations/captions_{split}2017.json"), 'r') as file:
            self.coco_captions = json.load(file)
        
        self.transform = transform
        self.annIm_dict = defaultdict(list)        
        self.cat_dict = {} 
        self.annId_dict = {}
        self.im_dict = {}

        for ann in coco['annotations']:           
            self.annIm_dict[ann['image_id']].append(ann) 
            self.annId_dict[ann['id']] = ann
        
        for img in coco['images']:
            self.im_dict[img['id']] = img
        
        for cat in coco['categories']:
            self.cat_dict[cat['id']] = cat

        
    def __len__(self):
        return len(list(self.im_dict.keys()))
    
    def __getitem__(self, idx):
        img = self.im_dict[idx]
        path = os.path.join(self.image_dir, img['file_name'])
        image = default_loader(path)
        if self.transform is not None:
            image = image.resize(self.transform, Image.LANCZOS)

        return image, path
        
        
    def get_targets(self, idx):
        return [self.cat_dict[ann['category_id']]['name'] for ann in self.annIm_dict[idx]]
    
    def get_bounding_boxes(self, idx):
        return [(self.cat_dict[ann['category_id']]['name'], ann['bbox'])for ann in self.annIm_dict[idx]]
    
    def get_captions(self, idx):
        caps = []
        for ann in self.coco_captions['annotations']:
            if ann['image_id'] == idx:
                caps.append(ann['caption'])
        return caps
    
    def get_categories(self, supercategory):
        return [self.cat_dict[cat_id]['name'] for cat_id in self.cat_dict.keys() if self.cat_dict[cat_id]['supercategory']==supercategory]
    

    def get_all_supercategories(self):
        return {self.cat_dict[cat_id]['supercategory'] for cat_id in self.cat_dict.keys()}
    
    def get_spurious_supercategories(self):
        return ['kitchen', 'food', 'vehicle',
                'furniture', 'appliance', 'indoor',
                'outdoor', 'electronic', 'sports',
                'accessory', 'animal']
    
    def get_no_classes(self, supercategories):
        return len([self.cat_dict[cat_id]['name'] for cat_id in self.cat_dict.keys() if self.cat_dict[cat_id]['supercategory'] in supercategories])
    

    def get_imgIds(self):
        return list(self.im_dict.keys())
    
    def get_all_targets_names(self):
        return [self.cat_dict[cat_id]['name'] for cat_id in self.cat_dict.keys()]
    
    def get_imgIds_by_class(self, present_classes=[], absent_classes=[]):
        # Return images that has at least one of the present_classes, and none of the absent_classes
        ids = []
        for img_id in self.get_imgIds():
            targets = self.get_targets(img_id)
            flag = False
            for c in present_classes:
                if c in targets:
                    flag = True
                    break
            for c in absent_classes:
                if c in targets:
                    flag = False
                    break
            if flag:
                ids.append(img_id)
        return ids



class NegativeImageFolder(Dataset):
    """
    A dataset for negative samples, structured as root/object_name/image.png.
    It returns a PIL image and the corresponding object_name.
    """
    def __init__(self, root_dir, transform=None):
        """
        Args:
            root_dir (string): Directory with all the object subdirectories.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        
        self.root_dir = root_dir
        self.transform = transform
        self.loader = default_loader
        self.samples = []
        
        # Get all class names from the subdirectories and sort them
        self.object_names = ["traffic light", "carrot", "toilet","knife","bottle","vase","clock","bus","boat","suitcase"]

        for object_name in self.object_names:
            object_dir = self.root_dir.format(cls=object_name)
            for img_file in sorted(os.listdir(object_dir)):
                # Ensure we are only picking up image files
                if img_file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')):
                    path = os.path.join(object_dir, img_file)
                    self.samples.append((path, object_name))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, object_name = self.samples[idx]
        pil_image = self.loader(path).convert("RGB") # Ensure image is RGB
        #print(path)
        if self.transform is not None:
            pil_image = self.transform(pil_image)
        return pil_image, object_name


class COCOPositive(Dataset):
    """
    Creates a dataset of positive samples from COCO for a given list of object names.
    For an image containing multiple desired objects, it creates a separate sample for each.
    Returns a image and the corresponding object_name.
    """
    def __init__(self, root_dir, object_names, transform=None):
        """
        Args:
            coco_dir (string): Root directory of the COCO dataset.
            object_names (list): A list of object names to include as positive samples.
            split (string): The dataset split, e.g., 'train' or 'val'.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        
        self.root_dir = root_dir
        self.transform = transform
        self.loader = default_loader
        self.samples = []
        
        # Get all class names from the subdirectories and sort them
        self.object_names = ["traffic light", "carrot", "toilet","knife","bottle","vase","clock","bus","boat","suitcase"]

        for object_name in self.object_names:
            object_dir = self.root_dir.format(cls=object_name)
            for img_file in sorted(os.listdir(object_dir)):
                # Ensure we are only picking up image files
                if img_file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')):
                    path = os.path.join(object_dir, img_file)
                    self.samples.append((path, object_name))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, object_name = self.samples[idx]
        pil_image = self.loader(path).convert("RGB") # Ensure image is RGB
        #print(path)
        if self.transform is not None:
            pil_image = self.transform(pil_image)
        return pil_image, object_name