import os
import os.path as osp
import json
from typing import Dict, List, Optional, Union

import numpy as np
import torch
import yaml
from PIL import Image
from torch import Tensor
from torch.utils.data import Dataset
from torchvision.transforms import PILToTensor, ToTensor, Compose, Resize, Normalize


class ImageFolderWithAttribution(Dataset):
    """ImageFolder with attribution maps and gt_masks.

    This class can also be used as plain ImageFolder without attribution maps.
    For experiments on synthetic dataset, the gt_masks are generated along the
    images by ``DatasetGenerator``.

    Args:
        num_channels: Number of channels of images. If 3, images will be treated as RGB
            images. If 1, images will be converted to gray images.
        img_root: Root of the images.
        attr_root: Root of the attribution maps. If None, no attribution maps will be loaded.
        attr_file_format: The file extension of attribution. Can be either "npy" or "png".
        imagenet_transform: If True, apply Resize -> ToTensor -> Normalize transforms.
            In the end, the images will be normalized using the mean and std from ImageNet dataset.
        gt_mask_root: Root of the gt_masks. If None, no gt masks will be loaded.
        cls_to_ind: Defining how to map the class names (sub-folders) to class indices.
            If "str_to_int", then simply call ``int(sub_folder)`` to convert a
            sub-folder name to class index. If "alphabet", then sort the sub-folders
            in alphabet order, and map the sub-folders to their sorted indices.
            Otherwise, this argument is a path to a yaml file, which can be parsed as a
            dict that maps sub-folders to class indices.
        excluded_classes: If not None, exclude all the samples that belong to these classes.

    """

    def __init__(
        self,
        num_channels: int,
        img_root: str,
        attr_root: Optional[str],
        attr_file_format: Optional[str],
        imagenet_transform: bool = False,
        gt_mask_root: Optional[str] = None,
        cls_to_ind: str = 'str_to_int',
        excluded_classes: Optional[List[str]] = None,
    ) -> None:
        if attr_root is not None:
            assert (attr_file_format is not None) and attr_file_format in ('png', 'npy')
            if attr_file_format == 'png':
                self.attr_transform = ToTensor()
            else:
                self.attr_transform = lambda x: torch.from_numpy(x)
        else:
            assert attr_file_format is None
        self.attr_file_format = attr_file_format

        self.num_channels = num_channels
        self.img_root = img_root
        self.attr_root = attr_root
        self.gt_mask_root = gt_mask_root

        # file name in format class_name/xxx.png
        files = []
        excluded_classes = [] if excluded_classes is None else excluded_classes
        for sub_dir in os.listdir(self.img_root):
            if sub_dir not in excluded_classes:
                files.extend(
                    [osp.join(sub_dir, base_name) for base_name in os.listdir(osp.join(self.img_root, sub_dir))])
        self.files = files

        if imagenet_transform:
            assert num_channels == 3
            self.img_transform = Compose(
                [Resize((224, 224)),
                 ToTensor(),
                 Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
                 ])
        else:
            self.img_transform = PILToTensor() if num_channels == 3 else ToTensor()

        if self.gt_mask_root is not None:
            self.gt_mask_transform = PILToTensor()
        else:
            self.gt_mask_transform = None

        if cls_to_ind == 'str_to_int':
            self.cls_to_ind_dict = {c: int(c) for c in os.listdir(self.img_root)}
        elif cls_to_ind == 'alphabet':
            self.cls_to_ind_dict = {c: i for i, c in enumerate(sorted(os.listdir(self.img_root)))}
        else:
            with open(cls_to_ind, 'r') as f:
                # splitext will give something like ['xxx', '.json']
                file_ext = osp.splitext(cls_to_ind)[1]
                if file_ext in ('.yaml', '.yml', '.YAML'):
                    self.cls_to_ind_dict = yaml.safe_load(f)
                elif file_ext == '.json':
                    self.cls_to_ind_dict = json.load(f)
                else:
                    raise ValueError(f'Unsupported cls_to_ind file format: {file_ext}')

    def __len__(self) -> int:
        return len(self.files)

    def __getitem__(self, index: int) -> Dict[str, Union[Tensor, str, int]]:
        """

        Args:
            index: Sample index

        Returns:
            A dict containing:
                - "img": image tensor with shape (num_channels, height, width).
                - "attr_map" (optional): attribution map tensor with shape (height, width).
                - "gt_mask" (optional): gt_mask with shape (height, width).
                - "img_file": image file with format: "cls_name/xxx.png"
                - "ori_size": original size of the image.
                - "label": (int) class index.

        """
        file = self.files[index]
        img = Image.open(osp.join(self.img_root, file))
        if self.num_channels == 1:
            # For single-channel synthetic images, first turn the image mode to L and
            # then use ToTensor to convert the gray image to (1, H, W) and scale the value to [0.0, 1.0]
            if img.mode != 'L':
                img = img.convert(mode='L')
        else:
            # For RGB synthetic images, use PILToTensor to convert the RGB image to (3, H, W) without value scaling
            # For RGB real images, convert them to RGB manually if necessary, and use ImageNet transforms to transform them.
            if img.mode != 'RGB':
                img = img.convert(mode='RGB')
        ori_size = list(img.size)
        img = self.img_transform(img).to(torch.float32)
        cls_name = osp.dirname(file)
        label = self.cls_to_ind_dict[cls_name]
        result = {'img': img, 'img_file': file, 'label': label, 'ori_size': ori_size}

        if self.attr_file_format is None:
            attr_map = None
        else:
            if self.attr_file_format == 'png':
                attr_map = Image.open(osp.join(self.attr_root, file))
                attr_map = self.attr_transform(attr_map).to(torch.float32).squeeze(0)
            else:
                attr_array_file = osp.splitext(file)[0] + '.npy'
                attr_map = np.load(osp.join(self.attr_root, attr_array_file))
                attr_map = self.attr_transform(attr_map).to(torch.float32)
            result.update({'attr_map': attr_map})

        if self.gt_mask_root is not None:
            gt_mask = Image.open(osp.join(self.gt_mask_root, file))
            gt_mask = self.gt_mask_transform(gt_mask).squeeze(0)
            result.update({'gt_mask': gt_mask})
        else:
            gt_mask = None

        self._valid(img, attr_map, gt_mask)

        return result

    @staticmethod
    def _valid(img: Tensor, attr_map: Optional[Tensor], gt_mask: Optional[Tensor]) -> None:
        if img.dim() != 3 or (img.shape[0] != 1 and img.shape[0] != 3):
            raise ValueError(f'img has invalid shape: {img.shape}')
        img_size = img.shape[-2:]

        if attr_map is not None and attr_map.shape != img_size:
            raise ValueError(f'img has spatial size: {img_size}, but attr_map has shape: {attr_map.shape}')

        if gt_mask is not None and gt_mask.shape != img_size:
            raise ValueError(f'img has spatial size: {img_size}, but gt_mask has shape: {gt_mask.shape}')
