import os.path as osp
from PIL import Image
import numpy as np
import torch
from typing import List, Tuple, Optional
from torch.utils.data.dataset import Dataset

from utils.file_io import load_list, load_leison_classes
from utils.visualizer import Visualizer
import dataset.paired_transforms_tv04 as p_tr
from utils.colormap import colormap

_NUM_CLASSES = 8
_LESION_COLORS = colormap(rgb=True, maximum=1)[:_NUM_CLASSES, :].tolist()


class RetinalLesionsClassDataset(Dataset):
    """
    Wrapper for retinal lesion dataset.
    """
    def __init__(self, data_root : str,
                 sample_list_path : str,
                 class_name: str,
                 transforms : Optional[p_tr.Compose] = None,
                 label_values : List[int] = [255]) -> None:
        self.data_root = data_root
        self.image_dir = osp.join(self.data_root, "images_896x896")
        self.seg_dir = osp.join(self.data_root, "lesion_segs_896x896")
        self.samples : List[str] = load_list(sample_list_path)
        self.class_name = class_name
        self.transforms = transforms
        self.label_values = label_values

    def get_target(self, img, label_dir : str, class_name : str, label_values : List[int]):
        target = Image.fromarray(np.zeros(img.size))
        label = 0

        expected_path = osp.join(label_dir, "{}.png".format(class_name))
        if osp.exists(expected_path):
            img = Image.open(expected_path)
            # img.show()
            arr = np.array(img)
            mask = np.zeros_like(arr)
            for val in self.label_values:
                mask[np.where(arr == val)] = 1
            target = Image.fromarray(mask)
            label = 1

        return target, label

    def save_binary_mask(self, index : int, save_dir : str):
        sample_name = self.samples[index]
        img = Image.open(osp.join(self.image_dir, "{}.jpg".format(sample_name))).convert("RGB")
        targets, _ = self.get_target(
            osp.join(self.seg_dir, sample_name), self.label_values
        )
        targets = [np.array(t) for t in targets]

        mask = np.zeros(img.size, dtype=int)
        for x in targets:
            mask = np.bitwise_or(mask, x)
        mask = (mask * 255).astype(np.uint8)
        mask = Image.fromarray(mask)
        mask.save(osp.join(save_dir, "{}.jpg".format(sample_name)), "JPEG")

    def __getitem__(self, index : int) -> Tuple[torch.Tensor, torch.Tensor]:
        sample_name = self.samples[index]
        img = Image.open(osp.join(self.image_dir, "{}.jpg".format(sample_name))).convert("RGB")
        target, label = self.get_target(
            img, osp.join(self.seg_dir, sample_name), self.class_name, self.label_values
        )

        if self.transforms is not None:
            img, target = self.transforms(img, target)

        return img, target.unsqueeze(dim=0).type(torch.float32)

    def __len__(self) -> int:
        return len(self.samples)
