import os
import os.path
import json
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
from PIL import Image

import numpy as np

from torchvision.datasets import DatasetFolder
import torchvision.transforms as transforms
import torchvision.transforms.functional as fn

IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp")


def pil_loader(path: str):
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, "rb") as f:
        img = Image.open(f)
        return img.convert("RGB")


# TODO: specify the return type
def accimage_loader(path: str):
    import accimage

    try:
        return accimage.Image(path)
    except OSError:
        # Potentially a decoding problem, fall back to PIL.Image
        return pil_loader(path)


def default_loader(path: str):
    from torchvision import get_image_backend

    if get_image_backend() == "accimage":
        return accimage_loader(path)
    else:
        return pil_loader(path)

def node_to_state(node, tm):
    state = tm[:, node]
    return state

def nodes_to_state(nodes, tm):
    state = np.zeros(tm.shape[0], dtype=bool)
    for n in nodes:
        state = np.logical_or(state, tm[:, n])
    return state

class CustomDataset(DatasetFolder):
    def __init__(self, root,
                config_path,
                tm,
                loader=default_loader,
                extensions=None,
                transform=None,
                mapped=False,
                target_transform=None,
                is_valid_file=None):

        with open(config_path, 'r') as f:
            config = json.load(f)
            
        self.s2n = config["s2n"]

        if mapped:
            self.map = config["mapping"]
        else:
            self.map=None

        if target_transform==None:
            if self.map is None:
                transform_label = lambda x: node_to_state(x, tm)
            else:
                transform_label = lambda x: nodes_to_state(x, tm)

        if transform==None:
            transform = transforms.Compose([
                            transforms.RandomResizedCrop(227),
                            transforms.RandomHorizontalFlip(),
                            transforms.ToTensor(),
                            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                std=[0.229, 0.224, 0.225]),
                            ]
                        )

        super().__init__(root,
                loader,
                IMG_EXTENSIONS if is_valid_file is None else None,
                transform,
                transform_label,
                is_valid_file)

    def find_classes(self, directory):

        if self.map is None:
            classes = sorted(entry.name for entry in os.scandir(directory) if (entry.is_dir() and (entry.name[1:] in self.s2n.keys())))
            if not classes:
                raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")
            class_to_idx = {'n'+synset:node for (synset,node) in self.s2n.items() if 'n'+synset in classes}


        else:
            classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
            if not classes:
                raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")
            small_class_to_idxs = {'n'+synset:[node] for (synset,node) in self.s2n.items() if 'n'+synset in classes}
            mapped_class_to_idxs = {'n'+synset:nodes for (synset,nodes) in self.map.items() if 'n'+synset in classes}
            class_to_idx = {**small_class_to_idxs, **mapped_class_to_idxs}
        
        return classes, class_to_idx