import os.path as osp
from PIL import Image
import cv2
import numpy as np
import torch
from typing import List, Tuple, Optional
from torch.utils.data.dataset import Dataset

from utils.file_io import load_list, load_leison_classes
from utils.visualizer import Visualizer
import dataset.paired_transforms_tv04 as p_tr
from utils.colormap import colormap

_EPS = 1e-10
_NUM_CLASSES = 8
_LESION_COLORS = colormap(rgb=True, maximum=1)[:_NUM_CLASSES, :].tolist()


class RetinalLesionsDataset(Dataset):
    """
    Wrapper for retinal lesion dataset.
    """
    def __init__(self, data_root : str,
                 sample_list_path : str,
                 classes_path : str,
                 transforms : Optional[p_tr.Compose] = None,
                 label_values : List[int] = [255],
                 binary : bool = False,
                 region_size : bool = False,
                 region_number : bool = False,
                 normalize_region_size : bool = True,
                 return_id : bool = False) -> None:
        self.data_root = data_root
        self.image_dir = osp.join(self.data_root, "images_896x896")
        self.seg_dir = osp.join(self.data_root, "lesion_segs_896x896")
        self.samples : List[str] = load_list(sample_list_path)
        self.classes, self.classes_abbrev = load_leison_classes(classes_path)
        self.transforms = transforms
        self.label_values = label_values
        self.binary = binary
        self.region_size = region_size
        self.region_number = region_number
        self.normalize_region_size = normalize_region_size
        self.return_id = return_id

    def turnoff_region_info(self) -> None:
        self.region_size = False
        self.region_number

    def get_target(self, label_dir : str, label_values : List[int]):
        targets = []
        labels = []
        for i, class_name in enumerate(self.classes):
            expected_path = osp.join(label_dir, "{}.png".format(class_name))
            if osp.exists(expected_path):
                img = Image.open(expected_path)
                # img.show()
                arr = np.array(img)
                mask = np.zeros_like(arr)
                for val in self.label_values:
                    mask[np.where(arr == val)] = 1
                targets.append(Image.fromarray(mask))
                labels.append(i)
        return targets, labels

    def save_binary_mask(self, index : int, save_dir : str):
        sample_name = self.samples[index]
        img = Image.open(osp.join(self.image_dir, "{}.jpg".format(sample_name))).convert("RGB")
        targets, _ = self.get_target(
            osp.join(self.seg_dir, sample_name), self.label_values
        )
        targets = [np.array(t) for t in targets]

        mask = np.zeros(img.size, dtype=int)
        for x in targets:
            mask = np.bitwise_or(mask, x)
        mask = (mask * 255).astype(np.uint8)
        mask = Image.fromarray(mask)
        mask.save(osp.join(save_dir, "{}.jpg".format(sample_name)), "JPEG")

    def get_region_size(self, target : torch.Tensor) -> torch.Tensor:
        region_area = torch.einsum("cij->c", target)
        if self.normalize_region_size:
            return (region_area + _EPS) / (target.shape[1] * target.shape[2])
        else:
            return region_area
        # return (torch.log(region_area + 1) + _EPS) / np.log(target.shape[1] * target.shape[2] + 1)

    def get_region_number(self, target : torch.Tensor) -> torch.Tensor:
        region_number = torch.zeros(target.size(0), 1)
        for c in range(target.size(0)):
            thres = (target[c].numpy() * 255).astype(np.uint8)
            countours, _ = cv2.findContours(thres, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
            region_number[c] = len(countours)
        return region_number

    def __getitem__(self, index : int) -> Tuple[torch.Tensor, torch.Tensor]:
        sample_name = self.samples[index]
        img = Image.open(osp.join(self.image_dir, "{}.jpg".format(sample_name))).convert("RGB")
        targets, labels = self.get_target(
            osp.join(self.seg_dir, sample_name), self.label_values
        )

        if self.transforms is not None:
            img, targets = self.transforms(img, targets)

        if self.binary:
            target = torch.zeros(img.shape[1], img.shape[2], dtype=torch.int32)
            for x in targets:
                target = torch.bitwise_or(target, x.type(torch.int32))
            target = torch.unsqueeze(target, 0).type(torch.float32)
        else:
            target = torch.zeros(len(self.classes), img.shape[1], img.shape[2])
            for i in range(len(targets)):
                target[labels[i], :, :] = targets[i]

        if self.return_id:
            if self.region_size:
                return (img, target, self.get_region_size(target), sample_name)
            elif self.region_number:
                return (img, target, self.get_region_number(target), sample_name)
            else:
                return (img, target, sample_name)
        else:
            if self.region_size:
                return (img, target, self.get_region_size(target))
            elif self.region_number:
                return (img, target, self.get_region_number(target))
            else:
                return (img, target)

    def __len__(self) -> int:
        return len(self.samples)
